[Mlir-commits] [mlir] f9008e6 - [NFC][Py Reformat] Reformat python files in mlir subdir
Tobias Hieta
llvmlistbot at llvm.org
Thu May 25 23:05:53 PDT 2023
Author: Tobias Hieta
Date: 2023-05-26T08:05:40+02:00
New Revision: f9008e6366c2496b1ca1785b891d5578174ad63e
URL: https://github.com/llvm/llvm-project/commit/f9008e6366c2496b1ca1785b891d5578174ad63e
DIFF: https://github.com/llvm/llvm-project/commit/f9008e6366c2496b1ca1785b891d5578174ad63e.diff
LOG: [NFC][Py Reformat] Reformat python files in mlir subdir
This is an ongoing series of commits that are reformatting our
Python code.
Reformatting is done with `black`.
If you end up having problems merging this commit because you
have made changes to a python file, the best way to handle that
is to run git checkout --ours <yourfile> and then reformat it
with black.
If you run into any problems, post to discourse about it and
we will try to help.
RFC Thread below:
https://discourse.llvm.org/t/rfc-document-and-standardize-python-code-style
Differential Revision: https://reviews.llvm.org/D150782
Added:
Modified:
mlir/benchmark/python/benchmark_sparse.py
mlir/benchmark/python/common.py
mlir/examples/standalone/test/CAPI/lit.local.cfg
mlir/examples/standalone/test/lit.cfg.py
mlir/examples/standalone/test/python/lit.local.cfg
mlir/examples/standalone/test/python/smoketest.py
mlir/python/mlir/_mlir_libs/__init__.py
mlir/python/mlir/dialects/_arith_ops_ext.py
mlir/python/mlir/dialects/_bufferization_ops_ext.py
mlir/python/mlir/dialects/_builtin_ops_ext.py
mlir/python/mlir/dialects/_func_ops_ext.py
mlir/python/mlir/dialects/_linalg_ops_ext.py
mlir/python/mlir/dialects/_loop_transform_ops_ext.py
mlir/python/mlir/dialects/_memref_ops_ext.py
mlir/python/mlir/dialects/_ml_program_ops_ext.py
mlir/python/mlir/dialects/_ods_common.py
mlir/python/mlir/dialects/_pdl_ops_ext.py
mlir/python/mlir/dialects/_scf_ops_ext.py
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/python/mlir/dialects/_tensor_ops_ext.py
mlir/python/mlir/dialects/_transform_ops_ext.py
mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/python/mlir/dialects/linalg/opdsl/lang/types.py
mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/python/mlir/dialects/python_test.py
mlir/python/mlir/dialects/transform/__init__.py
mlir/python/mlir/execution_engine.py
mlir/python/mlir/ir.py
mlir/python/mlir/runtime/np_to_memref.py
mlir/test/CAPI/lit.local.cfg
mlir/test/Conversion/GPUToCUDA/lit.local.cfg
mlir/test/Conversion/GPUToROCm/lit.local.cfg
mlir/test/Examples/Toy/Ch6/lit.local.cfg
mlir/test/Examples/Toy/Ch7/lit.local.cfg
mlir/test/Examples/lit.local.cfg
mlir/test/Examples/standalone/lit.local.cfg
mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg
mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg
mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py
mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg
mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg
mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg
mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg
mlir/test/Integration/GPU/CUDA/lit.local.cfg
mlir/test/Integration/GPU/ROCM/lit.local.cfg
mlir/test/Integration/lit.local.cfg
mlir/test/Unit/lit.cfg.py
mlir/test/lib/Dialect/Test/lit.local.cfg
mlir/test/lib/Dialect/Transform/lit.local.cfg
mlir/test/lib/Tools/PDLL/lit.local.cfg
mlir/test/lib/Transforms/lit.local.cfg
mlir/test/lit.cfg.py
mlir/test/mlir-cpu-runner/lit.local.cfg
mlir/test/mlir-pdll-lsp-server/lit.local.cfg
mlir/test/mlir-pdll/lit.local.cfg
mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
mlir/test/mlir-vulkan-runner/lit.local.cfg
mlir/test/python/develoment_files.py
mlir/test/python/dialects/arith_dialect.py
mlir/test/python/dialects/async_dialect.py
mlir/test/python/dialects/builtin.py
mlir/test/python/dialects/complex_dialect.py
mlir/test/python/dialects/func.py
mlir/test/python/dialects/gpu.py
mlir/test/python/dialects/linalg/opdsl/arguments.py
mlir/test/python/dialects/linalg/opdsl/assignments.py
mlir/test/python/dialects/linalg/opdsl/doctests.py
mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
mlir/test/python/dialects/linalg/opdsl/emit_fill.py
mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
mlir/test/python/dialects/linalg/opdsl/emit_misc.py
mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
mlir/test/python/dialects/linalg/opdsl/lit.local.cfg
mlir/test/python/dialects/linalg/opdsl/metadata.py
mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
mlir/test/python/dialects/linalg/ops.py
mlir/test/python/dialects/math_dialect.py
mlir/test/python/dialects/memref.py
mlir/test/python/dialects/ml_program.py
mlir/test/python/dialects/ods_helpers.py
mlir/test/python/dialects/pdl_ops.py
mlir/test/python/dialects/python_test.py
mlir/test/python/dialects/quant.py
mlir/test/python/dialects/scf.py
mlir/test/python/dialects/shape.py
mlir/test/python/dialects/sparse_tensor/dialect.py
mlir/test/python/dialects/sparse_tensor/passes.py
mlir/test/python/dialects/tensor.py
mlir/test/python/dialects/transform.py
mlir/test/python/dialects/transform_loop_ext.py
mlir/test/python/dialects/transform_structured_ext.py
mlir/test/python/dialects/vector.py
mlir/test/python/execution_engine.py
mlir/test/python/integration/dialects/linalg/opsrun.py
mlir/test/python/ir/affine_expr.py
mlir/test/python/ir/affine_map.py
mlir/test/python/ir/array_attributes.py
mlir/test/python/ir/attributes.py
mlir/test/python/ir/blocks.py
mlir/test/python/ir/builtin_types.py
mlir/test/python/ir/context_managers.py
mlir/test/python/ir/debug.py
mlir/test/python/ir/diagnostic_handler.py
mlir/test/python/ir/dialects.py
mlir/test/python/ir/exception.py
mlir/test/python/ir/insertion_point.py
mlir/test/python/ir/integer_set.py
mlir/test/python/ir/location.py
mlir/test/python/ir/module.py
mlir/test/python/ir/operation.py
mlir/test/python/ir/symbol_table.py
mlir/test/python/ir/value.py
mlir/test/python/lit.local.cfg
mlir/test/python/pass_manager.py
mlir/test/tblgen-lsp-server/lit.local.cfg
mlir/utils/gdb-scripts/prettyprinters.py
mlir/utils/generate-test-checks.py
mlir/utils/jupyter/mlir_opt_kernel/__main__.py
mlir/utils/jupyter/mlir_opt_kernel/install.py
mlir/utils/jupyter/mlir_opt_kernel/kernel.py
mlir/utils/lldb-scripts/mlirDataFormatters.py
mlir/utils/mbr/mbr/__init__.py
mlir/utils/mbr/mbr/discovery.py
mlir/utils/mbr/mbr/main.py
mlir/utils/mbr/mbr/stats.py
mlir/utils/spirv/gen_spirv_dialect.py
Removed:
################################################################################
diff --git a/mlir/benchmark/python/benchmark_sparse.py b/mlir/benchmark/python/benchmark_sparse.py
index 6d7a39690734c..72b3ef1a3f424 100644
--- a/mlir/benchmark/python/benchmark_sparse.py
+++ b/mlir/benchmark/python/benchmark_sparse.py
@@ -25,7 +25,7 @@
def matmul_dsl(
A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
- C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)
+ C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
):
"""Helper function for mlir sparse matrix multiplication benchmark."""
C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
@@ -43,6 +43,7 @@ def benchmark_sparse_mlir_multiplication():
param2_type = ir.RankedTensorType.get([1500, 2000], f64)
result_type = ir.RankedTensorType.get([1000, 2000], f64)
with ir.InsertionPoint(module.body):
+
@func.FuncOp.from_py_func(param1_type, param2_type, result_type)
def sparse_kernel(x, y, z):
return matmul_dsl(x, y, outs=[z])
@@ -51,37 +52,34 @@ def compiler():
with ir.Context(), ir.Location.unknown():
kernel_func = get_kernel_func_from_module(module)
timer_func = emit_timer_func()
- wrapped_func = emit_benchmark_wrapped_main_func(
- kernel_func,
- timer_func
- )
+ wrapped_func = emit_benchmark_wrapped_main_func(kernel_func, timer_func)
main_module_with_benchmark = ir.Module.parse(
str(timer_func) + str(wrapped_func) + str(kernel_func)
)
setup_passes(main_module_with_benchmark)
c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "")
- assert os.path.exists(c_runner_utils),\
- f"{c_runner_utils} does not exist." \
- f" Please pass a valid value for" \
+ assert os.path.exists(c_runner_utils), (
+ f"{c_runner_utils} does not exist."
+ f" Please pass a valid value for"
f" MLIR_C_RUNNER_UTILS environment variable."
+ )
runner_utils = os.getenv("MLIR_RUNNER_UTILS", "")
- assert os.path.exists(runner_utils),\
- f"{runner_utils} does not exist." \
- f" Please pass a valid value for MLIR_RUNNER_UTILS" \
+ assert os.path.exists(runner_utils), (
+ f"{runner_utils} does not exist."
+ f" Please pass a valid value for MLIR_RUNNER_UTILS"
f" environment variable."
+ )
engine = ExecutionEngine(
main_module_with_benchmark,
3,
- shared_libs=[c_runner_utils, runner_utils]
+ shared_libs=[c_runner_utils, runner_utils],
)
return engine.invoke
def runner(engine_invoke):
compiled_program_args = []
- for argument_type in [
- result_type, param1_type, param2_type, result_type
- ]:
+ for argument_type in [result_type, param1_type, param2_type, result_type]:
argument_type_str = str(argument_type)
dimensions_str = re.sub("<|>|tensor", "", argument_type_str)
dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]]
@@ -111,6 +109,7 @@ def benchmark_np_matrix_multiplication():
benchmark, we don't have any `compiler` function returned. We just return
the `runner` function.
"""
+
def runner():
argument1 = np.random.uniform(low=0.0, high=100.0, size=(1000, 1500))
argument2 = np.random.uniform(low=0.0, high=100.0, size=(1500, 2000))
diff --git a/mlir/benchmark/python/common.py b/mlir/benchmark/python/common.py
index 3634641074f3e..c605726df2a5f 100644
--- a/mlir/benchmark/python/common.py
+++ b/mlir/benchmark/python/common.py
@@ -10,8 +10,7 @@
def setup_passes(mlir_module):
- """Setup pass pipeline parameters for benchmark functions.
- """
+ """Setup pass pipeline parameters for benchmark functions."""
opt = (
"parallelization-strategy=none"
" vectorization-strategy=none vl=1 enable-simd-index32=False"
@@ -43,12 +42,15 @@ def get_kernel_func_from_module(module: ir.Module) -> func.FuncOp:
This function only works for a module with one region, one block, and one
operation.
"""
- assert len(module.operation.regions) == 1, \
- "Expected kernel module to have only one region"
- assert len(module.operation.regions[0].blocks) == 1, \
- "Expected kernel module to have only one block"
- assert len(module.operation.regions[0].blocks[0].operations) == 1, \
- "Expected kernel module to have only one operation"
+ assert (
+ len(module.operation.regions) == 1
+ ), "Expected kernel module to have only one region"
+ assert (
+ len(module.operation.regions[0].blocks) == 1
+ ), "Expected kernel module to have only one block"
+ assert (
+ len(module.operation.regions[0].blocks[0].operations) == 1
+ ), "Expected kernel module to have only one operation"
return module.operation.regions[0].blocks[0].operations[0]
@@ -57,8 +59,7 @@ def emit_timer_func() -> func.FuncOp:
used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included.
"""
i64_type = ir.IntegerType.get_signless(64)
- nanoTime = func.FuncOp(
- "nanoTime", ([], [i64_type]), visibility="private")
+ nanoTime = func.FuncOp("nanoTime", ([], [i64_type]), visibility="private")
nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
return nanoTime
@@ -76,9 +77,8 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
wrapped_func = func.FuncOp(
# Same signature and an extra buffer of indices to save timings.
"main",
- (kernel_func.arguments.types + [memref_of_i64_type],
- kernel_func.type.results),
- visibility="public"
+ (kernel_func.arguments.types + [memref_of_i64_type], kernel_func.type.results),
+ visibility="public",
)
wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
@@ -88,13 +88,13 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
zero = arith.ConstantOp.create_index(0)
n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero)
one = arith.ConstantOp.create_index(1)
- iter_args = list(wrapped_func.arguments[-num_results - 1:-1])
+ iter_args = list(wrapped_func.arguments[-num_results - 1 : -1])
loop = scf.ForOp(zero, n_iterations, one, iter_args)
with ir.InsertionPoint(loop.body):
start = func.CallOp(timer_func, [])
call = func.CallOp(
kernel_func,
- wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args
+ wrapped_func.arguments[: -num_results - 1] + loop.inner_iter_args,
)
end = func.CallOp(timer_func, [])
time_taken = arith.SubIOp(end, start)
diff --git a/mlir/examples/standalone/test/CAPI/lit.local.cfg b/mlir/examples/standalone/test/CAPI/lit.local.cfg
index f08a0de488ddd..bb0c17cdbada7 100644
--- a/mlir/examples/standalone/test/CAPI/lit.local.cfg
+++ b/mlir/examples/standalone/test/CAPI/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.add('.c')
+config.suffixes.add(".c")
diff --git a/mlir/examples/standalone/test/lit.cfg.py b/mlir/examples/standalone/test/lit.cfg.py
index 601ac8f769f93..e27dddd7fb0b9 100644
--- a/mlir/examples/standalone/test/lit.cfg.py
+++ b/mlir/examples/standalone/test/lit.cfg.py
@@ -16,52 +16,55 @@
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
-config.name = 'STANDALONE'
+config.name = "STANDALONE"
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.mlir']
+config.suffixes = [".mlir"]
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
+config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
-config.substitutions.append(('%PATH%', config.environment['PATH']))
-config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
+config.substitutions.append(("%PATH%", config.environment["PATH"]))
+config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
-llvm_config.with_system_environment(
- ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
+llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
llvm_config.use_default_substitutions()
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
-config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
+config.excludes = ["Inputs", "Examples", "CMakeLists.txt", "README.txt", "LICENSE.txt"]
# test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
-config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin')
-config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib')
+config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
+config.standalone_tools_dir = os.path.join(config.standalone_obj_root, "bin")
+config.standalone_libs_dir = os.path.join(config.standalone_obj_root, "lib")
-config.substitutions.append(('%standalone_libs', config.standalone_libs_dir))
+config.substitutions.append(("%standalone_libs", config.standalone_libs_dir))
# Tweak the PATH to include the tools dir.
-llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
+llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir]
tools = [
- 'mlir-opt',
- 'standalone-capi-test',
- 'standalone-opt',
- 'standalone-translate',
+ "mlir-opt",
+ "standalone-capi-test",
+ "standalone-opt",
+ "standalone-translate",
]
llvm_config.add_tool_substitutions(tools, tool_dirs)
-llvm_config.with_environment('PYTHONPATH', [
- os.path.join(config.mlir_obj_dir, 'python_packages', 'standalone'),
-], append_path=True)
+llvm_config.with_environment(
+ "PYTHONPATH",
+ [
+ os.path.join(config.mlir_obj_dir, "python_packages", "standalone"),
+ ],
+ append_path=True,
+)
diff --git a/mlir/examples/standalone/test/python/lit.local.cfg b/mlir/examples/standalone/test/python/lit.local.cfg
index b70b9d7a34fdd..3394f180e5121 100644
--- a/mlir/examples/standalone/test/python/lit.local.cfg
+++ b/mlir/examples/standalone/test/python/lit.local.cfg
@@ -1,4 +1,4 @@
-config.suffixes.add('.py')
+config.suffixes.add(".py")
if not config.enable_bindings_python:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 0d8f41c27e8ef..08e08cbd2fe24 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -1,17 +1,16 @@
# RUN: %python %s | FileCheck %s
from mlir_standalone.ir import *
-from mlir_standalone.dialects import (
- builtin as builtin_d,
- standalone as standalone_d
-)
+from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d
with Context():
- standalone_d.register_dialect()
- module = Module.parse("""
+ standalone_d.register_dialect()
+ module = Module.parse(
+ """
%0 = arith.constant 2 : i32
%1 = standalone.foo %0 : i32
- """)
- # CHECK: %[[C:.*]] = arith.constant 2 : i32
- # CHECK: standalone.foo %[[C]] : i32
- print(str(module))
+ """
+ )
+ # CHECK: %[[C:.*]] = arith.constant 2 : i32
+ # CHECK: standalone.foo %[[C]] : i32
+ print(str(module))
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 7d3d1f6ca873a..03fcb10130c3a 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -10,26 +10,26 @@
def get_lib_dirs() -> Sequence[str]:
- """Gets the lib directory for linking to shared libraries.
+ """Gets the lib directory for linking to shared libraries.
- On some platforms, the package may need to be built specially to export
- development libraries.
- """
- return [_this_dir]
+ On some platforms, the package may need to be built specially to export
+ development libraries.
+ """
+ return [_this_dir]
def get_include_dirs() -> Sequence[str]:
- """Gets the include directory for compiling against exported C libraries.
+ """Gets the include directory for compiling against exported C libraries.
- Depending on how the package was build, development C libraries may or may
- not be present.
- """
- return [os.path.join(_this_dir, "include")]
+ Depending on how the package was build, development C libraries may or may
+ not be present.
+ """
+ return [os.path.join(_this_dir, "include")]
# Perform Python level site initialization. This involves:
# 1. Attempting to load initializer modules, specific to the distribution.
-# 2. Defining the concrete mlir.ir.Context that does site specific
+# 2. Defining the concrete mlir.ir.Context that does site specific
# initialization.
#
# Aside from just being far more convenient to do this at the Python level,
@@ -38,91 +38,106 @@ def get_include_dirs() -> Sequence[str]:
# in the scope of the base class __init__).
#
# For #1, we:
-# a. Probe for modules named '_mlirRegisterEverything' and
-# '_site_initialize_{i}', where 'i' is a number starting at zero and
+# a. Probe for modules named '_mlirRegisterEverything' and
+# '_site_initialize_{i}', where 'i' is a number starting at zero and
# proceeding so long as a module with the name is found.
# b. If the module has a 'register_dialects' attribute, it will be called
# immediately with a DialectRegistry to populate.
# c. If the module has a 'context_init_hook', it will be added to a list
-# of callbacks that are invoked as the last step of Context
+# of callbacks that are invoked as the last step of Context
# initialization (and passed the Context under construction).
#
# This facility allows downstreams to customize Context creation to their
# needs.
def _site_initialize():
- import importlib
- import itertools
- import logging
- from ._mlir import ir
- logger = logging.getLogger(__name__)
- registry = ir.DialectRegistry()
- post_init_hooks = []
-
- def process_initializer_module(module_name):
- try:
- m = importlib.import_module(f".{module_name}", __name__)
- except ModuleNotFoundError:
- return False
- except ImportError:
- message = (f"Error importing mlir initializer {module_name}. This may "
- "happen in unclean incremental builds but is likely a real bug if "
- "encountered otherwise and the MLIR Python API may not function.")
- logger.warning(message, exc_info=True)
-
- logger.debug("Initializing MLIR with module: %s", module_name)
- if hasattr(m, "register_dialects"):
- logger.debug("Registering dialects from initializer %r", m)
- m.register_dialects(registry)
- if hasattr(m, "context_init_hook"):
- logger.debug("Adding context init hook from %r", m)
- post_init_hooks.append(m.context_init_hook)
- return True
-
-
- # If _mlirRegisterEverything is built, then include it as an initializer
- # module.
- process_initializer_module("_mlirRegisterEverything")
-
- # Load all _site_initialize_{i} modules, where 'i' is a number starting
- # at 0.
- for i in itertools.count():
- module_name = f"_site_initialize_{i}"
- if not process_initializer_module(module_name):
- break
-
- class Context(ir._BaseContext):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.append_dialect_registry(registry)
- for hook in post_init_hooks:
- hook(self)
- # TODO: There is some debate about whether we should eagerly load
- # all dialects. It is being done here in order to preserve existing
- # behavior. See: https://github.com/llvm/llvm-project/issues/56037
- self.load_all_available_dialects()
- ir.Context = Context
-
- class MLIRError(Exception):
- """
- An exception with diagnostic information. Has the following fields:
- message: str
- error_diagnostics: List[ir.DiagnosticInfo]
- """
- def __init__(self, message, error_diagnostics):
- self.message = message
- self.error_diagnostics = error_diagnostics
- super().__init__(message, error_diagnostics)
-
- def __str__(self):
- s = self.message
- if self.error_diagnostics:
- s += ':'
- for diag in self.error_diagnostics:
- s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ')
- for note in diag.notes:
- s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ')
- return s
- ir.MLIRError = MLIRError
+ import importlib
+ import itertools
+ import logging
+ from ._mlir import ir
+
+ logger = logging.getLogger(__name__)
+ registry = ir.DialectRegistry()
+ post_init_hooks = []
+
+ def process_initializer_module(module_name):
+ try:
+ m = importlib.import_module(f".{module_name}", __name__)
+ except ModuleNotFoundError:
+ return False
+ except ImportError:
+ message = (
+ f"Error importing mlir initializer {module_name}. This may "
+ "happen in unclean incremental builds but is likely a real bug if "
+ "encountered otherwise and the MLIR Python API may not function."
+ )
+ logger.warning(message, exc_info=True)
+
+ logger.debug("Initializing MLIR with module: %s", module_name)
+ if hasattr(m, "register_dialects"):
+ logger.debug("Registering dialects from initializer %r", m)
+ m.register_dialects(registry)
+ if hasattr(m, "context_init_hook"):
+ logger.debug("Adding context init hook from %r", m)
+ post_init_hooks.append(m.context_init_hook)
+ return True
+
+ # If _mlirRegisterEverything is built, then include it as an initializer
+ # module.
+ process_initializer_module("_mlirRegisterEverything")
+
+ # Load all _site_initialize_{i} modules, where 'i' is a number starting
+ # at 0.
+ for i in itertools.count():
+ module_name = f"_site_initialize_{i}"
+ if not process_initializer_module(module_name):
+ break
+
+ class Context(ir._BaseContext):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.append_dialect_registry(registry)
+ for hook in post_init_hooks:
+ hook(self)
+ # TODO: There is some debate about whether we should eagerly load
+ # all dialects. It is being done here in order to preserve existing
+ # behavior. See: https://github.com/llvm/llvm-project/issues/56037
+ self.load_all_available_dialects()
+
+ ir.Context = Context
+
+ class MLIRError(Exception):
+ """
+ An exception with diagnostic information. Has the following fields:
+ message: str
+ error_diagnostics: List[ir.DiagnosticInfo]
+ """
+
+ def __init__(self, message, error_diagnostics):
+ self.message = message
+ self.error_diagnostics = error_diagnostics
+ super().__init__(message, error_diagnostics)
+
+ def __str__(self):
+ s = self.message
+ if self.error_diagnostics:
+ s += ":"
+ for diag in self.error_diagnostics:
+ s += (
+ "\nerror: "
+ + str(diag.location)[4:-1]
+ + ": "
+ + diag.message.replace("\n", "\n ")
+ )
+ for note in diag.notes:
+ s += (
+ "\n note: "
+ + str(note.location)[4:-1]
+ + ": "
+ + note.message.replace("\n", "\n ")
+ )
+ return s
+
+ ir.MLIRError = MLIRError
_site_initialize()
diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py
index 240859352ce21..df38f871710fe 100644
--- a/mlir/python/mlir/dialects/_arith_ops_ext.py
+++ b/mlir/python/mlir/dialects/_arith_ops_ext.py
@@ -3,72 +3,67 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
+ from ..ir import *
+ from ._ods_common import get_default_loc_context as _get_default_loc_context
- from typing import Any, List, Union
+ from typing import Any, List, Union
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
def _isa(obj: Any, cls: type):
- try:
- cls(obj)
- except ValueError:
- return False
- return True
+ try:
+ cls(obj)
+ except ValueError:
+ return False
+ return True
def _is_any_of(obj: Any, classes: List[type]):
- return any(_isa(obj, cls) for cls in classes)
+ return any(_isa(obj, cls) for cls in classes)
def _is_integer_like_type(type: Type):
- return _is_any_of(type, [IntegerType, IndexType])
+ return _is_any_of(type, [IntegerType, IndexType])
def _is_float_type(type: Type):
- return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
+ return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
class ConstantOp:
- """Specialization for the constant op class."""
-
- def __init__(self,
- result: Type,
- value: Union[int, float, Attribute],
- *,
- loc=None,
- ip=None):
- if isinstance(value, int):
- super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
- elif isinstance(value, float):
- super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
- else:
- super().__init__(value, loc=loc, ip=ip)
-
- @classmethod
- def create_index(cls, value: int, *, loc=None, ip=None):
- """Create an index-typed constant."""
- return cls(
- IndexType.get(context=_get_default_loc_context(loc)),
- value,
- loc=loc,
- ip=ip)
-
- @property
- def type(self):
- return self.results[0].type
-
- @property
- def value(self):
- return Attribute(self.operation.attributes["value"])
-
- @property
- def literal_value(self) -> Union[int, float]:
- if _is_integer_like_type(self.type):
- return IntegerAttr(self.value).value
- elif _is_float_type(self.type):
- return FloatAttr(self.value).value
- else:
- raise ValueError("only integer and float constants have literal values")
+ """Specialization for the constant op class."""
+
+ def __init__(
+ self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+ ):
+ if isinstance(value, int):
+ super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+ elif isinstance(value, float):
+ super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+ else:
+ super().__init__(value, loc=loc, ip=ip)
+
+ @classmethod
+ def create_index(cls, value: int, *, loc=None, ip=None):
+ """Create an index-typed constant."""
+ return cls(
+ IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
+ )
+
+ @property
+ def type(self):
+ return self.results[0].type
+
+ @property
+ def value(self):
+ return Attribute(self.operation.attributes["value"])
+
+ @property
+ def literal_value(self) -> Union[int, float]:
+ if _is_integer_like_type(self.type):
+ return IntegerAttr(self.value).value
+ elif _is_float_type(self.type):
+ return FloatAttr(self.value).value
+ else:
+ raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py
index 6ed35f4445c56..1066cb4c775ca 100644
--- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py
+++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py
@@ -3,36 +3,39 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from typing import Sequence, Union
- from ..ir import *
- from ._ods_common import get_default_loc_context
+ from typing import Sequence, Union
+ from ..ir import *
+ from ._ods_common import get_default_loc_context
- from typing import Any, List, Union
+ from typing import Any, List, Union
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
class AllocTensorOp:
- """Extends the bufferization.alloc_tensor op."""
+ """Extends the bufferization.alloc_tensor op."""
- def __init__(self,
- tensor_type: Type,
- dynamic_sizes: Sequence[Value],
- copy: Value,
- size_hint: Value,
- escape: BoolAttr,
- *,
- loc=None,
- ip=None):
- """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
- context = get_default_loc_context(loc)
- attributes = {}
- if escape:
- attributes["escape"] = escape
- op = self.build_generic(
- results=[tensor_type],
- operands=[dynamic_sizes, copy, size_hint],
- attributes=attributes,
- loc=loc,
- ip=ip)
- OpView.__init__(self, op)
+ def __init__(
+ self,
+ tensor_type: Type,
+ dynamic_sizes: Sequence[Value],
+ copy: Value,
+ size_hint: Value,
+ escape: BoolAttr,
+ *,
+ loc=None,
+ ip=None
+ ):
+ """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
+ context = get_default_loc_context(loc)
+ attributes = {}
+ if escape:
+ attributes["escape"] = escape
+ op = self.build_generic(
+ results=[tensor_type],
+ operands=[dynamic_sizes, copy, size_hint],
+ attributes=attributes,
+ loc=loc,
+ ip=ip,
+ )
+ OpView.__init__(self, op)
diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py
index b69163fa41519..27a60123050ac 100644
--- a/mlir/python/mlir/dialects/_builtin_ops_ext.py
+++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py
@@ -3,18 +3,18 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
+ from ..ir import *
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
+
class ModuleOp:
- """Specialization for the module op class."""
+ """Specialization for the module op class."""
- def __init__(self, *, loc=None, ip=None):
- super().__init__(self.build_generic(results=[], operands=[], loc=loc,
- ip=ip))
- body = self.regions[0].blocks.append()
+ def __init__(self, *, loc=None, ip=None):
+ super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
+ body = self.regions[0].blocks.append()
- @property
- def body(self):
- return self.regions[0].blocks[0]
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py
index 56df423d30a0f..6d264c33f1f9d 100644
--- a/mlir/python/mlir/dialects/_func_ops_ext.py
+++ b/mlir/python/mlir/dialects/_func_ops_ext.py
@@ -3,298 +3,317 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
+ from ..ir import *
+ from ._ods_common import get_default_loc_context as _get_default_loc_context
- import inspect
+ import inspect
- from typing import Any, List, Optional, Sequence, Union
+ from typing import Any, List, Optional, Sequence, Union
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"
+
class ConstantOp:
- """Specialization for the constant op class."""
+ """Specialization for the constant op class."""
- def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
- super().__init__(result, value, loc=loc, ip=ip)
+ def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
+ super().__init__(result, value, loc=loc, ip=ip)
- @property
- def type(self):
- return self.results[0].type
+ @property
+ def type(self):
+ return self.results[0].type
class FuncOp:
- """Specialization for the func op class."""
-
- def __init__(self,
- name,
- type,
- *,
- visibility=None,
- body_builder=None,
- loc=None,
- ip=None):
- """
- Create a FuncOp with the provided `name`, `type`, and `visibility`.
- - `name` is a string representing the function name.
- - `type` is either a FunctionType or a pair of list describing inputs and
- results.
- - `visibility` is a string matching `public`, `private`, or `nested`. None
- implies private visibility.
- - `body_builder` is an optional callback, when provided a new entry block
- is created and the callback is invoked with the new op as argument within
- an InsertionPoint context already set for the block. The callback is
- expected to insert a terminator in the block.
- """
- sym_name = StringAttr.get(str(name))
-
- # If the type is passed as a tuple, build a FunctionType on the fly.
- if isinstance(type, tuple):
- type = FunctionType.get(inputs=type[0], results=type[1])
-
- type = TypeAttr.get(type)
- sym_visibility = StringAttr.get(
- str(visibility)) if visibility is not None else None
- super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
- if body_builder:
- entry_block = self.add_entry_block()
- with InsertionPoint(entry_block):
- body_builder(self)
-
- @property
- def is_external(self):
- return len(self.regions[0].blocks) == 0
-
- @property
- def body(self):
- return self.regions[0]
-
- @property
- def type(self):
- return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
- @property
- def visibility(self):
- return self.attributes["sym_visibility"]
-
- @property
- def name(self) -> StringAttr:
- return StringAttr(self.attributes["sym_name"])
-
- @property
- def entry_block(self):
- if self.is_external:
- raise IndexError('External function does not have a body')
- return self.regions[0].blocks[0]
-
- def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
- """
- Add an entry block to the function body using the function signature to
- infer block arguments.
- Returns the newly created block
- """
- if not self.is_external:
- raise IndexError('The function already has an entry block!')
- self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
- return self.body.blocks[0]
-
- @property
- def arg_attrs(self):
- return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
- @arg_attrs.setter
- def arg_attrs(self, attribute: Union[ArrayAttr, list]):
- if isinstance(attribute, ArrayAttr):
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
- else:
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
- attribute, context=self.context)
-
- @property
- def arguments(self):
- return self.entry_block.arguments
-
- @property
- def result_attrs(self):
- return self.attributes[RESULT_ATTRIBUTE_NAME]
-
- @result_attrs.setter
- def result_attrs(self, attribute: ArrayAttr):
- self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
-
- @classmethod
- def from_py_func(FuncOp,
- *inputs: Type,
- results: Optional[Sequence[Type]] = None,
- name: Optional[str] = None):
- """Decorator to define an MLIR FuncOp specified as a python function.
-
- Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
- active for the current thread (i.e. established in a `with` block).
-
- When applied as a decorator to a Python function, an entry block will
- be constructed for the FuncOp with types as specified in `*inputs`. The
- block arguments will be passed positionally to the Python function. In
- addition, if the Python function accepts keyword arguments generally or
- has a corresponding keyword argument, the following will be passed:
- * `func_op`: The `func` op being defined.
-
- By default, the function name will be the Python function `__name__`. This
- can be overriden by passing the `name` argument to the decorator.
-
- If `results` is not specified, then the decorator will implicitly
- insert a `ReturnOp` with the `Value`'s returned from the decorated
- function. It will also set the `FuncOp` type with the actual return
- value types. If `results` is specified, then the decorated function
- must return `None` and no implicit `ReturnOp` is added (nor are the result
- types updated). The implicit behavior is intended for simple, single-block
- cases, and users should specify result types explicitly for any complicated
- cases.
-
- The decorated function can further be called from Python and will insert
- a `CallOp` at the then-current insertion point, returning either None (
- if no return values), a unary Value (for one result), or a list of Values).
- This mechanism cannot be used to emit recursive calls (by construction).
- """
-
- def decorator(f):
- from . import func
- # Introspect the callable for optional features.
- sig = inspect.signature(f)
- has_arg_func_op = False
- for param in sig.parameters.values():
- if param.kind == param.VAR_KEYWORD:
- has_arg_func_op = True
- if param.name == "func_op" and (param.kind
- == param.POSITIONAL_OR_KEYWORD or
- param.kind == param.KEYWORD_ONLY):
- has_arg_func_op = True
-
- # Emit the FuncOp.
- implicit_return = results is None
- symbol_name = name or f.__name__
- function_type = FunctionType.get(
- inputs=inputs, results=[] if implicit_return else results)
- func_op = FuncOp(name=symbol_name, type=function_type)
- with InsertionPoint(func_op.add_entry_block()):
- func_args = func_op.entry_block.arguments
- func_kwargs = {}
- if has_arg_func_op:
- func_kwargs["func_op"] = func_op
- return_values = f(*func_args, **func_kwargs)
- if not implicit_return:
- return_types = list(results)
- assert return_values is None, (
- "Capturing a python function with explicit `results=` "
- "requires that the wrapped function returns None.")
- else:
- # Coerce return values, add ReturnOp and rewrite func type.
- if return_values is None:
- return_values = []
- elif isinstance(return_values, tuple):
- return_values = list(return_values)
- elif isinstance(return_values, Value):
- # Returning a single value is fine, coerce it into a list.
- return_values = [return_values]
- elif isinstance(return_values, OpView):
- # Returning a single operation is fine, coerce its results a list.
- return_values = return_values.operation.results
- elif isinstance(return_values, Operation):
- # Returning a single operation is fine, coerce its results a list.
- return_values = return_values.results
- else:
- return_values = list(return_values)
- func.ReturnOp(return_values)
- # Recompute the function type.
- return_types = [v.type for v in return_values]
- function_type = FunctionType.get(inputs=inputs, results=return_types)
- func_op.attributes["function_type"] = TypeAttr.get(function_type)
-
- def emit_call_op(*call_args):
- call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
- call_args)
- if return_types is None:
- return None
- elif len(return_types) == 1:
- return call_op.result
+ """Specialization for the func op class."""
+
+ def __init__(
+ self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+ ):
+ """
+ Create a FuncOp with the provided `name`, `type`, and `visibility`.
+ - `name` is a string representing the function name.
+ - `type` is either a FunctionType or a pair of list describing inputs and
+ results.
+ - `visibility` is a string matching `public`, `private`, or `nested`. None
+ implies private visibility.
+ - `body_builder` is an optional callback, when provided a new entry block
+ is created and the callback is invoked with the new op as argument within
+ an InsertionPoint context already set for the block. The callback is
+ expected to insert a terminator in the block.
+ """
+ sym_name = StringAttr.get(str(name))
+
+ # If the type is passed as a tuple, build a FunctionType on the fly.
+ if isinstance(type, tuple):
+ type = FunctionType.get(inputs=type[0], results=type[1])
+
+ type = TypeAttr.get(type)
+ sym_visibility = (
+ StringAttr.get(str(visibility)) if visibility is not None else None
+ )
+ super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+ if body_builder:
+ entry_block = self.add_entry_block()
+ with InsertionPoint(entry_block):
+ body_builder(self)
+
+ @property
+ def is_external(self):
+ return len(self.regions[0].blocks) == 0
+
+ @property
+ def body(self):
+ return self.regions[0]
+
+ @property
+ def type(self):
+ return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+ @property
+ def visibility(self):
+ return self.attributes["sym_visibility"]
+
+ @property
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes["sym_name"])
+
+ @property
+ def entry_block(self):
+ if self.is_external:
+ raise IndexError("External function does not have a body")
+ return self.regions[0].blocks[0]
+
+ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
+ """
+ Add an entry block to the function body using the function signature to
+ infer block arguments.
+ Returns the newly created block
+ """
+ if not self.is_external:
+ raise IndexError("The function already has an entry block!")
+ self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
+ return self.body.blocks[0]
+
+ @property
+ def arg_attrs(self):
+ return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+ @arg_attrs.setter
+ def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+ if isinstance(attribute, ArrayAttr):
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
- return call_op.results
-
- wrapped = emit_call_op
- wrapped.__name__ = f.__name__
- wrapped.func_op = func_op
- return wrapped
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ attribute, context=self.context
+ )
+
+ @property
+ def arguments(self):
+ return self.entry_block.arguments
+
+ @property
+ def result_attrs(self):
+ return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+ @result_attrs.setter
+ def result_attrs(self, attribute: ArrayAttr):
+ self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+
+ @classmethod
+ def from_py_func(
+ FuncOp,
+ *inputs: Type,
+ results: Optional[Sequence[Type]] = None,
+ name: Optional[str] = None,
+ ):
+ """Decorator to define an MLIR FuncOp specified as a python function.
+
+ Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
+ active for the current thread (i.e. established in a `with` block).
+
+ When applied as a decorator to a Python function, an entry block will
+ be constructed for the FuncOp with types as specified in `*inputs`. The
+ block arguments will be passed positionally to the Python function. In
+ addition, if the Python function accepts keyword arguments generally or
+ has a corresponding keyword argument, the following will be passed:
+ * `func_op`: The `func` op being defined.
+
+ By default, the function name will be the Python function `__name__`. This
+ can be overriden by passing the `name` argument to the decorator.
+
+ If `results` is not specified, then the decorator will implicitly
+ insert a `ReturnOp` with the `Value`'s returned from the decorated
+ function. It will also set the `FuncOp` type with the actual return
+ value types. If `results` is specified, then the decorated function
+ must return `None` and no implicit `ReturnOp` is added (nor are the result
+ types updated). The implicit behavior is intended for simple, single-block
+ cases, and users should specify result types explicitly for any complicated
+ cases.
+
+ The decorated function can further be called from Python and will insert
+ a `CallOp` at the then-current insertion point, returning either None (
+ if no return values), a unary Value (for one result), or a list of Values).
+ This mechanism cannot be used to emit recursive calls (by construction).
+ """
+
+ def decorator(f):
+ from . import func
+
+ # Introspect the callable for optional features.
+ sig = inspect.signature(f)
+ has_arg_func_op = False
+ for param in sig.parameters.values():
+ if param.kind == param.VAR_KEYWORD:
+ has_arg_func_op = True
+ if param.name == "func_op" and (
+ param.kind == param.POSITIONAL_OR_KEYWORD
+ or param.kind == param.KEYWORD_ONLY
+ ):
+ has_arg_func_op = True
+
+ # Emit the FuncOp.
+ implicit_return = results is None
+ symbol_name = name or f.__name__
+ function_type = FunctionType.get(
+ inputs=inputs, results=[] if implicit_return else results
+ )
+ func_op = FuncOp(name=symbol_name, type=function_type)
+ with InsertionPoint(func_op.add_entry_block()):
+ func_args = func_op.entry_block.arguments
+ func_kwargs = {}
+ if has_arg_func_op:
+ func_kwargs["func_op"] = func_op
+ return_values = f(*func_args, **func_kwargs)
+ if not implicit_return:
+ return_types = list(results)
+ assert return_values is None, (
+ "Capturing a python function with explicit `results=` "
+ "requires that the wrapped function returns None."
+ )
+ else:
+ # Coerce return values, add ReturnOp and rewrite func type.
+ if return_values is None:
+ return_values = []
+ elif isinstance(return_values, tuple):
+ return_values = list(return_values)
+ elif isinstance(return_values, Value):
+ # Returning a single value is fine, coerce it into a list.
+ return_values = [return_values]
+ elif isinstance(return_values, OpView):
+ # Returning a single operation is fine, coerce its results a list.
+ return_values = return_values.operation.results
+ elif isinstance(return_values, Operation):
+ # Returning a single operation is fine, coerce its results a list.
+ return_values = return_values.results
+ else:
+ return_values = list(return_values)
+ func.ReturnOp(return_values)
+ # Recompute the function type.
+ return_types = [v.type for v in return_values]
+ function_type = FunctionType.get(
+ inputs=inputs, results=return_types
+ )
+ func_op.attributes["function_type"] = TypeAttr.get(function_type)
+
+ def emit_call_op(*call_args):
+ call_op = func.CallOp(
+ return_types, FlatSymbolRefAttr.get(symbol_name), call_args
+ )
+ if return_types is None:
+ return None
+ elif len(return_types) == 1:
+ return call_op.result
+ else:
+ return call_op.results
+
+ wrapped = emit_call_op
+ wrapped.__name__ = f.__name__
+ wrapped.func_op = func_op
+ return wrapped
+
+ return decorator
- return decorator
class CallOp:
- """Specialization for the call op class."""
-
- def __init__(self,
- calleeOrResults: Union[FuncOp, List[Type]],
- argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
- arguments: Optional[List] = None,
- *,
- loc=None,
- ip=None):
- """Creates an call operation.
-
- The constructor accepts three
diff erent forms:
-
- 1. A function op to be called followed by a list of arguments.
- 2. A list of result types, followed by the name of the function to be
- called as string, following by a list of arguments.
- 3. A list of result types, followed by the name of the function to be
- called as symbol reference attribute, followed by a list of arguments.
-
- For example
-
- f = func.FuncOp("foo", ...)
- func.CallOp(f, [args])
- func.CallOp([result_types], "foo", [args])
-
- In all cases, the location and insertion point may be specified as keyword
- arguments if not provided by the surrounding context managers.
- """
-
- # TODO: consider supporting constructor "overloads", e.g., through a custom
- # or pybind-provided metaclass.
- if isinstance(calleeOrResults, FuncOp):
- if not isinstance(argumentsOrCallee, list):
- raise ValueError(
- "when constructing a call to a function, expected " +
- "the second argument to be a list of call arguments, " +
- f"got {type(argumentsOrCallee)}")
- if arguments is not None:
- raise ValueError("unexpected third argument when constructing a call" +
- "to a function")
-
- super().__init__(
- calleeOrResults.type.results,
- FlatSymbolRefAttr.get(
- calleeOrResults.name.value,
- context=_get_default_loc_context(loc)),
- argumentsOrCallee,
- loc=loc,
- ip=ip)
- return
-
- if isinstance(argumentsOrCallee, list):
- raise ValueError("when constructing a call to a function by name, " +
- "expected the second argument to be a string or a " +
- f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
-
- if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
- super().__init__(
- calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
- elif isinstance(argumentsOrCallee, str):
- super().__init__(
- calleeOrResults,
- FlatSymbolRefAttr.get(
- argumentsOrCallee, context=_get_default_loc_context(loc)),
- arguments,
- loc=loc,
- ip=ip)
+ """Specialization for the call op class."""
+
+ def __init__(
+ self,
+ calleeOrResults: Union[FuncOp, List[Type]],
+ argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
+ arguments: Optional[List] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an call operation.
+
+ The constructor accepts three
diff erent forms:
+
+ 1. A function op to be called followed by a list of arguments.
+ 2. A list of result types, followed by the name of the function to be
+ called as string, following by a list of arguments.
+ 3. A list of result types, followed by the name of the function to be
+ called as symbol reference attribute, followed by a list of arguments.
+
+ For example
+
+ f = func.FuncOp("foo", ...)
+ func.CallOp(f, [args])
+ func.CallOp([result_types], "foo", [args])
+
+ In all cases, the location and insertion point may be specified as keyword
+ arguments if not provided by the surrounding context managers.
+ """
+
+ # TODO: consider supporting constructor "overloads", e.g., through a custom
+ # or pybind-provided metaclass.
+ if isinstance(calleeOrResults, FuncOp):
+ if not isinstance(argumentsOrCallee, list):
+ raise ValueError(
+ "when constructing a call to a function, expected "
+ + "the second argument to be a list of call arguments, "
+ + f"got {type(argumentsOrCallee)}"
+ )
+ if arguments is not None:
+ raise ValueError(
+ "unexpected third argument when constructing a call"
+ + "to a function"
+ )
+
+ super().__init__(
+ calleeOrResults.type.results,
+ FlatSymbolRefAttr.get(
+ calleeOrResults.name.value, context=_get_default_loc_context(loc)
+ ),
+ argumentsOrCallee,
+ loc=loc,
+ ip=ip,
+ )
+ return
+
+ if isinstance(argumentsOrCallee, list):
+ raise ValueError(
+ "when constructing a call to a function by name, "
+ + "expected the second argument to be a string or a "
+ + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
+ )
+
+ if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
+ super().__init__(
+ calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
+ )
+ elif isinstance(argumentsOrCallee, str):
+ super().__init__(
+ calleeOrResults,
+ FlatSymbolRefAttr.get(
+ argumentsOrCallee, context=_get_default_loc_context(loc)
+ ),
+ arguments,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index eb9e969f33602..3f6d854ca3e2b 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -3,39 +3,45 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from typing import Optional, Sequence, Union
- from ..ir import *
- from ._ods_common import get_default_loc_context
- from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
+ from typing import Optional, Sequence, Union
+ from ..ir import *
+ from ._ods_common import get_default_loc_context
+ from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+
def isa(cls: Type, ty: Type):
- try:
- cls(ty)
- return True
- except ValueError:
- return False
+ try:
+ cls(ty)
+ return True
+ except ValueError:
+ return False
class StructuredOpMixin:
- """All structured ops use the same mixin class."""
+ """All structured ops use the same mixin class."""
- def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
- super().__init__(
- self.build_generic(results=list(results),
- operands=[list(inputs), list(outputs)],
- loc=loc,
- ip=ip))
+ def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
+ super().__init__(
+ self.build_generic(
+ results=list(results),
+ operands=[list(inputs), list(outputs)],
+ loc=loc,
+ ip=ip,
+ )
+ )
def select_opview_mixin(parent_opview_cls):
- # TODO: This shouldn't be a heuristic: we should have a way to annotate
- # the OpView to note that it is a structured op.
- if ("__init__" not in parent_opview_cls.__dict__ and
- hasattr(parent_opview_cls, "inputs") and
- hasattr(parent_opview_cls, "outputs") and
- hasattr(parent_opview_cls, "result_tensors")):
- return StructuredOpMixin
+ # TODO: This shouldn't be a heuristic: we should have a way to annotate
+ # the OpView to note that it is a structured op.
+ if (
+ "__init__" not in parent_opview_cls.__dict__
+ and hasattr(parent_opview_cls, "inputs")
+ and hasattr(parent_opview_cls, "outputs")
+ and hasattr(parent_opview_cls, "result_tensors")
+ ):
+ return StructuredOpMixin
diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
index 10079d32fd925..3536d45ab7369 100644
--- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
@@ -3,125 +3,130 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+ from ..ir import *
+ from ._ods_common import get_op_result_or_value as _get_op_result_or_value
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Union
class GetParentForOp:
- """Extension for GetParentForOp."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- num_loops: Optional[int] = None,
- ip=None,
- loc=None,
- ):
- if num_loops is None:
- num_loops = 1
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- num_loops=num_loops,
- ip=ip,
- loc=loc,
- )
+ """Extension for GetParentForOp."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ num_loops: Optional[int] = None,
+ ip=None,
+ loc=None,
+ ):
+ if num_loops is None:
+ num_loops = 1
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ num_loops=num_loops,
+ ip=ip,
+ loc=loc,
+ )
class LoopOutlineOp:
- """Extension for LoopOutlineOp."""
-
- def __init__(
- self,
- function_type: Type,
- call_type: Type,
- target: Union[Operation, Value],
- *,
- func_name: Union[str, StringAttr],
- ip=None,
- loc=None,
- ):
- super().__init__(
- function_type,
- call_type,
- _get_op_result_or_value(target),
- func_name=(func_name if isinstance(func_name, StringAttr) else
- StringAttr.get(func_name)),
- ip=ip,
- loc=loc,
- )
+ """Extension for LoopOutlineOp."""
+
+ def __init__(
+ self,
+ function_type: Type,
+ call_type: Type,
+ target: Union[Operation, Value],
+ *,
+ func_name: Union[str, StringAttr],
+ ip=None,
+ loc=None,
+ ):
+ super().__init__(
+ function_type,
+ call_type,
+ _get_op_result_or_value(target),
+ func_name=(
+ func_name
+ if isinstance(func_name, StringAttr)
+ else StringAttr.get(func_name)
+ ),
+ ip=ip,
+ loc=loc,
+ )
class LoopPeelOp:
- """Extension for LoopPeelOp."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- fail_if_already_divisible: Union[bool, BoolAttr] = False,
- ip=None,
- loc=None,
- ):
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- fail_if_already_divisible=(fail_if_already_divisible if isinstance(
- fail_if_already_divisible, BoolAttr) else
- BoolAttr.get(fail_if_already_divisible)),
- ip=ip,
- loc=loc,
- )
+ """Extension for LoopPeelOp."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ fail_if_already_divisible: Union[bool, BoolAttr] = False,
+ ip=None,
+ loc=None,
+ ):
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ fail_if_already_divisible=(
+ fail_if_already_divisible
+ if isinstance(fail_if_already_divisible, BoolAttr)
+ else BoolAttr.get(fail_if_already_divisible)
+ ),
+ ip=ip,
+ loc=loc,
+ )
class LoopPipelineOp:
- """Extension for LoopPipelineOp."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- iteration_interval: Optional[Union[int, IntegerAttr]] = None,
- read_latency: Optional[Union[int, IntegerAttr]] = None,
- ip=None,
- loc=None,
- ):
- if iteration_interval is None:
- iteration_interval = 1
- if read_latency is None:
- read_latency = 10
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- iteration_interval=iteration_interval,
- read_latency=read_latency,
- ip=ip,
- loc=loc,
- )
+ """Extension for LoopPipelineOp."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ iteration_interval: Optional[Union[int, IntegerAttr]] = None,
+ read_latency: Optional[Union[int, IntegerAttr]] = None,
+ ip=None,
+ loc=None,
+ ):
+ if iteration_interval is None:
+ iteration_interval = 1
+ if read_latency is None:
+ read_latency = 10
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ iteration_interval=iteration_interval,
+ read_latency=read_latency,
+ ip=ip,
+ loc=loc,
+ )
class LoopUnrollOp:
- """Extension for LoopUnrollOp."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- factor: Union[int, IntegerAttr],
- ip=None,
- loc=None,
- ):
- super().__init__(
- _get_op_result_or_value(target),
- factor=factor,
- ip=ip,
- loc=loc,
- )
+ """Extension for LoopUnrollOp."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ factor: Union[int, IntegerAttr],
+ ip=None,
+ loc=None,
+ ):
+ super().__init__(
+ _get_op_result_or_value(target),
+ factor=factor,
+ ip=ip,
+ loc=loc,
+ )
diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
index a00a087be79b2..825f1a0a7a6fa 100644
--- a/mlir/python/mlir/dialects/_memref_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_ops_ext.py
@@ -3,34 +3,34 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
- from ._ods_common import get_op_results_or_values as _get_op_results_or_values
+ from ..ir import *
+ from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+ from ._ods_common import get_op_results_or_values as _get_op_results_or_values
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
class LoadOp:
- """Specialization for the MemRef load operation."""
+ """Specialization for the MemRef load operation."""
- def __init__(self,
- memref: Union[Operation, OpView, Value],
- indices: Optional[Union[Operation, OpView,
- Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None):
- """Creates a memref load operation.
+ def __init__(
+ self,
+ memref: Union[Operation, OpView, Value],
+ indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None
+ ):
+ """Creates a memref load operation.
- Args:
- memref: the buffer to load from.
- indices: the list of subscripts, may be empty for zero-dimensional
- buffers.
- loc: user-visible location of the operation.
- ip: insertion point.
- """
- indices_resolved = [] if indices is None else _get_op_results_or_values(
- indices)
- super().__init__(memref, indices_resolved, loc=loc, ip=ip)
+ Args:
+ memref: the buffer to load from.
+ indices: the list of subscripts, may be empty for zero-dimensional
+ buffers.
+ loc: user-visible location of the operation.
+ ip: insertion point.
+ """
+ indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
+ super().__init__(memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py
index 8db82cf81c678..c84d23c16ef93 100644
--- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py
+++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py
@@ -3,11 +3,11 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from typing import Union
- from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
+ from typing import Union
+ from ..ir import *
+ from ._ods_common import get_default_loc_context as _get_default_loc_context
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from ._ml_program_ops_gen import *
@@ -17,100 +17,97 @@
class FuncOp:
- """Specialization for the func op class."""
-
- def __init__(self,
- name,
- type,
- *,
- visibility=None,
- body_builder=None,
- loc=None,
- ip=None):
- """
- Create a FuncOp with the provided `name`, `type`, and `visibility`.
- - `name` is a string representing the function name.
- - `type` is either a FunctionType or a pair of list describing inputs and
- results.
- - `visibility` is a string matching `public`, `private`, or `nested`. None
- implies private visibility.
- - `body_builder` is an optional callback, when provided a new entry block
- is created and the callback is invoked with the new op as argument within
- an InsertionPoint context already set for the block. The callback is
- expected to insert a terminator in the block.
- """
- sym_name = StringAttr.get(str(name))
-
- # If the type is passed as a tuple, build a FunctionType on the fly.
- if isinstance(type, tuple):
- type = FunctionType.get(inputs=type[0], results=type[1])
-
- type = TypeAttr.get(type)
- sym_visibility = StringAttr.get(
- str(visibility)) if visibility is not None else None
- super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
- if body_builder:
- entry_block = self.add_entry_block()
- with InsertionPoint(entry_block):
- body_builder(self)
-
- @property
- def is_external(self):
- return len(self.regions[0].blocks) == 0
-
- @property
- def body(self):
- return self.regions[0]
-
- @property
- def type(self):
- return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
- @property
- def visibility(self):
- return self.attributes["sym_visibility"]
-
- @property
- def name(self) -> StringAttr:
- return StringAttr(self.attributes["sym_name"])
-
- @property
- def entry_block(self):
- if self.is_external:
- raise IndexError('External function does not have a body')
- return self.regions[0].blocks[0]
-
- def add_entry_block(self):
- """
- Add an entry block to the function body using the function signature to
- infer block arguments.
- Returns the newly created block
- """
- if not self.is_external:
- raise IndexError('The function already has an entry block!')
- self.body.blocks.append(*self.type.inputs)
- return self.body.blocks[0]
-
- @property
- def arg_attrs(self):
- return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
- @arg_attrs.setter
- def arg_attrs(self, attribute: Union[ArrayAttr, list]):
- if isinstance(attribute, ArrayAttr):
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
- else:
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
- attribute, context=self.context)
-
- @property
- def arguments(self):
- return self.entry_block.arguments
-
- @property
- def result_attrs(self):
- return self.attributes[RESULT_ATTRIBUTE_NAME]
-
- @result_attrs.setter
- def result_attrs(self, attribute: ArrayAttr):
- self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+ """Specialization for the func op class."""
+
+ def __init__(
+ self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+ ):
+ """
+ Create a FuncOp with the provided `name`, `type`, and `visibility`.
+ - `name` is a string representing the function name.
+ - `type` is either a FunctionType or a pair of list describing inputs and
+ results.
+ - `visibility` is a string matching `public`, `private`, or `nested`. None
+ implies private visibility.
+ - `body_builder` is an optional callback, when provided a new entry block
+ is created and the callback is invoked with the new op as argument within
+ an InsertionPoint context already set for the block. The callback is
+ expected to insert a terminator in the block.
+ """
+ sym_name = StringAttr.get(str(name))
+
+ # If the type is passed as a tuple, build a FunctionType on the fly.
+ if isinstance(type, tuple):
+ type = FunctionType.get(inputs=type[0], results=type[1])
+
+ type = TypeAttr.get(type)
+ sym_visibility = (
+ StringAttr.get(str(visibility)) if visibility is not None else None
+ )
+ super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+ if body_builder:
+ entry_block = self.add_entry_block()
+ with InsertionPoint(entry_block):
+ body_builder(self)
+
+ @property
+ def is_external(self):
+ return len(self.regions[0].blocks) == 0
+
+ @property
+ def body(self):
+ return self.regions[0]
+
+ @property
+ def type(self):
+ return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+ @property
+ def visibility(self):
+ return self.attributes["sym_visibility"]
+
+ @property
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes["sym_name"])
+
+ @property
+ def entry_block(self):
+ if self.is_external:
+ raise IndexError("External function does not have a body")
+ return self.regions[0].blocks[0]
+
+ def add_entry_block(self):
+ """
+ Add an entry block to the function body using the function signature to
+ infer block arguments.
+ Returns the newly created block
+ """
+ if not self.is_external:
+ raise IndexError("The function already has an entry block!")
+ self.body.blocks.append(*self.type.inputs)
+ return self.body.blocks[0]
+
+ @property
+ def arg_attrs(self):
+ return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+ @arg_attrs.setter
+ def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+ if isinstance(attribute, ArrayAttr):
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+ else:
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ attribute, context=self.context
+ )
+
+ @property
+ def arguments(self):
+ return self.entry_block.arguments
+
+ @property
+ def result_attrs(self):
+ return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+ @result_attrs.setter
+ def result_attrs(self, attribute: ArrayAttr):
+ self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 51b90081973c4..7655629a55425 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -18,144 +18,152 @@
def extend_opview_class(ext_module):
- """Decorator to extend an OpView class from an extension module.
-
- Extension modules can expose various entry-points:
- Stand-alone class with the same name as a parent OpView class (i.e.
- "ReturnOp"). A name-based match is attempted first before falling back
- to a below mechanism.
-
- def select_opview_mixin(parent_opview_cls):
- If defined, allows an appropriate mixin class to be selected dynamically
- based on the parent OpView class. Should return NotImplemented if a
- decision is not made.
-
- Args:
- ext_module: A module from which to locate extensions. Can be None if not
- available.
-
- Returns:
- A decorator that takes an OpView subclass and further extends it as
- needed.
- """
-
- def class_decorator(parent_opview_cls: type):
- if ext_module is None:
- return parent_opview_cls
- mixin_cls = NotImplemented
- # First try to resolve by name.
- try:
- mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
- except AttributeError:
- # Fall back to a select_opview_mixin hook.
- try:
- select_mixin = getattr(ext_module, "select_opview_mixin")
- except AttributeError:
- pass
- else:
- mixin_cls = select_mixin(parent_opview_cls)
-
- if mixin_cls is NotImplemented or mixin_cls is None:
- return parent_opview_cls
-
- # Have a mixin_cls. Create an appropriate subclass.
- try:
-
- class LocalOpView(mixin_cls, parent_opview_cls):
- pass
- except TypeError as e:
- raise TypeError(
- f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e
- LocalOpView.__name__ = parent_opview_cls.__name__
- LocalOpView.__qualname__ = parent_opview_cls.__qualname__
- return LocalOpView
-
- return class_decorator
+ """Decorator to extend an OpView class from an extension module.
+
+ Extension modules can expose various entry-points:
+ Stand-alone class with the same name as a parent OpView class (i.e.
+ "ReturnOp"). A name-based match is attempted first before falling back
+ to a below mechanism.
+
+ def select_opview_mixin(parent_opview_cls):
+ If defined, allows an appropriate mixin class to be selected dynamically
+ based on the parent OpView class. Should return NotImplemented if a
+ decision is not made.
+
+ Args:
+ ext_module: A module from which to locate extensions. Can be None if not
+ available.
+
+ Returns:
+ A decorator that takes an OpView subclass and further extends it as
+ needed.
+ """
+
+ def class_decorator(parent_opview_cls: type):
+ if ext_module is None:
+ return parent_opview_cls
+ mixin_cls = NotImplemented
+ # First try to resolve by name.
+ try:
+ mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
+ except AttributeError:
+ # Fall back to a select_opview_mixin hook.
+ try:
+ select_mixin = getattr(ext_module, "select_opview_mixin")
+ except AttributeError:
+ pass
+ else:
+ mixin_cls = select_mixin(parent_opview_cls)
+
+ if mixin_cls is NotImplemented or mixin_cls is None:
+ return parent_opview_cls
+
+ # Have a mixin_cls. Create an appropriate subclass.
+ try:
+
+ class LocalOpView(mixin_cls, parent_opview_cls):
+ pass
+
+ except TypeError as e:
+ raise TypeError(
+ f"Could not mixin {mixin_cls} into {parent_opview_cls}"
+ ) from e
+ LocalOpView.__name__ = parent_opview_cls.__name__
+ LocalOpView.__qualname__ = parent_opview_cls.__qualname__
+ return LocalOpView
+
+ return class_decorator
def segmented_accessor(elements, raw_segments, idx):
- """
- Returns a slice of elements corresponding to the idx-th segment.
-
- elements: a sliceable container (operands or results).
- raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
- sizes of the segments.
- idx: index of the segment.
- """
- segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
- start = sum(segments[i] for i in range(idx))
- end = start + segments[idx]
- return elements[start:end]
-
-
-def equally_sized_accessor(elements, n_variadic, n_preceding_simple,
- n_preceding_variadic):
- """
- Returns a starting position and a number of elements per variadic group
- assuming equally-sized groups and the given numbers of preceding groups.
-
- elements: a sequential container.
- n_variadic: the number of variadic groups in the container.
- n_preceding_simple: the number of non-variadic groups preceding the current
- group.
- n_preceding_variadic: the number of variadic groups preceding the current
- group.
- """
-
- total_variadic_length = len(elements) - n_variadic + 1
- # This should be enforced by the C++-side trait verifier.
- assert total_variadic_length % n_variadic == 0
-
- elements_per_group = total_variadic_length // n_variadic
- start = n_preceding_simple + n_preceding_variadic * elements_per_group
- return start, elements_per_group
+ """
+ Returns a slice of elements corresponding to the idx-th segment.
+
+ elements: a sliceable container (operands or results).
+ raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
+ sizes of the segments.
+ idx: index of the segment.
+ """
+ segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
+ start = sum(segments[i] for i in range(idx))
+ end = start + segments[idx]
+ return elements[start:end]
+
+
+def equally_sized_accessor(
+ elements, n_variadic, n_preceding_simple, n_preceding_variadic
+):
+ """
+ Returns a starting position and a number of elements per variadic group
+ assuming equally-sized groups and the given numbers of preceding groups.
+
+ elements: a sequential container.
+ n_variadic: the number of variadic groups in the container.
+ n_preceding_simple: the number of non-variadic groups preceding the current
+ group.
+ n_preceding_variadic: the number of variadic groups preceding the current
+ group.
+ """
+
+ total_variadic_length = len(elements) - n_variadic + 1
+ # This should be enforced by the C++-side trait verifier.
+ assert total_variadic_length % n_variadic == 0
+
+ elements_per_group = total_variadic_length // n_variadic
+ start = n_preceding_simple + n_preceding_variadic * elements_per_group
+ return start, elements_per_group
def get_default_loc_context(location=None):
- """
- Returns a context in which the defaulted location is created. If the location
- is None, takes the current location from the stack, raises ValueError if there
- is no location on the stack.
- """
- if location is None:
- # Location.current raises ValueError if there is no current location.
- return _cext.ir.Location.current.context
- return location.context
+ """
+ Returns a context in which the defaulted location is created. If the location
+ is None, takes the current location from the stack, raises ValueError if there
+ is no location on the stack.
+ """
+ if location is None:
+ # Location.current raises ValueError if there is no current location.
+ return _cext.ir.Location.current.context
+ return location.context
def get_op_result_or_value(
- arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList]
+ arg: _Union[
+ _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList
+ ]
) -> _cext.ir.Value:
- """Returns the given value or the single result of the given op.
-
- This is useful to implement op constructors so that they can take other ops as
- arguments instead of requiring the caller to extract results for every op.
- Raises ValueError if provided with an op that doesn't have a single result.
- """
- if isinstance(arg, _cext.ir.OpView):
- return arg.operation.result
- elif isinstance(arg, _cext.ir.Operation):
- return arg.result
- elif isinstance(arg, _cext.ir.OpResultList):
- return arg[0]
- else:
- assert isinstance(arg, _cext.ir.Value)
- return arg
+ """Returns the given value or the single result of the given op.
+
+ This is useful to implement op constructors so that they can take other ops as
+ arguments instead of requiring the caller to extract results for every op.
+ Raises ValueError if provided with an op that doesn't have a single result.
+ """
+ if isinstance(arg, _cext.ir.OpView):
+ return arg.operation.result
+ elif isinstance(arg, _cext.ir.Operation):
+ return arg.result
+ elif isinstance(arg, _cext.ir.OpResultList):
+ return arg[0]
+ else:
+ assert isinstance(arg, _cext.ir.Value)
+ return arg
def get_op_results_or_values(
- arg: _Union[_cext.ir.OpView, _cext.ir.Operation,
- _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]]
+ arg: _Union[
+ _cext.ir.OpView,
+ _cext.ir.Operation,
+ _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
+ ]
) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
- """Returns the given sequence of values or the results of the given op.
-
- This is useful to implement op constructors so that they can take other ops as
- lists of arguments instead of requiring the caller to extract results for
- every op.
- """
- if isinstance(arg, _cext.ir.OpView):
- return arg.operation.results
- elif isinstance(arg, _cext.ir.Operation):
- return arg.results
- else:
- return [get_op_result_or_value(element) for element in arg]
+ """Returns the given sequence of values or the results of the given op.
+
+ This is useful to implement op constructors so that they can take other ops as
+ lists of arguments instead of requiring the caller to extract results for
+ every op.
+ """
+ if isinstance(arg, _cext.ir.OpView):
+ return arg.operation.results
+ elif isinstance(arg, _cext.ir.Operation):
+ return arg.results
+ else:
+ return [get_op_result_or_value(element) for element in arg]
diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
index 40ccbef6351dc..fc9de0b7f7db6 100644
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py
+++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py
@@ -3,10 +3,10 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
- from ..dialects import pdl
+ from ..ir import *
+ from ..dialects import pdl
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from typing import Union, Optional, Sequence, Mapping
from ._ods_common import (
@@ -16,264 +16,256 @@
class ApplyNativeConstraintOp:
- """Specialization for PDL apply native constraint op class."""
-
- def __init__(
- self,
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(name, args, loc=loc, ip=ip)
+ """Specialization for PDL apply native constraint op class."""
+
+ def __init__(
+ self,
+ name: Union[str, StringAttr],
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
+ args = _get_values(args)
+ super().__init__(name, args, loc=loc, ip=ip)
class ApplyNativeRewriteOp:
- """Specialization for PDL apply native rewrite op class."""
-
- def __init__(
- self,
- results: Sequence[Type],
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(results, name, args, loc=loc, ip=ip)
+ """Specialization for PDL apply native rewrite op class."""
+
+ def __init__(
+ self,
+ results: Sequence[Type],
+ name: Union[str, StringAttr],
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
+ args = _get_values(args)
+ super().__init__(results, name, args, loc=loc, ip=ip)
class AttributeOp:
- """Specialization for PDL attribute op class."""
+ """Specialization for PDL attribute op class."""
- def __init__(
- self,
- valueType: Optional[Union[OpView, Operation, Value]] = None,
- value: Optional[Attribute] = None,
- *,
- loc=None,
- ip=None,
- ):
- valueType = valueType if valueType is None else _get_value(valueType)
- result = pdl.AttributeType.get()
- super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
+ def __init__(
+ self,
+ valueType: Optional[Union[OpView, Operation, Value]] = None,
+ value: Optional[Attribute] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ valueType = valueType if valueType is None else _get_value(valueType)
+ result = pdl.AttributeType.get()
+ super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
class EraseOp:
- """Specialization for PDL erase op class."""
+ """Specialization for PDL erase op class."""
- def __init__(
- self,
- operation: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- operation = _get_value(operation)
- super().__init__(operation, loc=loc, ip=ip)
+ def __init__(
+ self,
+ operation: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ operation = _get_value(operation)
+ super().__init__(operation, loc=loc, ip=ip)
class OperandOp:
- """Specialization for PDL operand op class."""
+ """Specialization for PDL operand op class."""
- def __init__(
- self,
- type: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- type = type if type is None else _get_value(type)
- result = pdl.ValueType.get()
- super().__init__(result, valueType=type, loc=loc, ip=ip)
+ def __init__(
+ self,
+ type: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ type = type if type is None else _get_value(type)
+ result = pdl.ValueType.get()
+ super().__init__(result, valueType=type, loc=loc, ip=ip)
class OperandsOp:
- """Specialization for PDL operands op class."""
+ """Specialization for PDL operands op class."""
- def __init__(
- self,
- types: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- types = types if types is None else _get_value(types)
- result = pdl.RangeType.get(pdl.ValueType.get())
- super().__init__(result, valueType=types, loc=loc, ip=ip)
+ def __init__(
+ self,
+ types: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ types = types if types is None else _get_value(types)
+ result = pdl.RangeType.get(pdl.ValueType.get())
+ super().__init__(result, valueType=types, loc=loc, ip=ip)
class OperationOp:
- """Specialization for PDL operand op class."""
-
- def __init__(
- self,
- name: Optional[Union[str, StringAttr]] = None,
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- attributes: Optional[Mapping[str, Union[OpView, Operation,
- Value]]] = None,
- types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if types is None:
- types = []
- if attributes is None:
- attributes = {}
- if args is None:
- args = []
- args = _get_values(args)
- attrNames = []
- attrValues = []
- for attrName, attrValue in attributes.items():
- attrNames.append(StringAttr.get(attrName))
- attrValues.append(_get_value(attrValue))
- attrNames = ArrayAttr.get(attrNames)
- types = _get_values(types)
- result = pdl.OperationType.get()
- super().__init__(result,
- args,
- attrValues,
- attrNames,
- types,
- opName=name,
- loc=loc,
- ip=ip)
+ """Specialization for PDL operand op class."""
+
+ def __init__(
+ self,
+ name: Optional[Union[str, StringAttr]] = None,
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
+ types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if types is None:
+ types = []
+ if attributes is None:
+ attributes = {}
+ if args is None:
+ args = []
+ args = _get_values(args)
+ attrNames = []
+ attrValues = []
+ for attrName, attrValue in attributes.items():
+ attrNames.append(StringAttr.get(attrName))
+ attrValues.append(_get_value(attrValue))
+ attrNames = ArrayAttr.get(attrNames)
+ types = _get_values(types)
+ result = pdl.OperationType.get()
+ super().__init__(
+ result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
+ )
class PatternOp:
- """Specialization for PDL pattern op class."""
-
- def __init__(
- self,
- benefit: Union[IntegerAttr, int],
- name: Optional[Union[StringAttr, str]] = None,
- *,
- loc=None,
- ip=None,
- ):
- """Creates an PDL `pattern` operation."""
- super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
- self.regions[0].blocks.append()
-
- @property
- def body(self):
- """Return the body (block) of the pattern."""
- return self.regions[0].blocks[0]
+ """Specialization for PDL pattern op class."""
+
+ def __init__(
+ self,
+ benefit: Union[IntegerAttr, int],
+ name: Optional[Union[StringAttr, str]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an PDL `pattern` operation."""
+ super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ """Return the body (block) of the pattern."""
+ return self.regions[0].blocks[0]
class ReplaceOp:
- """Specialization for PDL replace op class."""
-
- def __init__(
- self,
- op: Union[OpView, Operation, Value],
- *,
- with_op: Optional[Union[OpView, Operation, Value]] = None,
- with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- loc=None,
- ip=None,
- ):
- if with_values is None:
- with_values = []
- op = _get_value(op)
- with_op = with_op if with_op is None else _get_value(with_op)
- with_values = _get_values(with_values)
- super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
+ """Specialization for PDL replace op class."""
+
+ def __init__(
+ self,
+ op: Union[OpView, Operation, Value],
+ *,
+ with_op: Optional[Union[OpView, Operation, Value]] = None,
+ with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if with_values is None:
+ with_values = []
+ op = _get_value(op)
+ with_op = with_op if with_op is None else _get_value(with_op)
+ with_values = _get_values(with_values)
+ super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
class ResultOp:
- """Specialization for PDL result op class."""
+ """Specialization for PDL result op class."""
- def __init__(
- self,
- parent: Union[OpView, Operation, Value],
- index: Union[IntegerAttr, int],
- *,
- loc=None,
- ip=None,
- ):
- parent = _get_value(parent)
- result = pdl.ValueType.get()
- super().__init__(result, parent, index, loc=loc, ip=ip)
+ def __init__(
+ self,
+ parent: Union[OpView, Operation, Value],
+ index: Union[IntegerAttr, int],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ parent = _get_value(parent)
+ result = pdl.ValueType.get()
+ super().__init__(result, parent, index, loc=loc, ip=ip)
class ResultsOp:
- """Specialization for PDL results op class."""
+ """Specialization for PDL results op class."""
- def __init__(
- self,
- result: Type,
- parent: Union[OpView, Operation, Value],
- index: Optional[Union[IntegerAttr, int]] = None,
- *,
- loc=None,
- ip=None,
- ):
- parent = _get_value(parent)
- super().__init__(result, parent, index=index, loc=loc, ip=ip)
+ def __init__(
+ self,
+ result: Type,
+ parent: Union[OpView, Operation, Value],
+ index: Optional[Union[IntegerAttr, int]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ parent = _get_value(parent)
+ super().__init__(result, parent, index=index, loc=loc, ip=ip)
class RewriteOp:
- """Specialization for PDL rewrite op class."""
-
- def __init__(
- self,
- root: Optional[Union[OpView, Operation, Value]] = None,
- name: Optional[Union[StringAttr, str]] = None,
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- root = root if root is None else _get_value(root)
- args = _get_values(args)
- super().__init__(args, root=root, name=name, loc=loc, ip=ip)
-
- def add_body(self):
- """Add body (block) to the rewrite."""
- self.regions[0].blocks.append()
- return self.body
-
- @property
- def body(self):
- """Return the body (block) of the rewrite."""
- return self.regions[0].blocks[0]
+ """Specialization for PDL rewrite op class."""
+
+ def __init__(
+ self,
+ root: Optional[Union[OpView, Operation, Value]] = None,
+ name: Optional[Union[StringAttr, str]] = None,
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
+ root = root if root is None else _get_value(root)
+ args = _get_values(args)
+ super().__init__(args, root=root, name=name, loc=loc, ip=ip)
+
+ def add_body(self):
+ """Add body (block) to the rewrite."""
+ self.regions[0].blocks.append()
+ return self.body
+
+ @property
+ def body(self):
+ """Return the body (block) of the rewrite."""
+ return self.regions[0].blocks[0]
class TypeOp:
- """Specialization for PDL type op class."""
+ """Specialization for PDL type op class."""
- def __init__(self,
- constantType: Optional[Union[TypeAttr, Type]] = None,
- *,
- loc=None,
- ip=None):
- result = pdl.TypeType.get()
- super().__init__(result, constantType=constantType, loc=loc, ip=ip)
+ def __init__(
+ self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
+ ):
+ result = pdl.TypeType.get()
+ super().__init__(result, constantType=constantType, loc=loc, ip=ip)
class TypesOp:
- """Specialization for PDL types op class."""
-
- def __init__(
- self,
- constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if constantTypes is None:
- constantTypes = []
- result = pdl.RangeType.get(pdl.TypeType.get())
- super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
+ """Specialization for PDL types op class."""
+
+ def __init__(
+ self,
+ constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if constantTypes is None:
+ constantTypes = []
+ result = pdl.RangeType.get(pdl.TypeType.get())
+ super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index 3c3e673021585..4b2519ef35357 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -3,105 +3,104 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
+ from ..ir import *
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from typing import Any, Optional, Sequence, Union
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+)
+
class ForOp:
- """Specialization for the SCF for op class."""
-
- def __init__(self,
- lower_bound,
- upper_bound,
- step,
- iter_args: Optional[Union[Operation, OpView,
- Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None):
- """Creates an SCF `for` operation.
-
- - `lower_bound` is the value to use as lower bound of the loop.
- - `upper_bound` is the value to use as upper bound of the loop.
- - `step` is the value to use as loop step.
- - `iter_args` is a list of additional loop-carried arguments or an operation
- producing them as results.
- """
- if iter_args is None:
- iter_args = []
- iter_args = _get_op_results_or_values(iter_args)
-
- results = [arg.type for arg in iter_args]
- super().__init__(
- self.build_generic(
- regions=1,
- results=results,
- operands=[
- _get_op_result_or_value(o)
- for o in [lower_bound, upper_bound, step]
- ] + list(iter_args),
- loc=loc,
- ip=ip))
- self.regions[0].blocks.append(IndexType.get(), *results)
-
- @property
- def body(self):
- """Returns the body (block) of the loop."""
- return self.regions[0].blocks[0]
-
- @property
- def induction_variable(self):
- """Returns the induction variable of the loop."""
- return self.body.arguments[0]
-
- @property
- def inner_iter_args(self):
- """Returns the loop-carried arguments usable within the loop.
-
- To obtain the loop-carried operands, use `iter_args`.
- """
- return self.body.arguments[1:]
+ """Specialization for the SCF for op class."""
+
+ def __init__(
+ self,
+ lower_bound,
+ upper_bound,
+ step,
+ iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None
+ ):
+ """Creates an SCF `for` operation.
+
+ - `lower_bound` is the value to use as lower bound of the loop.
+ - `upper_bound` is the value to use as upper bound of the loop.
+ - `step` is the value to use as loop step.
+ - `iter_args` is a list of additional loop-carried arguments or an operation
+ producing them as results.
+ """
+ if iter_args is None:
+ iter_args = []
+ iter_args = _get_op_results_or_values(iter_args)
+
+ results = [arg.type for arg in iter_args]
+ super().__init__(
+ self.build_generic(
+ regions=1,
+ results=results,
+ operands=[
+ _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
+ ]
+ + list(iter_args),
+ loc=loc,
+ ip=ip,
+ )
+ )
+ self.regions[0].blocks.append(IndexType.get(), *results)
+
+ @property
+ def body(self):
+ """Returns the body (block) of the loop."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def induction_variable(self):
+ """Returns the induction variable of the loop."""
+ return self.body.arguments[0]
+
+ @property
+ def inner_iter_args(self):
+ """Returns the loop-carried arguments usable within the loop.
+
+ To obtain the loop-carried operands, use `iter_args`.
+ """
+ return self.body.arguments[1:]
class IfOp:
- """Specialization for the SCF if op class."""
-
- def __init__(self,
- cond,
- results_=[],
- *,
- hasElse=False,
- loc=None,
- ip=None):
- """Creates an SCF `if` operation.
-
- - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- - `hasElse` determines whether the if operation has the else branch.
- """
- operands = []
- operands.append(cond)
- results = []
- results.extend(results_)
- super().__init__(
- self.build_generic(
- regions=2,
- results=results,
- operands=operands,
- loc=loc,
- ip=ip))
- self.regions[0].blocks.append(*[])
- if hasElse:
- self.regions[1].blocks.append(*[])
-
- @property
- def then_block(self):
- """Returns the then block of the if operation."""
- return self.regions[0].blocks[0]
-
- @property
- def else_block(self):
- """Returns the else block of the if operation."""
- return self.regions[1].blocks[0]
+ """Specialization for the SCF if op class."""
+
+ def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
+ """Creates an SCF `if` operation.
+
+ - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
+ - `hasElse` determines whether the if operation has the else branch.
+ """
+ operands = []
+ operands.append(cond)
+ results = []
+ results.extend(results_)
+ super().__init__(
+ self.build_generic(
+ regions=2, results=results, operands=operands, loc=loc, ip=ip
+ )
+ )
+ self.regions[0].blocks.append(*[])
+ if hasElse:
+ self.regions[1].blocks.append(*[])
+
+ @property
+ def then_block(self):
+ """Returns the then block of the if operation."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def else_block(self):
+ """Returns the else block of the if operation."""
+ return self.regions[1].blocks[0]
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 9c051cd3d146d..30dafff6a11c5 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -3,11 +3,11 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
- from ..dialects import pdl, transform
+ from ..ir import *
+ from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+ from ..dialects import pdl, transform
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from typing import List, Optional, Sequence, Union, overload
@@ -16,312 +16,315 @@
def _get_int_int_array_attr(
- values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
- IntOrAttrList]]]]
+ values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
- """Creates an array attribute containing array attributes of integers.
+ """Creates an array attribute containing array attributes of integers.
If the operand is already an array attribute, forwards it. Otherwise treats
the operand as a list of attributes or integers, potentially interpserced, to
create a new array-of-array attribute. Expects the thread-local MLIR context
to have been set by the context manager.
"""
- if values is None:
- return ArrayAttr.get([])
- if isinstance(values, ArrayAttr):
- return values
- if isinstance(values, list):
- values = [
- ArrayAttr.get(
- [IntegerAttr.get(IntegerType.get_signless(64), v)
- for v in value])
- for value in values
- ]
+ if values is None:
+ return ArrayAttr.get([])
+ if isinstance(values, ArrayAttr):
+ return values
+ if isinstance(values, list):
+ values = [
+ ArrayAttr.get(
+ [IntegerAttr.get(IntegerType.get_signless(64), v) for v in value]
+ )
+ for value in values
+ ]
- return ArrayAttr.get(values)
+ return ArrayAttr.get(values)
class DecomposeOp:
- """Specialization for DecomposeOp class."""
+ """Specialization for DecomposeOp class."""
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- super().__init__(pdl.OperationType.get(),
- _get_op_result_or_value(target),
- loc=loc,
- ip=ip)
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ super().__init__(
+ pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
+ )
class GeneralizeOp:
- """Specialization for GeneralizeOp class."""
+ """Specialization for GeneralizeOp class."""
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- super().__init__(pdl.OperationType.get(),
- _get_op_result_or_value(target),
- loc=loc,
- ip=ip)
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ super().__init__(
+ pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
+ )
class InterchangeOp:
- """Specialization for InterchangeOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- iterator_interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- pdl_operation_type = pdl.OperationType.get()
- super().__init__(
- pdl_operation_type,
- _get_op_result_or_value(target),
- iterator_interchange=iterator_interchange,
- loc=loc,
- ip=ip,
- )
+ """Specialization for InterchangeOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ iterator_interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ pdl_operation_type = pdl.OperationType.get()
+ super().__init__(
+ pdl_operation_type,
+ _get_op_result_or_value(target),
+ iterator_interchange=iterator_interchange,
+ loc=loc,
+ ip=ip,
+ )
class MatchOp:
- """Specialization for MatchOp class."""
-
- @classmethod
- def match_op_names(
- MatchOp,
- target: Union[Operation, Value],
- names: Sequence[str],
- loc=None,
- ip=None,
- ):
- pdl_operation_type = pdl.OperationType.get()
- return MatchOp(
- pdl_operation_type,
- _get_op_result_or_value(target),
- ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
- loc=loc,
- ip=ip,
- )
+ """Specialization for MatchOp class."""
+
+ @classmethod
+ def match_op_names(
+ MatchOp,
+ target: Union[Operation, Value],
+ names: Sequence[str],
+ loc=None,
+ ip=None,
+ ):
+ pdl_operation_type = pdl.OperationType.get()
+ return MatchOp(
+ pdl_operation_type,
+ _get_op_result_or_value(target),
+ ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
+ loc=loc,
+ ip=ip,
+ )
class MultiTileSizesOp:
- """Specialization for MultitileSizesOp class."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- dimension: Union[int, IntegerAttr],
- target_size: Union[int, IntegerAttr],
- divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
- loc=None,
- ip=None,
- ):
- if divisor is None:
- divisor = 1
- super().__init__(
- result_type,
- result_type,
- result_type,
- _get_op_result_or_value(target),
- dimension=dimension,
- target_size=target_size,
- divisor=divisor,
- loc=loc,
- ip=ip,
- )
+ """Specialization for MultitileSizesOp class."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ dimension: Union[int, IntegerAttr],
+ target_size: Union[int, IntegerAttr],
+ divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if divisor is None:
+ divisor = 1
+ super().__init__(
+ result_type,
+ result_type,
+ result_type,
+ _get_op_result_or_value(target),
+ dimension=dimension,
+ target_size=target_size,
+ divisor=divisor,
+ loc=loc,
+ ip=ip,
+ )
class PadOp:
- """Specialization for PadOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- padding_values: Optional[Optional[Union[ArrayAttr,
- Sequence[Attribute]]]] = None,
- padding_dimensions: OptionalIntList = None,
- pack_paddings: OptionalIntList = None,
- transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
- ArrayAttr, IntOrAttrList]]]] = None,
- loc=None,
- ip=None,
- ):
- if transpose_paddings is None:
- transpose_paddings = []
- if pack_paddings is None:
- pack_paddings = []
- if padding_dimensions is None:
- padding_dimensions = []
- if padding_values is None:
- padding_values = []
- pdl_operation_type = pdl.OperationType.get()
- transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
- super().__init__(
- pdl_operation_type,
- _get_op_result_or_value(target),
- padding_values=padding_values,
- padding_dimensions=padding_dimensions,
- pack_paddings=pack_paddings,
- transpose_paddings=transpose_paddings_attr,
- loc=loc,
- ip=ip,
- )
+ """Specialization for PadOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ padding_values: Optional[
+ Optional[Union[ArrayAttr, Sequence[Attribute]]]
+ ] = None,
+ padding_dimensions: OptionalIntList = None,
+ pack_paddings: OptionalIntList = None,
+ transpose_paddings: Optional[
+ Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
+ ] = None,
+ loc=None,
+ ip=None,
+ ):
+ if transpose_paddings is None:
+ transpose_paddings = []
+ if pack_paddings is None:
+ pack_paddings = []
+ if padding_dimensions is None:
+ padding_dimensions = []
+ if padding_values is None:
+ padding_values = []
+ pdl_operation_type = pdl.OperationType.get()
+ transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
+ super().__init__(
+ pdl_operation_type,
+ _get_op_result_or_value(target),
+ padding_values=padding_values,
+ padding_dimensions=padding_dimensions,
+ pack_paddings=pack_paddings,
+ transpose_paddings=transpose_paddings_attr,
+ loc=loc,
+ ip=ip,
+ )
class ScalarizeOp:
- """Specialization for ScalarizeOp class."""
+ """Specialization for ScalarizeOp class."""
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- pdl_operation_type = pdl.OperationType.get()
- super().__init__(pdl_operation_type,
- _get_op_result_or_value(target),
- loc=loc,
- ip=ip)
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ pdl_operation_type = pdl.OperationType.get()
+ super().__init__(
+ pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
+ )
class SplitOp:
- """Specialization for SplitOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- dimension: Union[int, Attribute],
- split_point: Union[int, Operation, Value, Attribute],
- *,
- loc=None,
- ip=None,
- ):
- if isinstance(split_point, int):
- static_split_point = split_point
- dynamic_split_point = None
- else:
- static_split_point = ShapedType.get_dynamic_size()
- dynamic_split_point = _get_op_result_or_value(split_point)
-
- target = _get_op_result_or_value(target)
-
- super().__init__(
- target.type,
- target.type,
- target,
- dimension=dimension,
- static_split_point=static_split_point,
- dynamic_split_point=dynamic_split_point,
- loc=loc,
- ip=ip,
- )
+ """Specialization for SplitOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ dimension: Union[int, Attribute],
+ split_point: Union[int, Operation, Value, Attribute],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(split_point, int):
+ static_split_point = split_point
+ dynamic_split_point = None
+ else:
+ static_split_point = ShapedType.get_dynamic_size()
+ dynamic_split_point = _get_op_result_or_value(split_point)
+
+ target = _get_op_result_or_value(target)
+
+ super().__init__(
+ target.type,
+ target.type,
+ target,
+ dimension=dimension,
+ static_split_point=static_split_point,
+ dynamic_split_point=dynamic_split_point,
+ loc=loc,
+ ip=ip,
+ )
class TileOp:
- """Specialization for TileOp class."""
-
- @overload
- def __init__(
- self,
- loop_types: Union[Type, List[Type]],
- target: Union[Operation, Value],
- *,
- sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
- ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, Value, OpView],
- *,
- sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
- ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- ...
-
- def __init__(
- self,
- loop_types_or_target: Union[Type, List[Type], Operation, Value],
- target_or_none: Optional[Union[Operation, Value, OpView]] = None,
- *,
- sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
- ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- if interchange is None:
- interchange = []
- if sizes is None:
- sizes = []
-
- static_sizes = []
- dynamic_sizes = []
- if isinstance(sizes, ArrayAttr):
- sizes_attr = sizes
- else:
- for size in sizes:
- if isinstance(size, int):
- static_sizes.append(size)
+ """Specialization for TileOp class."""
+
+ @overload
+ def __init__(
+ self,
+ loop_types: Union[Type, List[Type]],
+ target: Union[Operation, Value],
+ *,
+ sizes: Optional[
+ Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
+ ] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ sizes: Optional[
+ Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
+ ] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loop_types_or_target: Union[Type, List[Type], Operation, Value],
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ sizes: Optional[
+ Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
+ ] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ if interchange is None:
+ interchange = []
+ if sizes is None:
+ sizes = []
+
+ static_sizes = []
+ dynamic_sizes = []
+ if isinstance(sizes, ArrayAttr):
+ sizes_attr = sizes
+ else:
+ for size in sizes:
+ if isinstance(size, int):
+ static_sizes.append(size)
+ else:
+ static_sizes.append(ShapedType.get_dynamic_size())
+ dynamic_sizes.append(_get_op_result_or_value(size))
+ sizes_attr = DenseI64ArrayAttr.get(static_sizes)
+
+ num_loops = sum(v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
+
+ if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+ loop_types = [transform.AnyOpType.get()] * num_loops
+ target = loop_types_or_target
+ assert target_or_none is None, "Cannot construct TileOp with two targets."
else:
- static_sizes.append(ShapedType.get_dynamic_size())
- dynamic_sizes.append(_get_op_result_or_value(size))
- sizes_attr = DenseI64ArrayAttr.get(static_sizes)
-
- num_loops = sum(
- v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
-
- if isinstance(loop_types_or_target, (Operation, Value, OpView)):
- loop_types = [transform.AnyOpType.get()] * num_loops
- target = loop_types_or_target
- assert target_or_none is None, "Cannot construct TileOp with two targets."
- else:
- loop_types = (([loop_types_or_target] * num_loops) if isinstance(
- loop_types_or_target, Type) else loop_types_or_target)
- target = target_or_none
-
- target = _get_op_result_or_value(target)
-
- super().__init__(
- target.type,
- loop_types,
- target,
- dynamic_sizes=dynamic_sizes,
- static_sizes=sizes_attr,
- interchange=interchange,
- loc=loc,
- ip=ip,
- )
-
- def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
- if not attr:
- return []
- return [element for element in attr]
+ loop_types = (
+ ([loop_types_or_target] * num_loops)
+ if isinstance(loop_types_or_target, Type)
+ else loop_types_or_target
+ )
+ target = target_or_none
+
+ target = _get_op_result_or_value(target)
+
+ super().__init__(
+ target.type,
+ loop_types,
+ target,
+ dynamic_sizes=dynamic_sizes,
+ static_sizes=sizes_attr,
+ interchange=interchange,
+ loc=loc,
+ ip=ip,
+ )
+
+ def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
+ if not attr:
+ return []
+ return [element for element in attr]
class VectorizeOp:
- """Specialization for VectorizeOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- vectorize_padding: Union[bool, BoolAttr] = False,
- loc=None,
- ip=None,
- ):
- pdl_operation_type = pdl.OperationType.get()
- if isinstance(vectorize_padding, bool):
- vectorize_padding = UnitAttr.get()
- super().__init__(
- pdl_operation_type,
- _get_op_result_or_value(target),
- vectorize_padding=vectorize_padding,
- loc=loc,
- ip=ip,
- )
+ """Specialization for VectorizeOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ vectorize_padding: Union[bool, BoolAttr] = False,
+ loc=None,
+ ip=None,
+ ):
+ pdl_operation_type = pdl.OperationType.get()
+ if isinstance(vectorize_padding, bool):
+ vectorize_padding = UnitAttr.get()
+ super().__init__(
+ pdl_operation_type,
+ _get_op_result_or_value(target),
+ vectorize_padding=vectorize_padding,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py
index 51d998b6e3ceb..09b9ec68db7d9 100644
--- a/mlir/python/mlir/dialects/_tensor_ops_ext.py
+++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py
@@ -3,40 +3,42 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
+ from ..ir import *
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from typing import Any, Optional, Sequence, Union
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+)
class EmptyOp:
- """Extends the tensor.empty op."""
+ """Extends the tensor.empty op."""
- def __init__(self,
- sizes: Sequence[Union[int, Value]],
- element_type: Type,
- *,
- loc=None,
- ip=None):
- """Constructs an `empty` with mixed static/dynamic sizes."""
- # TODO: Refactor the EmptyOp to take an element type attribute and
- # then use normal result type inference, unifying the Python and C++ side
- # with a standard mechanism (versus stashing that in builders).
- dynamic_sizes = []
- static_sizes = []
- for s in sizes:
- if isinstance(s, int):
- static_sizes.append(s)
- else:
- static_sizes.append(ShapedType.get_dynamic_size())
- dynamic_sizes.append(s)
- result_type = RankedTensorType.get(static_sizes, element_type)
- op = self.build_generic(
- results=[result_type],
- operands=dynamic_sizes,
- attributes={},
- loc=loc,
- ip=ip)
- OpView.__init__(self, op)
+ def __init__(
+ self,
+ sizes: Sequence[Union[int, Value]],
+ element_type: Type,
+ *,
+ loc=None,
+ ip=None
+ ):
+ """Constructs an `empty` with mixed static/dynamic sizes."""
+ # TODO: Refactor the EmptyOp to take an element type attribute and
+ # then use normal result type inference, unifying the Python and C++ side
+ # with a standard mechanism (versus stashing that in builders).
+ dynamic_sizes = []
+ static_sizes = []
+ for s in sizes:
+ if isinstance(s, int):
+ static_sizes.append(s)
+ else:
+ static_sizes.append(ShapedType.get_dynamic_size())
+ dynamic_sizes.append(s)
+ result_type = RankedTensorType.get(static_sizes, element_type)
+ op = self.build_generic(
+ results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
+ )
+ OpView.__init__(self, op)
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index cc4428ea5b115..425ec65859d39 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -3,144 +3,131 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- )
+ from ..ir import *
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ )
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
class CastOp:
-
- def __init__(self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- loc=None,
- ip=None):
- super().__init__(result_type,
- _get_op_result_or_value(target),
- loc=loc,
- ip=ip)
+ def __init__(
+ self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
+ ):
+ super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
class GetClosestIsolatedParentOp:
-
- def __init__(self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- loc=None,
- ip=None):
- super().__init__(result_type,
- _get_op_result_or_value(target),
- loc=loc,
- ip=ip)
+ def __init__(
+ self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
+ ):
+ super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
class MergeHandlesOp:
-
- def __init__(
- self,
- handles: Sequence[Union[Operation, Value]],
- *,
- deduplicate: bool = False,
- loc=None,
- ip=None,
- ):
- super().__init__(
- [_get_op_result_or_value(h) for h in handles],
- deduplicate=deduplicate,
- loc=loc,
- ip=ip,
- )
+ def __init__(
+ self,
+ handles: Sequence[Union[Operation, Value]],
+ *,
+ deduplicate: bool = False,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ [_get_op_result_or_value(h) for h in handles],
+ deduplicate=deduplicate,
+ loc=loc,
+ ip=ip,
+ )
class ReplicateOp:
-
- def __init__(
- self,
- pattern: Union[Operation, Value],
- handles: Sequence[Union[Operation, Value]],
- *,
- loc=None,
- ip=None,
- ):
- super().__init__(
- [_get_op_result_or_value(h).type for h in handles],
- _get_op_result_or_value(pattern),
- [_get_op_result_or_value(h) for h in handles],
- loc=loc,
- ip=ip,
- )
+ def __init__(
+ self,
+ pattern: Union[Operation, Value],
+ handles: Sequence[Union[Operation, Value]],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ [_get_op_result_or_value(h).type for h in handles],
+ _get_op_result_or_value(pattern),
+ [_get_op_result_or_value(h) for h in handles],
+ loc=loc,
+ ip=ip,
+ )
class SequenceOp:
-
- def __init__(
- self,
- failure_propagation_mode,
- results: Sequence[Type],
- target: Union[Operation, Value, Type],
- extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation,
- OpView]] = None,
- ):
- root = (_get_op_result_or_value(target) if isinstance(
- target, (Operation, Value)) else None)
- root_type = root.type if not isinstance(target, Type) else target
- if not isinstance(failure_propagation_mode, Attribute):
- failure_propagation_mode_attr = IntegerAttr.get(
- IntegerType.get_signless(32), failure_propagation_mode._as_int())
- else:
- failure_propagation_mode_attr = failure_propagation_mode
-
- if extra_bindings is None:
- extra_bindings = []
- if isinstance(extra_bindings, (Operation, OpView)):
- extra_bindings = _get_op_results_or_values(extra_bindings)
-
- extra_binding_types = []
- if len(extra_bindings) != 0:
- if isinstance(extra_bindings[0], Type):
- extra_binding_types = extra_bindings
- extra_bindings = []
- else:
- extra_binding_types = [v.type for v in extra_bindings]
-
- super().__init__(
- results_=results,
- failure_propagation_mode=failure_propagation_mode_attr,
- root=root,
- extra_bindings=extra_bindings,
- )
- self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
-
- @property
- def bodyTarget(self) -> Value:
- return self.body.arguments[0]
-
- @property
- def bodyExtraArgs(self) -> BlockArgumentList:
- return self.body.arguments[1:]
+ def __init__(
+ self,
+ failure_propagation_mode,
+ results: Sequence[Type],
+ target: Union[Operation, Value, Type],
+ extra_bindings: Optional[
+ Union[Sequence[Value], Sequence[Type], Operation, OpView]
+ ] = None,
+ ):
+ root = (
+ _get_op_result_or_value(target)
+ if isinstance(target, (Operation, Value))
+ else None
+ )
+ root_type = root.type if not isinstance(target, Type) else target
+ if not isinstance(failure_propagation_mode, Attribute):
+ failure_propagation_mode_attr = IntegerAttr.get(
+ IntegerType.get_signless(32), failure_propagation_mode._as_int()
+ )
+ else:
+ failure_propagation_mode_attr = failure_propagation_mode
+
+ if extra_bindings is None:
+ extra_bindings = []
+ if isinstance(extra_bindings, (Operation, OpView)):
+ extra_bindings = _get_op_results_or_values(extra_bindings)
+
+ extra_binding_types = []
+ if len(extra_bindings) != 0:
+ if isinstance(extra_bindings[0], Type):
+ extra_binding_types = extra_bindings
+ extra_bindings = []
+ else:
+ extra_binding_types = [v.type for v in extra_bindings]
+
+ super().__init__(
+ results_=results,
+ failure_propagation_mode=failure_propagation_mode_attr,
+ root=root,
+ extra_bindings=extra_bindings,
+ )
+ self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
+
+ @property
+ def bodyTarget(self) -> Value:
+ return self.body.arguments[0]
+
+ @property
+ def bodyExtraArgs(self) -> BlockArgumentList:
+ return self.body.arguments[1:]
class YieldOp:
-
- def __init__(
- self,
- operands: Optional[Union[Operation, Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if operands is None:
- operands = []
- super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+ def __init__(
+ self,
+ operands: Optional[Union[Operation, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if operands is None:
+ operands = []
+ super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
index 5a695d6216770..2f651319930fc 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
@@ -31,61 +31,60 @@
def create_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(description="Dump an oplib in various formats")
- p.add_argument("modules",
- metavar="M",
- type=str,
- nargs="*",
- help="Op module to dump")
- p.add_argument("--file",
- metavar="F",
- type=str,
- nargs="*",
- help="Python op file to dump")
- p.add_argument("--format",
- type=str,
- dest="format",
- default="yaml",
- choices=("yaml", "repr"),
- help="Format in which to dump")
- return p
+ p = argparse.ArgumentParser(description="Dump an oplib in various formats")
+ p.add_argument(
+ "modules", metavar="M", type=str, nargs="*", help="Op module to dump"
+ )
+ p.add_argument(
+ "--file", metavar="F", type=str, nargs="*", help="Python op file to dump"
+ )
+ p.add_argument(
+ "--format",
+ type=str,
+ dest="format",
+ default="yaml",
+ choices=("yaml", "repr"),
+ help="Format in which to dump",
+ )
+ return p
def load_module_from_file(module_name, file_path):
- spec = importlib.util.spec_from_file_location(module_name, file_path)
- m = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(m)
- return m
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ m = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(m)
+ return m
def main(args):
- # Load all configs.
- configs = []
- modules = []
- for module_name in args.modules:
- modules.append(
- importlib.import_module(module_name,
- package="mlir.dialects.linalg.opdsl"))
- for i, file_path in enumerate(args.file or []):
- modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
- for m in modules:
- for attr_name, value in m.__dict__.items():
- # TODO: This class layering is awkward.
- if isinstance(value, DefinedOpCallable):
- try:
- linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
- except Exception as e:
- raise ValueError(
- f"Could not create LinalgOpConfig from {value.op_def}") from e
- configs.extend(linalg_config)
-
- # Print.
- if args.format == "yaml":
- print(yaml_dump_all(configs))
- elif args.format == "repr":
- for config in configs:
- print(repr(config))
+ # Load all configs.
+ configs = []
+ modules = []
+ for module_name in args.modules:
+ modules.append(
+ importlib.import_module(module_name, package="mlir.dialects.linalg.opdsl")
+ )
+ for i, file_path in enumerate(args.file or []):
+ modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
+ for m in modules:
+ for attr_name, value in m.__dict__.items():
+ # TODO: This class layering is awkward.
+ if isinstance(value, DefinedOpCallable):
+ try:
+ linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
+ except Exception as e:
+ raise ValueError(
+ f"Could not create LinalgOpConfig from {value.op_def}"
+ ) from e
+ configs.extend(linalg_config)
+
+ # Print.
+ if args.format == "yaml":
+ print(yaml_dump_all(configs))
+ elif args.format == "repr":
+ for config in configs:
+ print(repr(config))
if __name__ == "__main__":
- main(create_arg_parser().parse_args())
+ main(create_arg_parser().parse_args())
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
index 038f068345428..9fa626dfa78b1 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
@@ -66,201 +66,201 @@
class AffineBuildState:
- """Internal state for the AffineExprDef._create impls.
-
- Note that a "local" AffineBuildState can be created relative to a "global"
- AffineBuildState. In that case, any affine expressions built will inherit
- symbol and dim bindings from the global state and will update both as new
- ones are discovered. This allows for building expressions across contexts
- which share a common symbol and dim space.
- """
-
- def __init__(self,
- *,
- global_state: "AffineBuildState" = None,
- allow_new_symbols: bool = True,
- allow_new_dims: bool = True):
- if not global_state:
- self.all_symbols = dict() # type: Dict[str, int]
- self.all_dims = dict() # type: Dict[str, int]
- else:
- # Alias the global dict.
- self.all_symbols = global_state.all_symbols
- self.all_dims = global_state.all_dims
-
- # Map of symbols and dims in the current build.
- self.local_symbols = dict() # type: Dict[str, int]
- self.local_dims = dict() # type: Dict[str, int]
- self.allow_new_symbols = allow_new_symbols
- self.allow_new_dims = allow_new_dims
-
- def get_dim(self, dimname: str) -> int:
- """Gets the dim position given a name."""
- pos = self.all_dims.get(dimname)
- if pos is None:
- if not self.allow_new_dims:
- raise ValueError(
- f"New dimensions not allowed in the current affine expression: "
- f"Requested '{dimname}', Availble: {self.all_dims}")
- pos = len(self.all_dims)
- self.all_dims[dimname] = pos
- self.local_dims[dimname] = pos
- return pos
-
- def get_symbol(self, symname: str) -> int:
- """Geta a symbol position given a name."""
- pos = self.all_symbols.get(symname)
- if pos is None:
- if not self.allow_new_symbols:
- raise ValueError(
- f"New symbols not allowed in the current affine expression: "
- f"Requested '{symname}', Availble: {self.all_symbols}")
- pos = len(self.all_symbols)
- self.all_symbols[symname] = pos
- self.local_symbols[symname] = pos
- return pos
-
- @property
- def local_dim_count(self) -> int:
- return len(self.local_dims)
-
- @property
- def local_symbol_count(self) -> int:
- return len(self.local_symbols)
-
- @property
- def dim_count(self) -> int:
- return len(self.all_dims)
-
- @property
- def symbol_count(self) -> int:
- return len(self.all_symbols)
-
- def __repr__(self):
- lines = [f"AffineBuildState<"]
- lines.append(f" symbols={self.local_symbols}")
- lines.append(f" dims={self.local_dims}>")
- return "\n".join(lines)
+ """Internal state for the AffineExprDef._create impls.
+
+ Note that a "local" AffineBuildState can be created relative to a "global"
+ AffineBuildState. In that case, any affine expressions built will inherit
+ symbol and dim bindings from the global state and will update both as new
+ ones are discovered. This allows for building expressions across contexts
+ which share a common symbol and dim space.
+ """
+
+ def __init__(
+ self,
+ *,
+ global_state: "AffineBuildState" = None,
+ allow_new_symbols: bool = True,
+ allow_new_dims: bool = True,
+ ):
+ if not global_state:
+ self.all_symbols = dict() # type: Dict[str, int]
+ self.all_dims = dict() # type: Dict[str, int]
+ else:
+ # Alias the global dict.
+ self.all_symbols = global_state.all_symbols
+ self.all_dims = global_state.all_dims
+
+ # Map of symbols and dims in the current build.
+ self.local_symbols = dict() # type: Dict[str, int]
+ self.local_dims = dict() # type: Dict[str, int]
+ self.allow_new_symbols = allow_new_symbols
+ self.allow_new_dims = allow_new_dims
+
+ def get_dim(self, dimname: str) -> int:
+ """Gets the dim position given a name."""
+ pos = self.all_dims.get(dimname)
+ if pos is None:
+ if not self.allow_new_dims:
+ raise ValueError(
+ f"New dimensions not allowed in the current affine expression: "
+ f"Requested '{dimname}', Availble: {self.all_dims}"
+ )
+ pos = len(self.all_dims)
+ self.all_dims[dimname] = pos
+ self.local_dims[dimname] = pos
+ return pos
+
+ def get_symbol(self, symname: str) -> int:
+ """Geta a symbol position given a name."""
+ pos = self.all_symbols.get(symname)
+ if pos is None:
+ if not self.allow_new_symbols:
+ raise ValueError(
+ f"New symbols not allowed in the current affine expression: "
+ f"Requested '{symname}', Availble: {self.all_symbols}"
+ )
+ pos = len(self.all_symbols)
+ self.all_symbols[symname] = pos
+ self.local_symbols[symname] = pos
+ return pos
+
+ @property
+ def local_dim_count(self) -> int:
+ return len(self.local_dims)
+
+ @property
+ def local_symbol_count(self) -> int:
+ return len(self.local_symbols)
+
+ @property
+ def dim_count(self) -> int:
+ return len(self.all_dims)
+
+ @property
+ def symbol_count(self) -> int:
+ return len(self.all_symbols)
+
+ def __repr__(self):
+ lines = [f"AffineBuildState<"]
+ lines.append(f" symbols={self.local_symbols}")
+ lines.append(f" dims={self.local_dims}>")
+ return "\n".join(lines)
class AffineExprDef:
- """Base class for an affine expression being defined."""
+ """Base class for an affine expression being defined."""
- def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
- """Builds the corresponding _ir.AffineExpr from the definitions.
- """
- state = AffineBuildState() if state is None else state
- expr = self._create(state)
- return expr
+ def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
+ """Builds the corresponding _ir.AffineExpr from the definitions."""
+ state = AffineBuildState() if state is None else state
+ expr = self._create(state)
+ return expr
- def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
- raise NotImplementedError()
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ raise NotImplementedError()
- @staticmethod
- def coerce_from(py_value):
- if isinstance(py_value, int):
- return AffineConstantExpr(py_value)
- assert isinstance(py_value, AffineExprDef)
- return py_value
+ @staticmethod
+ def coerce_from(py_value):
+ if isinstance(py_value, int):
+ return AffineConstantExpr(py_value)
+ assert isinstance(py_value, AffineExprDef)
+ return py_value
- def visit_affine_exprs(self, callback):
- """Visits all AffineExprDefs including self."""
- callback(self)
+ def visit_affine_exprs(self, callback):
+ """Visits all AffineExprDefs including self."""
+ callback(self)
- def __add__(lhs, rhs):
- rhs = AffineExprDef.coerce_from(rhs)
- return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
+ def __add__(lhs, rhs):
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
- def __mul__(lhs, rhs):
- rhs = AffineExprDef.coerce_from(rhs)
- return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
+ def __mul__(lhs, rhs):
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
- def __mod__(lhs, rhs):
- rhs = AffineExprDef.coerce_from(rhs)
- return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
+ def __mod__(lhs, rhs):
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
- def __floordiv__(lhs, rhs):
- rhs = AffineExprDef.coerce_from(rhs)
- return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
+ def __floordiv__(lhs, rhs):
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
- def __truediv__(lhs, rhs):
- # TODO: Not really a ceil div - taking liberties for the DSL.
- rhs = AffineExprDef.coerce_from(rhs)
- return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
+ def __truediv__(lhs, rhs):
+ # TODO: Not really a ceil div - taking liberties for the DSL.
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
class AffineConstantExpr(AffineExprDef):
- """An affine constant being defined."""
+ """An affine constant being defined."""
- def __init__(self, value: int):
- assert isinstance(value, int)
- self.value = value
+ def __init__(self, value: int):
+ assert isinstance(value, int)
+ self.value = value
- def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
- return _ir.AffineConstantExpr.get(self.value)
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ return _ir.AffineConstantExpr.get(self.value)
- def __repr__(self):
- return f"Const({self.value})"
+ def __repr__(self):
+ return f"Const({self.value})"
class AffineBinaryExprDef(AffineExprDef):
- """An affine binary expression being defined."""
+ """An affine binary expression being defined."""
- def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
- self.ir_ctor = ir_ctor
- self.lhs = lhs
- self.rhs = rhs
+ def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
+ self.ir_ctor = ir_ctor
+ self.lhs = lhs
+ self.rhs = rhs
- def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
- return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
- def visit_affine_exprs(self, callback):
- """Visits all AffineExprDefs including self."""
- super().visit_affine_exprs(callback)
- self.lhs.visit_affine_exprs(callback)
- self.rhs.visit_affine_exprs(callback)
+ def visit_affine_exprs(self, callback):
+ """Visits all AffineExprDefs including self."""
+ super().visit_affine_exprs(callback)
+ self.lhs.visit_affine_exprs(callback)
+ self.rhs.visit_affine_exprs(callback)
- def __repr__(self):
- return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
+ def __repr__(self):
+ return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
class DimDef(AffineExprDef):
- """Represents a named dimension.
-
- """
- ALL_DIMS = dict() # type: Dict[str, "DimDef"]
-
- def __new__(cls, dimname: str):
- existing = cls.ALL_DIMS.get(dimname)
- if existing is not None:
- return existing
- new = super().__new__(cls)
- new.dimname = dimname
- cls.ALL_DIMS[dimname] = new
- return new
-
- def __repr__(self):
- return f"Dim({self.dimname})"
-
- def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
- pos = state.get_dim(self.dimname)
- return _ir.AffineDimExpr.get(position=pos)
-
- @classmethod
- def create_expando(cls):
- """Create an expando class that creates unique symbols based on attr access.
- """
+ """Represents a named dimension."""
+
+ ALL_DIMS = dict() # type: Dict[str, "DimDef"]
+
+ def __new__(cls, dimname: str):
+ existing = cls.ALL_DIMS.get(dimname)
+ if existing is not None:
+ return existing
+ new = super().__new__(cls)
+ new.dimname = dimname
+ cls.ALL_DIMS[dimname] = new
+ return new
- class ExpandoDims:
+ def __repr__(self):
+ return f"Dim({self.dimname})"
- def __getattr__(self, n):
- return cls(n)
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ pos = state.get_dim(self.dimname)
+ return _ir.AffineDimExpr.get(position=pos)
- return ExpandoDims()
+ @classmethod
+ def create_expando(cls):
+ """Create an expando class that creates unique symbols based on attr access."""
+
+ class ExpandoDims:
+ def __getattr__(self, n):
+ return cls(n)
+
+ return ExpandoDims()
class SymbolDef(AffineExprDef):
- """Represents a named symbol.
+ """Represents a named symbol.
>>> s1 = SymbolDef("s1")
>>> s1
@@ -270,36 +270,35 @@ class SymbolDef(AffineExprDef):
False
>>> s1 is SymbolDef("s1")
True
- """
- ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"]
-
- def __new__(cls, symname: str):
- existing = cls.ALL_SYMBOLS.get(symname)
- if existing is not None:
- return existing
- new = super().__new__(cls)
- new.symname = symname
- cls.ALL_SYMBOLS[symname] = new
- return new
-
- def __repr__(self):
- return f"Symbol({self.symname})"
-
- def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
- pos = state.get_symbol(self.symname)
- return _ir.AffineSymbolExpr.get(position=pos)
-
- @classmethod
- def create_expando(cls):
- """Create an expando class that creates unique symbols based on attr access.
"""
- class ExpandoSymbols:
+ ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"]
+
+ def __new__(cls, symname: str):
+ existing = cls.ALL_SYMBOLS.get(symname)
+ if existing is not None:
+ return existing
+ new = super().__new__(cls)
+ new.symname = symname
+ cls.ALL_SYMBOLS[symname] = new
+ return new
+
+ def __repr__(self):
+ return f"Symbol({self.symname})"
+
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ pos = state.get_symbol(self.symname)
+ return _ir.AffineSymbolExpr.get(position=pos)
+
+ @classmethod
+ def create_expando(cls):
+ """Create an expando class that creates unique symbols based on attr access."""
- def __getattr__(self, n):
- return cls(n)
+ class ExpandoSymbols:
+ def __getattr__(self, n):
+ return cls(n)
- return ExpandoSymbols()
+ return ExpandoSymbols()
# Global accessor for on-demand dims and symbols.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 135f55ea516d0..5d5866fdeabf6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -23,223 +23,232 @@
class TensorExpression:
- """An expression that can appear on the RHS of a comprehension."""
+ """An expression that can appear on the RHS of a comprehension."""
- def to_scalar_expression(self) -> ScalarExpression:
- raise NotImplementedError()
+ def to_scalar_expression(self) -> ScalarExpression:
+ raise NotImplementedError()
- def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
- """Visits all tensor expression reachable by the expression."""
- callback(self)
+ def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+ """Visits all tensor expression reachable by the expression."""
+ callback(self)
- def collect_dim_uses(self, uses: Set["DimDef"]):
- """Collects all DimDefs reachable through this expression."""
+ def collect_dim_uses(self, uses: Set["DimDef"]):
+ """Collects all DimDefs reachable through this expression."""
- def visit_dim_def(dim_def: AffineExprDef):
- if isinstance(dim_def, DimDef):
- uses.add(dim_def)
+ def visit_dim_def(dim_def: AffineExprDef):
+ if isinstance(dim_def, DimDef):
+ uses.add(dim_def)
- def visit_affine_exprs(expr: "TensorExpression"):
- if isinstance(expr, TensorUse):
- for ind in expr.indices:
- ind.visit_affine_exprs(visit_dim_def)
- if isinstance(expr, TensorReduceFn):
- for ind in expr.reduce_fn.reduce_dims:
- ind.visit_affine_exprs(visit_dim_def)
+ def visit_affine_exprs(expr: "TensorExpression"):
+ if isinstance(expr, TensorUse):
+ for ind in expr.indices:
+ ind.visit_affine_exprs(visit_dim_def)
+ if isinstance(expr, TensorReduceFn):
+ for ind in expr.reduce_fn.reduce_dims:
+ ind.visit_affine_exprs(visit_dim_def)
- self.visit_tensor_exprs(visit_affine_exprs)
+ self.visit_tensor_exprs(visit_affine_exprs)
- def collect_tensor_uses(self, uses: Set["TensorUse"]):
- """Collects all TensorUses reachable through this expression."""
+ def collect_tensor_uses(self, uses: Set["TensorUse"]):
+ """Collects all TensorUses reachable through this expression."""
- def visit_tensor_use(expr: "TensorExpression"):
- if isinstance(expr, TensorUse):
- uses.add(expr)
+ def visit_tensor_use(expr: "TensorExpression"):
+ if isinstance(expr, TensorUse):
+ uses.add(expr)
- self.visit_tensor_exprs(visit_tensor_use)
+ self.visit_tensor_exprs(visit_tensor_use)
- def collect_indices(self, indices: Set["index"]):
- """Collects all index accesses reachable through this expression."""
+ def collect_indices(self, indices: Set["index"]):
+ """Collects all index accesses reachable through this expression."""
- def visit_index(expr: "TensorExpression"):
- if isinstance(expr, index):
- indices.add(expr)
+ def visit_index(expr: "TensorExpression"):
+ if isinstance(expr, index):
+ indices.add(expr)
- self.visit_tensor_exprs(visit_index)
+ self.visit_tensor_exprs(visit_index)
- def collect_scalar_uses(self, uses: Set["ScalarDef"]):
- """Collects all ScalarDefs reachable through this expression."""
+ def collect_scalar_uses(self, uses: Set["ScalarDef"]):
+ """Collects all ScalarDefs reachable through this expression."""
- def visit_scalar_def(expr: "TensorExpression"):
- if isinstance(expr, ScalarDef):
- uses.add(expr)
+ def visit_scalar_def(expr: "TensorExpression"):
+ if isinstance(expr, ScalarDef):
+ uses.add(expr)
- self.visit_tensor_exprs(visit_scalar_def)
+ self.visit_tensor_exprs(visit_scalar_def)
- def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
- return BinaryFn.add(self, rhs)
+ def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
+ return BinaryFn.add(self, rhs)
- def __mul__(self, rhs) -> "TensorExpression":
- return BinaryFn.mul(self, rhs)
+ def __mul__(self, rhs) -> "TensorExpression":
+ return BinaryFn.mul(self, rhs)
- def __sub__(self, rhs) -> "TensorExpression":
- return BinaryFn.sub(self, rhs)
+ def __sub__(self, rhs) -> "TensorExpression":
+ return BinaryFn.sub(self, rhs)
- def __hash__(self):
- return hash(id(self))
+ def __hash__(self):
+ return hash(id(self))
class TensorUse(TensorExpression):
- """A used tensor represented by its (tensor_name, indices).
-
- Note that forming a comprehension via direct assignment is performed through
- __setitem__ on the TensorDef level. However, performing a reduction with
- compound ops (+=, *=, etc) is done by doing a:
- TensorDef.__getitem__
- TensorUse.__iadd__
- TensorDef.__setitem__
- """
-
- def __init__(self, operand_def: "OperandDef",
- indices: Sequence[AffineExprDef]):
- self.operand_def = operand_def
- self.indices = tuple(indices)
-
- def to_scalar_expression(self) -> ScalarExpression:
- return ScalarArg(self.tensor_name).expr()
-
- @property
- def tensor_name(self) -> str:
- name = self.operand_def.name
- assert name is not None, "TensorDef not registered with an op"
- return name
-
- def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
- # Computes the reduction dims for implicit reductions. Assumes that the rhs
- # is the expression being reduced and self is being reduced into. Any
- # indices referenced on the rhs and not in self are considered reduction
- # dims and will be ordered as encountered on the rhs.
- rhs_dims = set()
- lhs_dims = set()
- rhs.collect_dim_uses(rhs_dims)
- self.collect_dim_uses(lhs_dims)
- return rhs_dims - lhs_dims
-
- def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
- return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
-
- def __repr__(self):
- return (f"{self.operand_def.name}"
- f"[{', '.join([repr(i) for i in self.indices])}]")
+ """A used tensor represented by its (tensor_name, indices).
+
+ Note that forming a comprehension via direct assignment is performed through
+ __setitem__ on the TensorDef level. However, performing a reduction with
+ compound ops (+=, *=, etc) is done by doing a:
+ TensorDef.__getitem__
+ TensorUse.__iadd__
+ TensorDef.__setitem__
+ """
+
+ def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]):
+ self.operand_def = operand_def
+ self.indices = tuple(indices)
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarArg(self.tensor_name).expr()
+
+ @property
+ def tensor_name(self) -> str:
+ name = self.operand_def.name
+ assert name is not None, "TensorDef not registered with an op"
+ return name
+
+ def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
+ # Computes the reduction dims for implicit reductions. Assumes that the rhs
+ # is the expression being reduced and self is being reduced into. Any
+ # indices referenced on the rhs and not in self are considered reduction
+ # dims and will be ordered as encountered on the rhs.
+ rhs_dims = set()
+ lhs_dims = set()
+ rhs.collect_dim_uses(rhs_dims)
+ self.collect_dim_uses(lhs_dims)
+ return rhs_dims - lhs_dims
+
+ def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
+ return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
+
+ def __repr__(self):
+ return (
+ f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]"
+ )
class TensorFn(TensorExpression):
- """Application of a tensor function."""
-
- def __init__(self, kind: "FunctionKind", name: Optional[str],
- operand_def: Optional["OperandDef"], type_var: Optional[TypeVar],
- args: Sequence[TensorExpression]):
- if bool(name) + bool(operand_def) != 1:
- raise ValueError("One of 'name', 'operand_def' must be specified")
- self.name = name
- self.kind = kind
- self.operand_def = operand_def
- self.type_var = type_var
- self.args = args
-
- def to_scalar_expression(self) -> ScalarExpression:
- if self.operand_def:
- assert self.operand_def.name, "TensorFn not registered with an op"
- attr_name = self.operand_def.name if self.operand_def else None
- args = [arg.to_scalar_expression() for arg in self.args]
- return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
-
- def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
- super().visit_tensor_exprs(callback)
- for arg in self.args:
- arg.visit_tensor_exprs(callback)
-
- def __repr__(self):
- name = self.operand_def.name if self.operand_def else self.name
- return (f"{self.kind.name}.{name}(type_var={self.type_var}, "
- f"args={', '.join(repr(a) for a in self.args)})")
+ """Application of a tensor function."""
+
+ def __init__(
+ self,
+ kind: "FunctionKind",
+ name: Optional[str],
+ operand_def: Optional["OperandDef"],
+ type_var: Optional[TypeVar],
+ args: Sequence[TensorExpression],
+ ):
+ if bool(name) + bool(operand_def) != 1:
+ raise ValueError("One of 'name', 'operand_def' must be specified")
+ self.name = name
+ self.kind = kind
+ self.operand_def = operand_def
+ self.type_var = type_var
+ self.args = args
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ if self.operand_def:
+ assert self.operand_def.name, "TensorFn not registered with an op"
+ attr_name = self.operand_def.name if self.operand_def else None
+ args = [arg.to_scalar_expression() for arg in self.args]
+ return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
+
+ def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+ super().visit_tensor_exprs(callback)
+ for arg in self.args:
+ arg.visit_tensor_exprs(callback)
+
+ def __repr__(self):
+ name = self.operand_def.name if self.operand_def else self.name
+ return (
+ f"{self.kind.name}.{name}(type_var={self.type_var}, "
+ f"args={', '.join(repr(a) for a in self.args)})"
+ )
class TensorReduceFn(TensorExpression):
- """Application of a reduction function.
-
- This captures the lhs (initial value) separately from the rhs.
- """
-
- def __init__(self, reduce_use: "ReduceFnUse",
- args: Sequence[TensorExpression]):
- self.reduce_use = reduce_use
- self.lhs = None # type: Optional[TensorUse]
- self.args = args
-
- def to_scalar_expression(self) -> ScalarExpression:
- if self.lhs is None:
- raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
- f"bound to its lhs: {self}")
- full_args = [self.lhs.to_scalar_expression()
- ] + [arg.to_scalar_expression() for arg in self.args]
- fn_name = None
- attr_name = None
- if self.reduce_use.binary_fn:
- fn_name = self.reduce_use.binary_fn.fn_name
- if self.reduce_use.binary_attr:
- attr_name = self.reduce_use.binary_attr.operand_def.name
- return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None,
- full_args).expr()
-
- def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
- for arg in self.args:
- arg.visit_tensor_exprs(callback)
-
- def __repr__(self):
- return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
+ """Application of a reduction function.
+
+ This captures the lhs (initial value) separately from the rhs.
+ """
+
+ def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]):
+ self.reduce_use = reduce_use
+ self.lhs = None # type: Optional[TensorUse]
+ self.args = args
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ if self.lhs is None:
+ raise ValueError(
+ f"Cannot scalarize a TensorReduceFn that has not been "
+ f"bound to its lhs: {self}"
+ )
+ full_args = [self.lhs.to_scalar_expression()] + [
+ arg.to_scalar_expression() for arg in self.args
+ ]
+ fn_name = None
+ attr_name = None
+ if self.reduce_use.binary_fn:
+ fn_name = self.reduce_use.binary_fn.fn_name
+ if self.reduce_use.binary_attr:
+ attr_name = self.reduce_use.binary_attr.operand_def.name
+ return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr()
+
+ def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+ for arg in self.args:
+ arg.visit_tensor_exprs(callback)
+
+ def __repr__(self):
+ return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
class const(TensorExpression):
- """Returns the given constant floating point or integer value."""
+ """Returns the given constant floating point or integer value."""
- def __init__(self, value: Any):
- with _ir.Context():
- if isinstance(value, float):
- self.value = str(_ir.FloatAttr.get_f64(float(value)))
- elif isinstance(value, int):
- self.value = str(
- _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
- else:
- raise ValueError(f"const requires int or float but got {type(value)}")
+ def __init__(self, value: Any):
+ with _ir.Context():
+ if isinstance(value, float):
+ self.value = str(_ir.FloatAttr.get_f64(float(value)))
+ elif isinstance(value, int):
+ self.value = str(
+ _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))
+ )
+ else:
+ raise ValueError(f"const requires int or float but got {type(value)}")
- def to_scalar_expression(self) -> ScalarExpression:
- return ScalarConst(self.value).expr()
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarConst(self.value).expr()
- def __repr__(self):
- return f"const({self.value})"
+ def __repr__(self):
+ return f"const({self.value})"
class index(TensorExpression):
- """Returns the iteration index for a given dimension name.
+ """Returns the iteration index for a given dimension name.
- Resolves the given dimension name to obtain its position in the iteration
- domain of the operation.
- """
+ Resolves the given dimension name to obtain its position in the iteration
+ domain of the operation.
+ """
- def __init__(self, dim: DimDef):
- self.dim_def = dim
- self.dim = -1
+ def __init__(self, dim: DimDef):
+ self.dim_def = dim
+ self.dim = -1
- def resolve_dimension_name(self, affine_state: AffineBuildState):
- self.dim = affine_state.get_dim(self.dim_def.dimname)
+ def resolve_dimension_name(self, affine_state: AffineBuildState):
+ self.dim = affine_state.get_dim(self.dim_def.dimname)
- def to_scalar_expression(self) -> ScalarExpression:
- assert self.dim != -1, "Dimension name not resolved"
- return ScalarIndex(self.dim).expr()
+ def to_scalar_expression(self) -> ScalarExpression:
+ assert self.dim != -1, "Dimension name not resolved"
+ return ScalarIndex(self.dim).expr()
- def __repr__(self):
- return f"index({repr(self.dim)})"
+ def __repr__(self):
+ return f"index({repr(self.dim)})"
###############################################################################
@@ -248,155 +257,160 @@ def __repr__(self):
class FunctionKind(Enum):
- UNARY = 0
- BINARY = 1
- TYPE = 2
+ UNARY = 0
+ BINARY = 1
+ TYPE = 2
class UnaryFnType:
- """Unary function.
+ """Unary function.
- A unary function takes one tensor expression and returns the
- function evaluation result.
- """
+ A unary function takes one tensor expression and returns the
+ function evaluation result.
+ """
- def __init__(self, fn_name: str):
- self.fn_name = fn_name
+ def __init__(self, fn_name: str):
+ self.fn_name = fn_name
- def __call__(self, arg: TensorExpression) -> "TensorFn":
- return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
+ def __call__(self, arg: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
- def __repr__(self):
- return f"{self.fn_name}"
+ def __repr__(self):
+ return f"{self.fn_name}"
class UnaryFn:
- """Unary function namespace."""
- exp = UnaryFnType("exp")
- log = UnaryFnType("log")
- abs = UnaryFnType("abs")
- ceil = UnaryFnType("ceil")
- floor = UnaryFnType("floor")
- negf = UnaryFnType("negf")
+ """Unary function namespace."""
+
+ exp = UnaryFnType("exp")
+ log = UnaryFnType("log")
+ abs = UnaryFnType("abs")
+ ceil = UnaryFnType("ceil")
+ floor = UnaryFnType("floor")
+ negf = UnaryFnType("negf")
class BinaryFnType:
- """Binary function.
+ """Binary function.
- A binary function takes two tensor expressions and returns the
- function evaluation result.
- """
+ A binary function takes two tensor expressions and returns the
+ function evaluation result.
+ """
- def __init__(self, fn_name: str):
- self.fn_name = fn_name
+ def __init__(self, fn_name: str):
+ self.fn_name = fn_name
- def __call__(self, arg0: TensorExpression,
- arg1: TensorExpression) -> "TensorFn":
- return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
+ def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
- def __repr__(self):
- return f"{self.fn_name}"
+ def __repr__(self):
+ return f"{self.fn_name}"
class BinaryFn:
- """Binary function namespace.
+ """Binary function namespace.
- As the integer types are signless, signedness is implement by
diff erent
- functions that treat integers as signed or unsigned values.
+ As the integer types are signless, signedness is implement by
diff erent
+ functions that treat integers as signed or unsigned values.
+
+ Examples:
+ - max -> `arith.MaxSIOp`
+ - max_unsinged -> `arith.MaxUIOp`
+ """
- Examples:
- - max -> `arith.MaxSIOp`
- - max_unsinged -> `arith.MaxUIOp`
- """
- add = BinaryFnType("add")
- sub = BinaryFnType("sub")
- mul = BinaryFnType("mul")
- max_signed = BinaryFnType("max_signed")
- min_signed = BinaryFnType("min_signed")
- max_unsigned = BinaryFnType("max_unsigned")
- min_unsigned = BinaryFnType("min_unsigned")
+ add = BinaryFnType("add")
+ sub = BinaryFnType("sub")
+ mul = BinaryFnType("mul")
+ max_signed = BinaryFnType("max_signed")
+ min_signed = BinaryFnType("min_signed")
+ max_unsigned = BinaryFnType("max_unsigned")
+ min_unsigned = BinaryFnType("min_unsigned")
class TypeFnType:
- """Type conversion function.
+ """Type conversion function.
- A type conversion function takes a target type and a tensor expression and
- returns the casted tensor expression.
- """
+ A type conversion function takes a target type and a tensor expression and
+ returns the casted tensor expression.
+ """
- def __init__(self, fn_name: str):
- self.fn_name = fn_name
+ def __init__(self, fn_name: str):
+ self.fn_name = fn_name
- def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
- return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
+ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
- def __repr__(self):
- return f"{self.fn_name}"
+ def __repr__(self):
+ return f"{self.fn_name}"
class TypeFn:
- """Type conversion function namespace.
+ """Type conversion function namespace.
+
+ As the integer types are signless, signedness is implement by
diff erent cast
+ functions that treat integers as signed (`cast_signed`) or unsigned
+ (`cast_unsigned`) values.
- As the integer types are signless, signedness is implement by
diff erent cast
- functions that treat integers as signed (`cast_signed`) or unsigned
- (`cast_unsigned`) values.
+ Examples:
+ - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
+ - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
+ """
- Examples:
- - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
- - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
- """
- cast_signed = TypeFnType("cast_signed")
- cast_unsigned = TypeFnType("cast_unsigned")
+ cast_signed = TypeFnType("cast_signed")
+ cast_unsigned = TypeFnType("cast_unsigned")
class ReduceFnUse:
- """Reduction function use.
+ """Reduction function use.
- A reduction use specifies the reduction function and dimensions.
- """
+ A reduction use specifies the reduction function and dimensions.
+ """
- def __init__(self, binary_fn: Optional[BinaryFnType],
- binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef):
- if bool(binary_fn) + bool(binary_attr) != 1:
- raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
- self.binary_fn = binary_fn
- self.binary_attr = binary_attr
- self.reduce_dims = reduce_dims
+ def __init__(
+ self,
+ binary_fn: Optional[BinaryFnType],
+ binary_attr: Optional["BinaryFnAttrDef"],
+ *reduce_dims: DimDef,
+ ):
+ if bool(binary_fn) + bool(binary_attr) != 1:
+ raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
+ self.binary_fn = binary_fn
+ self.binary_attr = binary_attr
+ self.reduce_dims = reduce_dims
- def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
- return TensorReduceFn(self, args)
+ def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
+ return TensorReduceFn(self, args)
- def __repr__(self):
- fn = self.binary_fn if self.binary_fn else self.binary_attr
- return (
- f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})")
+ def __repr__(self):
+ fn = self.binary_fn if self.binary_fn else self.binary_attr
+ return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})"
class ReduceFnType:
- """Reduction function.
+ """Reduction function.
- A binary function that reduces its RHS into its LHS.
- """
+ A binary function that reduces its RHS into its LHS.
+ """
- def __init__(self, binary_fn: BinaryFnType):
- if not isinstance(binary_fn, BinaryFnType):
- raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
- self.binary_fn = binary_fn
+ def __init__(self, binary_fn: BinaryFnType):
+ if not isinstance(binary_fn, BinaryFnType):
+ raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
+ self.binary_fn = binary_fn
- def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
- return ReduceFnUse(self.binary_fn, None, *reduce_dims)
+ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+ return ReduceFnUse(self.binary_fn, None, *reduce_dims)
- def __repr__(self):
- return f"reduce_{repr(self.binary_fn)}"
+ def __repr__(self):
+ return f"reduce_{repr(self.binary_fn)}"
class ReduceFn:
- add = ReduceFnType(BinaryFn.add)
- mul = ReduceFnType(BinaryFn.mul)
- max_signed = ReduceFnType(BinaryFn.max_signed)
- min_signed = ReduceFnType(BinaryFn.min_signed)
- max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
- min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
+ add = ReduceFnType(BinaryFn.add)
+ mul = ReduceFnType(BinaryFn.mul)
+ max_signed = ReduceFnType(BinaryFn.max_signed)
+ min_signed = ReduceFnType(BinaryFn.min_signed)
+ max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
+ min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
###############################################################################
@@ -405,237 +419,265 @@ class ReduceFn:
class OperandKind(Enum):
- INPUT_TENSOR = 0
- SCALAR = 1
- OUTPUT_TENSOR = 2
- INDEX_ATTR = 3
- UNARY_FN_ATTR = 4
- BINARY_FN_ATTR = 5
- TYPE_FN_ATTR = 6
+ INPUT_TENSOR = 0
+ SCALAR = 1
+ OUTPUT_TENSOR = 2
+ INDEX_ATTR = 3
+ UNARY_FN_ATTR = 4
+ BINARY_FN_ATTR = 5
+ TYPE_FN_ATTR = 6
class OperandDef:
- """Definition of an operand passed to an operation.
-
- Keep the meta information of Tensor, Scalar, and Attribute operands and
- provide the shared registration functionality.
- """
-
- def __init__(self,
- kind: OperandKind,
- type_var: Optional[TypeVar] = None,
- size_exprs: Optional[Sequence[AffineExprDef]] = None,
- index_dims: Optional[Sequence[DimDef]] = None,
- default_indices: Optional[Sequence[int]] = None,
- default_fn: Optional[str] = None):
- if type_var and not isinstance(type_var, TypeVar):
- raise ValueError(
- f"OperandDef requires a TypeVar but got {repr(type_var)}")
- self.owner = None # type: Optional["LinalgOpDef"]
- self.type_var = type_var
- self.size_exprs = size_exprs
- self.index_dims = index_dims
- self.default_indices = default_indices
- self.default_fn = default_fn
- self.kind = kind
- self.name = None # type: Optional[str]
- self.registered_index = -1 # type: int
-
- def attach(self, index: int, name: str, owner: "LinalgOpDef"):
- if self.owner:
- raise ValueError(f"OperandDef already registered with an op: {self}")
- self.registered_index = index
- self.name = name
- self.owner = owner
-
- def is_input(self) -> bool:
- return (self.kind == OperandKind.SCALAR or
- self.kind == OperandKind.INPUT_TENSOR)
-
- def is_tensor(self) -> bool:
- return (self.kind == OperandKind.INPUT_TENSOR or
- self.kind == OperandKind.OUTPUT_TENSOR)
-
- def is_attribute(self) -> bool:
- return (self.kind == OperandKind.INDEX_ATTR or
- self.kind == OperandKind.UNARY_FN_ATTR or
- self.kind == OperandKind.BINARY_FN_ATTR or
- self.kind == OperandKind.TYPE_FN_ATTR)
-
- def __hash__(self):
- return hash(id(self))
-
- def __repr__(self):
- return (f"{self.name}:OperandDef(kind={self.kind.name}, "
+ """Definition of an operand passed to an operation.
+
+ Keep the meta information of Tensor, Scalar, and Attribute operands and
+ provide the shared registration functionality.
+ """
+
+ def __init__(
+ self,
+ kind: OperandKind,
+ type_var: Optional[TypeVar] = None,
+ size_exprs: Optional[Sequence[AffineExprDef]] = None,
+ index_dims: Optional[Sequence[DimDef]] = None,
+ default_indices: Optional[Sequence[int]] = None,
+ default_fn: Optional[str] = None,
+ ):
+ if type_var and not isinstance(type_var, TypeVar):
+ raise ValueError(f"OperandDef requires a TypeVar but got {repr(type_var)}")
+ self.owner = None # type: Optional["LinalgOpDef"]
+ self.type_var = type_var
+ self.size_exprs = size_exprs
+ self.index_dims = index_dims
+ self.default_indices = default_indices
+ self.default_fn = default_fn
+ self.kind = kind
+ self.name = None # type: Optional[str]
+ self.registered_index = -1 # type: int
+
+ def attach(self, index: int, name: str, owner: "LinalgOpDef"):
+ if self.owner:
+ raise ValueError(f"OperandDef already registered with an op: {self}")
+ self.registered_index = index
+ self.name = name
+ self.owner = owner
+
+ def is_input(self) -> bool:
+ return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR
+
+ def is_tensor(self) -> bool:
+ return (
+ self.kind == OperandKind.INPUT_TENSOR
+ or self.kind == OperandKind.OUTPUT_TENSOR
+ )
+
+ def is_attribute(self) -> bool:
+ return (
+ self.kind == OperandKind.INDEX_ATTR
+ or self.kind == OperandKind.UNARY_FN_ATTR
+ or self.kind == OperandKind.BINARY_FN_ATTR
+ or self.kind == OperandKind.TYPE_FN_ATTR
+ )
+
+ def __hash__(self):
+ return hash(id(self))
+
+ def __repr__(self):
+ return (
+ f"{self.name}:OperandDef(kind={self.kind.name}, "
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, "
f"index_dims={self.index_dims}, "
f"default_indices={self.default_indices}, "
- f"default_fn={self.default_fn})")
+ f"default_fn={self.default_fn})"
+ )
class TensorDef:
- """Tensor operand definition.
-
- Tensor operands are indexed using the associated indexing_map when forwarded
- to the body of the structured op. A unique name identifies the tensor operands
- and an index determines their position in the operation's parameter list. A
- tensor definition takes type, a shape, and an optional flag to mark output
- tensors. Additionally, a tuple of index dimensions may be used to map the
- tensor to the loop dimensions of the operation. This mapping is needed to
- compute the indexing map of shape-only tensors that have no uses.
- """
-
- def __init__(self,
- type_var: TypeVar,
- *shape: AffineExprDef,
- index_dims: Optional[Sequence[DimDef]] = None,
- output: bool = False):
- if index_dims and len(shape) != len(index_dims):
- raise ValueError(f"Expected the shape rank {len(shape)} to match the "
- f"number of index_dims {len(index_dims)}")
- if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
- raise ValueError(f"TensorDef requires index dims of type DimDef but "
- f"got {index_dims}")
- kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
- self.operand_def = OperandDef(
- kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
-
- def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
- assert self.operand_def.owner, "TensorDef is not registered with an op"
- state = AffineBuildState(
- global_state=self.operand_def.owner._affine_state,
- allow_new_symbols=False)
- if not isinstance(dims, tuple):
- dims = (dims,) # Handle single subscript case.
- # Special case: (None) is a 0d-scalar use.
- if dims == (None,):
- dims = ()
-
- exprs = []
- for expr_def in dims:
- if not isinstance(expr_def, AffineExprDef):
- raise KeyError(
- "A TensorDef can only be subscripted by a tuple of affine dims")
- exprs.append(expr_def)
- return TensorUse(self.operand_def, exprs)
-
- def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
- """Creates a new 1:1 comprehension by binding this tensor to an expression.
-
- Note that due to the way assignment works in Python, we have to capture
- direct assignment as a setitem on the TensorDef.
+ """Tensor operand definition.
+
+ Tensor operands are indexed using the associated indexing_map when forwarded
+ to the body of the structured op. A unique name identifies the tensor operands
+ and an index determines their position in the operation's parameter list. A
+ tensor definition takes type, a shape, and an optional flag to mark output
+ tensors. Additionally, a tuple of index dimensions may be used to map the
+ tensor to the loop dimensions of the operation. This mapping is needed to
+ compute the indexing map of shape-only tensors that have no uses.
"""
- if not isinstance(value, TensorExpression):
- raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. "
- f"Got: {repr(value)}")
- use = self[dims]
- comp = Comprehension((use, value))
- self.operand_def.owner.comprehensions.append(comp)
+
+ def __init__(
+ self,
+ type_var: TypeVar,
+ *shape: AffineExprDef,
+ index_dims: Optional[Sequence[DimDef]] = None,
+ output: bool = False,
+ ):
+ if index_dims and len(shape) != len(index_dims):
+ raise ValueError(
+ f"Expected the shape rank {len(shape)} to match the "
+ f"number of index_dims {len(index_dims)}"
+ )
+ if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
+ raise ValueError(
+ f"TensorDef requires index dims of type DimDef but " f"got {index_dims}"
+ )
+ kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
+ self.operand_def = OperandDef(
+ kind, type_var=type_var, size_exprs=shape, index_dims=index_dims
+ )
+
+ def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
+ assert self.operand_def.owner, "TensorDef is not registered with an op"
+ state = AffineBuildState(
+ global_state=self.operand_def.owner._affine_state, allow_new_symbols=False
+ )
+ if not isinstance(dims, tuple):
+ dims = (dims,) # Handle single subscript case.
+ # Special case: (None) is a 0d-scalar use.
+ if dims == (None,):
+ dims = ()
+
+ exprs = []
+ for expr_def in dims:
+ if not isinstance(expr_def, AffineExprDef):
+ raise KeyError(
+ "A TensorDef can only be subscripted by a tuple of affine dims"
+ )
+ exprs.append(expr_def)
+ return TensorUse(self.operand_def, exprs)
+
+ def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
+ """Creates a new 1:1 comprehension by binding this tensor to an expression.
+
+ Note that due to the way assignment works in Python, we have to capture
+ direct assignment as a setitem on the TensorDef.
+ """
+ if not isinstance(value, TensorExpression):
+ raise ValueError(
+ f"Only TensorExpressions can be assigned to TensorDefs. "
+ f"Got: {repr(value)}"
+ )
+ use = self[dims]
+ comp = Comprehension((use, value))
+ self.operand_def.owner.comprehensions.append(comp)
class ScalarDef(TensorExpression):
- """Scalar operand definition.
+ """Scalar operand definition.
- Scalar operands are forwarded to the body of the structured op as they are.
- A unique name identifies the scalars and an index determines their position in
- the operation's parameter list.
- """
+ Scalar operands are forwarded to the body of the structured op as they are.
+ A unique name identifies the scalars and an index determines their position in
+ the operation's parameter list.
+ """
- def __init__(self, type_var: TypeVar):
- self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
+ def __init__(self, type_var: TypeVar):
+ self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
- @property
- def scalar_name(self) -> str:
- name = self.operand_def.name
- assert name is not None, "ScalarDef not registered with an op"
- return name
+ @property
+ def scalar_name(self) -> str:
+ name = self.operand_def.name
+ assert name is not None, "ScalarDef not registered with an op"
+ return name
- def to_scalar_expression(self) -> ScalarExpression:
- return ScalarArg(self.scalar_name).expr()
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarArg(self.scalar_name).expr()
class IndexAttrDef:
- """Index attribute definition.
-
- Index attributes provide a way to define and set symbols that can be used in
- indexing expressions. Every attribute specifies a tuple of symbols that at
- compile-time are replaced by integer values as well as their default values.
- """
-
- def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
- if any(not isinstance(size, SymbolDef) for size in sizes):
- raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef "
- f"but got {sizes}")
- if any(not isinstance(default_val, int) for default_val in default):
- raise ValueError(f"IndexAttrDef requires default values of type int "
- f"but got {default}")
- if len(sizes) != len(default):
- raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
- f"but got {len(default)}")
- self.operand_def = OperandDef(
- OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
+ """Index attribute definition.
+
+ Index attributes provide a way to define and set symbols that can be used in
+ indexing expressions. Every attribute specifies a tuple of symbols that at
+ compile-time are replaced by integer values as well as their default values.
+ """
+
+ def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
+ if any(not isinstance(size, SymbolDef) for size in sizes):
+ raise ValueError(
+ f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}"
+ )
+ if any(not isinstance(default_val, int) for default_val in default):
+ raise ValueError(
+ f"IndexAttrDef requires default values of type int "
+ f"but got {default}"
+ )
+ if len(sizes) != len(default):
+ raise ValueError(
+ f"IndexAttrDef expects {len(sizes)} default values "
+ f"but got {len(default)}"
+ )
+ self.operand_def = OperandDef(
+ OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default
+ )
class UnaryFnAttrDef:
- """Unary function attribute definition.
+ """Unary function attribute definition.
- Unary function attributes provide a way to make the arithmetic computation
- parametrizable. Every attribute specifies a default unary function
- that may be overwritten at operation instantiation time.
- """
+ Unary function attributes provide a way to make the arithmetic computation
+ parametrizable. Every attribute specifies a default unary function
+ that may be overwritten at operation instantiation time.
+ """
- def __init__(self, default: "UnaryFnType"):
- if not isinstance(default, UnaryFnType):
- raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType "
- f"but got {default}")
- self.operand_def = OperandDef(
- OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name)
+ def __init__(self, default: "UnaryFnType"):
+ if not isinstance(default, UnaryFnType):
+ raise ValueError(
+ f"UnaryFnAttrDef requires default of type UnaryFnType "
+ f"but got {default}"
+ )
+ self.operand_def = OperandDef(
+ OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name
+ )
- def __call__(self, arg: TensorExpression) -> TensorFn:
- return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
+ def __call__(self, arg: TensorExpression) -> TensorFn:
+ return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
class BinaryFnAttrDef:
- """Binary function attribute definition.
+ """Binary function attribute definition.
- Binary function attributes provide a way to make the arithmetic computation
- parametrizable. Every attribute specifies a default binary function
- that may be overwritten at operation instantiation time.
- """
+ Binary function attributes provide a way to make the arithmetic computation
+ parametrizable. Every attribute specifies a default binary function
+ that may be overwritten at operation instantiation time.
+ """
- def __init__(self, default: "BinaryFnType"):
- if not isinstance(default, BinaryFnType):
- raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType "
- f"but got {default}")
- self.operand_def = OperandDef(
- OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name)
+ def __init__(self, default: "BinaryFnType"):
+ if not isinstance(default, BinaryFnType):
+ raise ValueError(
+ f"BinaryFnAttrDef requires default of type BinaryFnType "
+ f"but got {default}"
+ )
+ self.operand_def = OperandDef(
+ OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name
+ )
- def __call__(self, arg0: TensorExpression,
- arg1: TensorExpression) -> TensorFn:
- return TensorFn(FunctionKind.BINARY, None, self.operand_def, None,
- [arg0, arg1])
+ def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
+ return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1])
- def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
- return ReduceFnUse(None, self, *reduce_dims)
+ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+ return ReduceFnUse(None, self, *reduce_dims)
class TypeFnAttrDef:
- """Type conversion function attribute definition.
+ """Type conversion function attribute definition.
- Type conversion function attributes provide a way to make type conversions
- parameterizable. Every attribute specifies a default type conversion function
- that may be overwritten at operation instantiation time.
- """
+ Type conversion function attributes provide a way to make type conversions
+ parameterizable. Every attribute specifies a default type conversion function
+ that may be overwritten at operation instantiation time.
+ """
- def __init__(self, default: "TypeFnType"):
- if not isinstance(default, TypeFnType):
- raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType "
- f"but got {default}")
- self.operand_def = OperandDef(
- OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
+ def __init__(self, default: "TypeFnType"):
+ if not isinstance(default, TypeFnType):
+ raise ValueError(
+ f"TypeFnAttrDef requires default of type TypeFnType "
+ f"but got {default}"
+ )
+ self.operand_def = OperandDef(
+ OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name
+ )
- def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
- return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
+ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
+ return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
###############################################################################
@@ -644,48 +686,48 @@ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
class Comprehension:
- """Represents a single comprehension."""
-
- def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
- self.definitions = list() # List[TensorUse]
- self.values = list() # List[TensorExpression]
-
- # Find the lhs to reduction rhs.
- for assign, value in bindings:
- if isinstance(value, TensorReduceFn):
- if value.lhs:
- raise ValueError(f"Reduction expression already assigns: {value}")
- value.lhs = assign
- self.definitions.append(assign)
- self.values.append(value)
-
- @property
- def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
- """Gets the reduction dims for the comprehension or None."""
- result = set()
- for use in self.values:
- if isinstance(use, TensorReduceFn):
- result.add(use.reduce_use.reduce_dims)
- else:
- result.add(tuple())
- return result
-
- def __repr__(self):
- if len(self.definitions) > 1:
- defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
- values_repr = f"({', '.join(repr(v) for v in self.values)})"
- else:
- defs_repr = f"{repr(self.definitions[0])}"
- values_repr = f"{repr(self.values[0])}"
-
- return f"{defs_repr} = {values_repr}"
+ """Represents a single comprehension."""
+
+ def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
+ self.definitions = list() # List[TensorUse]
+ self.values = list() # List[TensorExpression]
+
+ # Find the lhs to reduction rhs.
+ for assign, value in bindings:
+ if isinstance(value, TensorReduceFn):
+ if value.lhs:
+ raise ValueError(f"Reduction expression already assigns: {value}")
+ value.lhs = assign
+ self.definitions.append(assign)
+ self.values.append(value)
+
+ @property
+ def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
+ """Gets the reduction dims for the comprehension or None."""
+ result = set()
+ for use in self.values:
+ if isinstance(use, TensorReduceFn):
+ result.add(use.reduce_use.reduce_dims)
+ else:
+ result.add(tuple())
+ return result
+
+ def __repr__(self):
+ if len(self.definitions) > 1:
+ defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
+ values_repr = f"({', '.join(repr(v) for v in self.values)})"
+ else:
+ defs_repr = f"{repr(self.definitions[0])}"
+ values_repr = f"{repr(self.values[0])}"
+
+ return f"{defs_repr} = {values_repr}"
class OpInterfaceDef:
- """An interface that an op implements."""
+ """An interface that an op implements."""
- def __init__(self, cpp_name: str):
- self.cpp_name = cpp_name
+ def __init__(self, cpp_name: str):
+ self.cpp_name = cpp_name
ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
@@ -694,86 +736,94 @@ def __init__(self, cpp_name: str):
class OpDefinitionDef:
- """A method that an op implements."""
+ """A method that an op implements."""
- def __init__(self, def_name: str):
- self.def_name = def_name
+ def __init__(self, def_name: str):
+ self.def_name = def_name
Canonicalizer = OpDefinitionDef("hasCanonicalizer")
class OpMetadataDef(YAMLObject):
- """Metadata about the op (generally not behavior impacting)."""
- yaml_tag = "!LinalgOpMetadata"
-
- def __init__(self, name: str, cpp_class_name: Optional[str],
- doc: Optional[str]):
- self.name = name
- self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
- self.doc = doc
- self.implements = [] # type: List[OpInterfaceDef]
- self.defines = [] # type: List[OpDefinitionsDef]
-
- def to_yaml_custom_dict(self):
- d = dict(
- name=self.name,
- cpp_class_name=self.cpp_class_name,
- doc=self.doc,
- )
- if self.implements:
- d["implements"] = [intr.cpp_name for intr in self.implements]
- if self.defines:
- d["defines"] = [defi.def_name for defi in self.defines]
- return d
+ """Metadata about the op (generally not behavior impacting)."""
+
+ yaml_tag = "!LinalgOpMetadata"
+
+ def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]):
+ self.name = name
+ self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
+ self.doc = doc
+ self.implements = [] # type: List[OpInterfaceDef]
+ self.defines = [] # type: List[OpDefinitionsDef]
+
+ def to_yaml_custom_dict(self):
+ d = dict(
+ name=self.name,
+ cpp_class_name=self.cpp_class_name,
+ doc=self.doc,
+ )
+ if self.implements:
+ d["implements"] = [intr.cpp_name for intr in self.implements]
+ if self.defines:
+ d["defines"] = [defi.def_name for defi in self.defines]
+ return d
class LinalgOpDef:
- """Definition of a linalg op."""
-
- def __init__(self,
- name: str,
- cpp_class_name: Optional[str] = None,
- doc: Optional[str] = None):
- self.metadata = OpMetadataDef(
- name=name, cpp_class_name=cpp_class_name, doc=doc)
- self.registered_operands = dict() # type: Dict[str, OperandDef]
- self.domain = list() # type: List[DimDef]
- self.comprehensions = list() # type: List[Comprehension]
- self._affine_state = AffineBuildState()
-
- def add_operand(self, name: str, operand: OperandDef):
- """Registers an operand."""
- if name in self.registered_operands:
- raise ValueError(f"The operand {name} is already registered "
- f"to {self.registered_operands['name']}")
- structured_op_methods = [
- "inputs", "outputs", "result_tensors", "region", "iterator_types",
- "indexing_maps", "getRegionBuilder", "getLibraryCallName"
- ]
- if operand.is_attribute() and name in structured_op_methods:
- raise ValueError(f"The attribute name {name} conflicts with a structured "
- f"op method name")
- # Ensure output tensors are registered after input tensors and scalars and
- # attributes are registered after all other operand types.
- if operand.is_input() and any(
- not op_def.is_input() for op_def in self.registered_operands.values()):
- raise ValueError(f"Input {name} registered after an output or attribute")
- if operand.kind == OperandKind.OUTPUT_TENSOR and any(
- op_def.is_attribute() for op_def in self.registered_operands.values()):
- raise ValueError(f"Output {name} registered after an attribute")
- operand.attach(len(self.registered_operands), name, self)
- self.registered_operands[name] = operand
-
- def __repr__(self):
- lines = [
- f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"
- ]
- for name, operand in self.registered_operands.items():
- lines.append(f" {operand}")
- if self.comprehensions:
- lines[-1] += " {"
- for comprehension in self.comprehensions:
- lines.append(f" {comprehension}")
- lines.append("}")
- return "\n".join(lines)
+ """Definition of a linalg op."""
+
+ def __init__(
+ self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None
+ ):
+ self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
+ self.registered_operands = dict() # type: Dict[str, OperandDef]
+ self.domain = list() # type: List[DimDef]
+ self.comprehensions = list() # type: List[Comprehension]
+ self._affine_state = AffineBuildState()
+
+ def add_operand(self, name: str, operand: OperandDef):
+ """Registers an operand."""
+ if name in self.registered_operands:
+ raise ValueError(
+ f"The operand {name} is already registered "
+ f"to {self.registered_operands['name']}"
+ )
+ structured_op_methods = [
+ "inputs",
+ "outputs",
+ "result_tensors",
+ "region",
+ "iterator_types",
+ "indexing_maps",
+ "getRegionBuilder",
+ "getLibraryCallName",
+ ]
+ if operand.is_attribute() and name in structured_op_methods:
+ raise ValueError(
+ f"The attribute name {name} conflicts with a structured "
+ f"op method name"
+ )
+ # Ensure output tensors are registered after input tensors and scalars and
+ # attributes are registered after all other operand types.
+ if operand.is_input() and any(
+ not op_def.is_input() for op_def in self.registered_operands.values()
+ ):
+ raise ValueError(f"Input {name} registered after an output or attribute")
+ if operand.kind == OperandKind.OUTPUT_TENSOR and any(
+ op_def.is_attribute() for op_def in self.registered_operands.values()
+ ):
+ raise ValueError(f"Output {name} registered after an attribute")
+ operand.attach(len(self.registered_operands), name, self)
+ self.registered_operands[name] = operand
+
+ def __repr__(self):
+ lines = [f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"]
+ for name, operand in self.registered_operands.items():
+ lines.append(f" {operand}")
+ if self.comprehensions:
+ lines[-1] += " {"
+ for comprehension in self.comprehensions:
+ lines.append(f" {comprehension}")
+ lines.append("}")
+ return "\n".join(lines)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 2a0da68295265..d522d5712d253 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -21,422 +21,468 @@
def _serialize_affine_map(affine_map: _ir.AffineMap) -> str:
- with affine_map.context:
- # Affine map printing/parsing is via an AffineMap attr.
- attr = _ir.AffineMapAttr.get(affine_map)
- return str(attr)
+ with affine_map.context:
+ # Affine map printing/parsing is via an AffineMap attr.
+ attr = _ir.AffineMapAttr.get(affine_map)
+ return str(attr)
class TensorUseConfig:
- """Wrapper around a TensorUse with additional context-bound state."""
+ """Wrapper around a TensorUse with additional context-bound state."""
- def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
- self.tensor_use = tensor_use
- self.indexing_map = indexing_map
+ def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
+ self.tensor_use = tensor_use
+ self.indexing_map = indexing_map
- def __repr__(self):
- return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
+ def __repr__(self):
+ return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
class OperandDefConfig(YAMLObject):
- """Wrapper containing an operand definition with additional state."""
- yaml_tag = "!LinalgOperandDefConfig"
-
- def __init__(self,
- operand_def: OperandDef,
- shape_map: Optional[_ir.AffineMap] = None,
- index_attr_map: Optional[_ir.AffineMap] = None):
- self.operand_def = operand_def
- self.shape_map = shape_map # type: Optional[_ir.AffineMap]
- self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap]
- self.indexing_map = None # type: Optional[_ir.AffineMap]
-
- @property
- def name(self) -> str:
- return self.operand_def.name
-
- @property
- def kind(self) -> OperandKind:
- return self.operand_def.kind
-
- @property
- def type_var(self) -> TypeVar:
- return self.operand_def.type_var
-
- def to_yaml_custom_dict(self):
- self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
- if self.type_var:
- self_dict["type_var"] = self.type_var.name
- if self.shape_map:
- self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
- if self.index_attr_map:
- self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
- if self.operand_def.default_indices:
- self_dict["default_indices"] = self.operand_def.default_indices
- if self.operand_def.default_fn:
- self_dict["default_fn"] = self.operand_def.default_fn
- return self_dict
-
- def __repr__(self):
- return (f"OperandDefConfig({self.operand_def}, "
+ """Wrapper containing an operand definition with additional state."""
+
+ yaml_tag = "!LinalgOperandDefConfig"
+
+ def __init__(
+ self,
+ operand_def: OperandDef,
+ shape_map: Optional[_ir.AffineMap] = None,
+ index_attr_map: Optional[_ir.AffineMap] = None,
+ ):
+ self.operand_def = operand_def
+ self.shape_map = shape_map # type: Optional[_ir.AffineMap]
+ self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap]
+ self.indexing_map = None # type: Optional[_ir.AffineMap]
+
+ @property
+ def name(self) -> str:
+ return self.operand_def.name
+
+ @property
+ def kind(self) -> OperandKind:
+ return self.operand_def.kind
+
+ @property
+ def type_var(self) -> TypeVar:
+ return self.operand_def.type_var
+
+ def to_yaml_custom_dict(self):
+ self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
+ if self.type_var:
+ self_dict["type_var"] = self.type_var.name
+ if self.shape_map:
+ self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
+ if self.index_attr_map:
+ self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
+ if self.operand_def.default_indices:
+ self_dict["default_indices"] = self.operand_def.default_indices
+ if self.operand_def.default_fn:
+ self_dict["default_fn"] = self.operand_def.default_fn
+ return self_dict
+
+ def __repr__(self):
+ return (
+ f"OperandDefConfig({self.operand_def}, "
f"shape_map={self.shape_map}, "
f"index_attr_map={self.index_attr_map}, "
- f"indexing_map={self.indexing_map})")
+ f"indexing_map={self.indexing_map})"
+ )
class LinalgIndexingMapsConfig(YAMLObject):
- """Abstracts the style of indexing maps that the op exports.
-
- Presently only static (tied to the op name) indexing maps are supported. In
- the future, it is expected that we will have additional variants:
- - Dynamic based on attributes
- - Dynamic based on operands
- Each is expected to require a
diff erent variant of specification.
- """
- yaml_tag = "!LinalgIndexingMapsConfig"
-
- def __init__(self,
- static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
- self.static_indexing_maps = static_indexing_maps
-
- def to_yaml_custom_dict(self):
- if self.static_indexing_maps is not None:
- return dict(static_indexing_maps=[
- _serialize_affine_map(m) for m in self.static_indexing_maps
- ])
- raise ValueError(
- f"LinalgIndexingMapsConfig must have one type of indexing map"
- f"(got none)")
+ """Abstracts the style of indexing maps that the op exports.
+ Presently only static (tied to the op name) indexing maps are supported. In
+ the future, it is expected that we will have additional variants:
+ - Dynamic based on attributes
+ - Dynamic based on operands
+ Each is expected to require a
diff erent variant of specification.
+ """
-class LinalgStructuredOpConfig(YAMLObject):
- """Configuration for metadata sufficient to construct a linalg named op."""
-
- yaml_tag = "!LinalgStructuredOpConfig"
-
- def __init__(self,
- comprehension: Comprehension,
- domain: Sequence[DimDef],
- registered_operands: Sequence[OperandDef],
- context: Optional[_ir.Context] = None):
- self.context = context if context is not None else _ir.Context()
- self.affine_state = AffineBuildState()
- self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]]
- self.operands = dict() # type: Dict[OperandDef, OperandDefConfig]
- self.uses = dict() # type: Dict[TensorUse, TensorUseConfig]
-
- # Compute the ordered set of writes and collect the tensor, capture, dims,
- # and index uses.
- collected_tensor_uses = set()
- collected_scalar_uses = set()
- collected_dim_uses = set()
- collected_indices = set()
- for write_use, read_use in zip(comprehension.definitions,
- comprehension.values):
- self.writes.append((write_use, read_use))
-
- for write_use, read_use in self.writes:
- collected_tensor_uses.add(write_use)
- read_use.collect_tensor_uses(collected_tensor_uses)
- read_use.collect_scalar_uses(collected_scalar_uses)
- read_use.collect_dim_uses(collected_dim_uses)
- write_use.collect_dim_uses(collected_dim_uses)
- read_use.collect_indices(collected_indices)
-
- # Set domain to the sorted list of uses if no domain annotation is given.
- if not domain:
- domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
-
- # Verify the domain dimensions match the used dimensions.
- if (len(domain) != len(collected_dim_uses) or
- any(dim not in collected_dim_uses for dim in domain)):
- raise ValueError(f"Expected the annotated domain dimensions {domain} to "
- f"match the set of dimension used by the tensor "
- f"comprehension {collected_dim_uses}")
-
- # Instantiate the dimensions in the given order.
- with self.context:
- local_state = AffineBuildState(
- global_state=self.affine_state, allow_new_symbols=False)
- for dim in domain:
- dim.build(state=local_state)
-
- # Collect all attribute definitions.
- collected_attr_defs = list()
- for operand in registered_operands:
- if operand.is_attribute():
- collected_attr_defs.append(operand)
-
- # Collect all tensors with manual indexing annotation.
- collected_index_defs = list()
- for operand in registered_operands:
- if operand.index_dims:
- if any(dim not in collected_dim_uses for dim in operand.index_dims):
- raise ValueError(f"Expected all index dims {operand.index_dims} of "
- f"operand {operand.name} to have uses.")
- collected_index_defs.append(operand)
-
- # Collect the operand definitions of all tensor/scalar uses, attributes, and
- # shape-only tensors.
- all_operand_defs = list()
- for use in collected_tensor_uses:
- all_operand_defs.append(use.operand_def)
- for use in collected_scalar_uses:
- all_operand_defs.append(use.operand_def)
- for definition in collected_attr_defs:
- all_operand_defs.append(definition)
- for definition in collected_index_defs:
- all_operand_defs.append(definition)
-
- # Add all operands in registration order to ensure the symbols are
- # registered in the order they appear.
- all_operand_defs = sorted(
- all_operand_defs, key=lambda operand_def: operand_def.registered_index)
- for operand_def in all_operand_defs:
- self.add_operand(operand_def)
-
- # Add all shape-only tensor index_dim annotations and all tensor uses.
- for definition in collected_index_defs:
- self.add_indexed_operand(definition)
- for use in collected_tensor_uses:
- self.add_tensor_use(use)
-
- # Normalize all shape and indexing maps now that full count of dims and
- # symbols are known.
- for cuse in self.uses.values():
- cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
- for definition in collected_index_defs:
- self.operands[definition].indexing_map = self._normalize_affine_map(
- self.operands[definition].indexing_map)
- for operand_config in self.operands.values():
- if operand_config.shape_map:
- operand_config.shape_map = self._normalize_affine_map(
- operand_config.shape_map, with_dims=False)
- if operand_config.index_attr_map:
- operand_config.index_attr_map = self._normalize_affine_map(
- operand_config.index_attr_map, with_dims=False)
-
- # Now for each write use, propagate the indexing maps from the use to the
- # tensor, ensuring that there are not conflicts.
- for write_use, _ in self.writes:
- write_tensor_config = self.operands[write_use.operand_def]
- if write_tensor_config.indexing_map:
- raise ValueError(
- f"Unexpected multi-write to a single tensor: {write_tensor_config}")
- write_tensor_config.indexing_map = self.uses[write_use].indexing_map
-
- # For each read use, propagate the indexing maps from the use to the
- # tensor, ensuring that there are not conflicts.
- for _, read_expr in self.writes:
- read_uses = set() # type: Set[TensorUse]
- read_expr.collect_tensor_uses(read_uses)
- for read_use in read_uses:
- read_operand_config = self.operands[read_use.operand_def]
- if (read_operand_config.indexing_map and
- read_operand_config.indexing_map !=
- self.uses[read_use].indexing_map):
- raise ValueError(
- f"Unexpected multi-read of a tensor with
diff erent accesses:"
- f"{read_operand_config} vs {read_use}")
- read_operand_config.indexing_map = self.uses[read_use].indexing_map
-
- # Set the indexing map of all scalar uses to the empty map.
- for operand_config in self.operands.values():
- if operand_config.operand_def.kind == OperandKind.SCALAR:
- operand_config.indexing_map = self._get_scalar_map()
-
- # Check all registered tensor and scalar operands have an indexing map.
- for operand in registered_operands:
- if operand.is_attribute():
- continue
- if not (operand in self.operands and self.operands[operand].indexing_map):
- raise ValueError(f"Failed to compute an indexing map for operand "
- f"{operand.name}")
-
- # Collect reduction dims and ensure all the same.
- all_reduction_dims = set(comprehension.all_reduction_dims)
- if len(all_reduction_dims) != 1:
- raise ValueError(
- f"All writes within a generic must have the same reduction "
- f"dims. Got: {all_reduction_dims}")
- self.reduction_dims = next(iter(all_reduction_dims))
-
- # Check the index dimension exists and resolve.
- for index in collected_indices:
- if index.dim_def.dimname not in self.affine_state.all_dims:
+ yaml_tag = "!LinalgIndexingMapsConfig"
+
+ def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
+ self.static_indexing_maps = static_indexing_maps
+
+ def to_yaml_custom_dict(self):
+ if self.static_indexing_maps is not None:
+ return dict(
+ static_indexing_maps=[
+ _serialize_affine_map(m) for m in self.static_indexing_maps
+ ]
+ )
raise ValueError(
- f"The dimension {index.dim_def.dimname} is not part of the "
- f"iteration domain {self.affine_state.all_dims}")
- index.resolve_dimension_name(self.affine_state)
-
- # Generate the scalar assignments (used to build a body).
- self.assignments = [
- ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
- for write_use, read_expr in self.writes
- ]
-
- @property
- def ordered_operands(self) -> Sequence[OperandDefConfig]:
- return sorted(
- self.operands.values(),
- key=lambda operand: operand.operand_def.registered_index)
-
- @property
- def ordered_dims(self) -> Sequence[Tuple[str, int]]:
- """Gets the ordered list of dim bindings (symbolic name, position).
-
- TODO: The original parser relies on parse ordering to arrive at the
- iterator types, but that ordering is not defined on the Python side, so
- this may be ambiguous.
- """
- return list(self.affine_state.all_dims.items())
-
- @property
- def indexing_maps(self) -> Sequence[_ir.AffineMap]:
- return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
-
- @property
- def iterator_types(self) -> Sequence[str]:
-
- def get_type(symbolic_name, position):
- for reduction_dim_expr in self.reduction_dims:
- if reduction_dim_expr.dimname == symbolic_name:
- return "reduction"
- return "parallel"
-
- return [get_type(*dim) for dim in self.ordered_dims]
-
- def add_operand(self, operand_def: OperandDef):
- if operand_def in self.operands:
- return
- if not (operand_def.is_tensor() or
- operand_def.kind == OperandKind.INDEX_ATTR):
- self.operands[operand_def] = OperandDefConfig(operand_def)
- return
- with self.context:
- local_state = AffineBuildState(
- global_state=self.affine_state, allow_new_dims=False)
- exprs = []
- for expr in operand_def.size_exprs:
- exprs.append(expr.build(state=local_state))
- assert local_state.local_dim_count == 0
- affine_map = _ir.AffineMap.get(
- dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
- if operand_def.kind == OperandKind.INDEX_ATTR:
- self.operands[operand_def] = OperandDefConfig(
- operand_def, index_attr_map=affine_map)
- else:
- self.operands[operand_def] = OperandDefConfig(
- operand_def, shape_map=affine_map)
-
- def add_indexed_operand(self, operand_def: OperandDef):
- with self.context:
- local_state = AffineBuildState(
- global_state=self.affine_state, allow_new_symbols=False)
- exprs = []
- for expr in operand_def.index_dims:
- exprs.append(expr.build(state=local_state))
- self.operands[operand_def].indexing_map = _ir.AffineMap.get(
- dim_count=local_state.dim_count,
- symbol_count=local_state.symbol_count,
- exprs=exprs)
-
- def add_tensor_use(self, tensor_use: TensorUse):
- if tensor_use in self.uses:
- return
- with self.context:
- local_state = AffineBuildState(
- global_state=self.affine_state, allow_new_symbols=False)
- exprs = []
- for expr in tensor_use.indices:
- exprs.append(expr.build(state=local_state))
- indexing_map = _ir.AffineMap.get(
- dim_count=local_state.dim_count,
- symbol_count=local_state.symbol_count,
- exprs=exprs)
-
- use_config = TensorUseConfig(tensor_use, indexing_map)
- self.uses[tensor_use] = use_config
-
- def _get_scalar_map(self) -> _ir.AffineMap:
- """Create an empty affine map used to index a scalar."""
- with self.context:
- return _ir.AffineMap.get(
- dim_count=self.affine_state.dim_count,
- symbol_count=self.affine_state.symbol_count,
- exprs=list())
-
- def _normalize_affine_map(self,
- affine_map: _ir.AffineMap,
- with_dims: bool = True) -> _ir.AffineMap:
- """Normalizes an indexing map to have the max known symbols and dims."""
- with self.context:
- return _ir.AffineMap.get(
- dim_count=self.affine_state.dim_count if with_dims else 0,
- symbol_count=self.affine_state.symbol_count,
- exprs=list(affine_map.results))
-
- def to_yaml_custom_dict(self):
- self_dict = dict(args=self.ordered_operands)
- # TODO: Refactor the hierarchy internally when supporting more
- # than static (preserving this serialized form).
- self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
- static_indexing_maps=self.indexing_maps)
- self_dict["iterator_types"] = self.iterator_types
- self_dict["assignments"] = self.assignments
- return self_dict
-
- def __repr__(self):
- lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
- lines.append("operands=[")
- for def_config in self.ordered_operands:
- lines.append(f" {repr(def_config)}")
- lines.append("], indexing_maps=[")
- for m in self.indexing_maps:
- lines.append(f" {repr(m)}")
- lines.append(f"], iterator_types=[")
- for t in self.iterator_types:
- lines.append(f" {t}")
- lines.append("])")
- return "\n".join(lines)
+ f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)"
+ )
+
+
+class LinalgStructuredOpConfig(YAMLObject):
+ """Configuration for metadata sufficient to construct a linalg named op."""
+
+ yaml_tag = "!LinalgStructuredOpConfig"
+
+ def __init__(
+ self,
+ comprehension: Comprehension,
+ domain: Sequence[DimDef],
+ registered_operands: Sequence[OperandDef],
+ context: Optional[_ir.Context] = None,
+ ):
+ self.context = context if context is not None else _ir.Context()
+ self.affine_state = AffineBuildState()
+ self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]]
+ self.operands = dict() # type: Dict[OperandDef, OperandDefConfig]
+ self.uses = dict() # type: Dict[TensorUse, TensorUseConfig]
+
+ # Compute the ordered set of writes and collect the tensor, capture, dims,
+ # and index uses.
+ collected_tensor_uses = set()
+ collected_scalar_uses = set()
+ collected_dim_uses = set()
+ collected_indices = set()
+ for write_use, read_use in zip(comprehension.definitions, comprehension.values):
+ self.writes.append((write_use, read_use))
+
+ for write_use, read_use in self.writes:
+ collected_tensor_uses.add(write_use)
+ read_use.collect_tensor_uses(collected_tensor_uses)
+ read_use.collect_scalar_uses(collected_scalar_uses)
+ read_use.collect_dim_uses(collected_dim_uses)
+ write_use.collect_dim_uses(collected_dim_uses)
+ read_use.collect_indices(collected_indices)
+
+ # Set domain to the sorted list of uses if no domain annotation is given.
+ if not domain:
+ domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
+
+ # Verify the domain dimensions match the used dimensions.
+ if len(domain) != len(collected_dim_uses) or any(
+ dim not in collected_dim_uses for dim in domain
+ ):
+ raise ValueError(
+ f"Expected the annotated domain dimensions {domain} to "
+ f"match the set of dimension used by the tensor "
+ f"comprehension {collected_dim_uses}"
+ )
+
+ # Instantiate the dimensions in the given order.
+ with self.context:
+ local_state = AffineBuildState(
+ global_state=self.affine_state, allow_new_symbols=False
+ )
+ for dim in domain:
+ dim.build(state=local_state)
+
+ # Collect all attribute definitions.
+ collected_attr_defs = list()
+ for operand in registered_operands:
+ if operand.is_attribute():
+ collected_attr_defs.append(operand)
+
+ # Collect all tensors with manual indexing annotation.
+ collected_index_defs = list()
+ for operand in registered_operands:
+ if operand.index_dims:
+ if any(dim not in collected_dim_uses for dim in operand.index_dims):
+ raise ValueError(
+ f"Expected all index dims {operand.index_dims} of "
+ f"operand {operand.name} to have uses."
+ )
+ collected_index_defs.append(operand)
+
+ # Collect the operand definitions of all tensor/scalar uses, attributes, and
+ # shape-only tensors.
+ all_operand_defs = list()
+ for use in collected_tensor_uses:
+ all_operand_defs.append(use.operand_def)
+ for use in collected_scalar_uses:
+ all_operand_defs.append(use.operand_def)
+ for definition in collected_attr_defs:
+ all_operand_defs.append(definition)
+ for definition in collected_index_defs:
+ all_operand_defs.append(definition)
+
+ # Add all operands in registration order to ensure the symbols are
+ # registered in the order they appear.
+ all_operand_defs = sorted(
+ all_operand_defs, key=lambda operand_def: operand_def.registered_index
+ )
+ for operand_def in all_operand_defs:
+ self.add_operand(operand_def)
+
+ # Add all shape-only tensor index_dim annotations and all tensor uses.
+ for definition in collected_index_defs:
+ self.add_indexed_operand(definition)
+ for use in collected_tensor_uses:
+ self.add_tensor_use(use)
+
+ # Normalize all shape and indexing maps now that full count of dims and
+ # symbols are known.
+ for cuse in self.uses.values():
+ cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
+ for definition in collected_index_defs:
+ self.operands[definition].indexing_map = self._normalize_affine_map(
+ self.operands[definition].indexing_map
+ )
+ for operand_config in self.operands.values():
+ if operand_config.shape_map:
+ operand_config.shape_map = self._normalize_affine_map(
+ operand_config.shape_map, with_dims=False
+ )
+ if operand_config.index_attr_map:
+ operand_config.index_attr_map = self._normalize_affine_map(
+ operand_config.index_attr_map, with_dims=False
+ )
+
+ # Now for each write use, propagate the indexing maps from the use to the
+ # tensor, ensuring that there are not conflicts.
+ for write_use, _ in self.writes:
+ write_tensor_config = self.operands[write_use.operand_def]
+ if write_tensor_config.indexing_map:
+ raise ValueError(
+ f"Unexpected multi-write to a single tensor: {write_tensor_config}"
+ )
+ write_tensor_config.indexing_map = self.uses[write_use].indexing_map
+
+ # For each read use, propagate the indexing maps from the use to the
+ # tensor, ensuring that there are not conflicts.
+ for _, read_expr in self.writes:
+ read_uses = set() # type: Set[TensorUse]
+ read_expr.collect_tensor_uses(read_uses)
+ for read_use in read_uses:
+ read_operand_config = self.operands[read_use.operand_def]
+ if (
+ read_operand_config.indexing_map
+ and read_operand_config.indexing_map
+ != self.uses[read_use].indexing_map
+ ):
+ raise ValueError(
+ f"Unexpected multi-read of a tensor with
diff erent accesses:"
+ f"{read_operand_config} vs {read_use}"
+ )
+ read_operand_config.indexing_map = self.uses[read_use].indexing_map
+
+ # Set the indexing map of all scalar uses to the empty map.
+ for operand_config in self.operands.values():
+ if operand_config.operand_def.kind == OperandKind.SCALAR:
+ operand_config.indexing_map = self._get_scalar_map()
+
+ # Check all registered tensor and scalar operands have an indexing map.
+ for operand in registered_operands:
+ if operand.is_attribute():
+ continue
+ if not (operand in self.operands and self.operands[operand].indexing_map):
+ raise ValueError(
+ f"Failed to compute an indexing map for operand " f"{operand.name}"
+ )
+
+ # Collect reduction dims and ensure all the same.
+ all_reduction_dims = set(comprehension.all_reduction_dims)
+ if len(all_reduction_dims) != 1:
+ raise ValueError(
+ f"All writes within a generic must have the same reduction "
+ f"dims. Got: {all_reduction_dims}"
+ )
+ self.reduction_dims = next(iter(all_reduction_dims))
+
+ # Check the index dimension exists and resolve.
+ for index in collected_indices:
+ if index.dim_def.dimname not in self.affine_state.all_dims:
+ raise ValueError(
+ f"The dimension {index.dim_def.dimname} is not part of the "
+ f"iteration domain {self.affine_state.all_dims}"
+ )
+ index.resolve_dimension_name(self.affine_state)
+
+ # Generate the scalar assignments (used to build a body).
+ self.assignments = [
+ ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
+ for write_use, read_expr in self.writes
+ ]
+
+ @property
+ def ordered_operands(self) -> Sequence[OperandDefConfig]:
+ return sorted(
+ self.operands.values(),
+ key=lambda operand: operand.operand_def.registered_index,
+ )
+
+ @property
+ def ordered_dims(self) -> Sequence[Tuple[str, int]]:
+ """Gets the ordered list of dim bindings (symbolic name, position).
+
+ TODO: The original parser relies on parse ordering to arrive at the
+ iterator types, but that ordering is not defined on the Python side, so
+ this may be ambiguous.
+ """
+ return list(self.affine_state.all_dims.items())
+
+ @property
+ def indexing_maps(self) -> Sequence[_ir.AffineMap]:
+ return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
+
+ @property
+ def iterator_types(self) -> Sequence[str]:
+ def get_type(symbolic_name, position):
+ for reduction_dim_expr in self.reduction_dims:
+ if reduction_dim_expr.dimname == symbolic_name:
+ return "reduction"
+ return "parallel"
+
+ return [get_type(*dim) for dim in self.ordered_dims]
+
+ def add_operand(self, operand_def: OperandDef):
+ if operand_def in self.operands:
+ return
+ if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR):
+ self.operands[operand_def] = OperandDefConfig(operand_def)
+ return
+ with self.context:
+ local_state = AffineBuildState(
+ global_state=self.affine_state, allow_new_dims=False
+ )
+ exprs = []
+ for expr in operand_def.size_exprs:
+ exprs.append(expr.build(state=local_state))
+ assert local_state.local_dim_count == 0
+ affine_map = _ir.AffineMap.get(
+ dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs
+ )
+ if operand_def.kind == OperandKind.INDEX_ATTR:
+ self.operands[operand_def] = OperandDefConfig(
+ operand_def, index_attr_map=affine_map
+ )
+ else:
+ self.operands[operand_def] = OperandDefConfig(
+ operand_def, shape_map=affine_map
+ )
+
+ def add_indexed_operand(self, operand_def: OperandDef):
+ with self.context:
+ local_state = AffineBuildState(
+ global_state=self.affine_state, allow_new_symbols=False
+ )
+ exprs = []
+ for expr in operand_def.index_dims:
+ exprs.append(expr.build(state=local_state))
+ self.operands[operand_def].indexing_map = _ir.AffineMap.get(
+ dim_count=local_state.dim_count,
+ symbol_count=local_state.symbol_count,
+ exprs=exprs,
+ )
+
+ def add_tensor_use(self, tensor_use: TensorUse):
+ if tensor_use in self.uses:
+ return
+ with self.context:
+ local_state = AffineBuildState(
+ global_state=self.affine_state, allow_new_symbols=False
+ )
+ exprs = []
+ for expr in tensor_use.indices:
+ exprs.append(expr.build(state=local_state))
+ indexing_map = _ir.AffineMap.get(
+ dim_count=local_state.dim_count,
+ symbol_count=local_state.symbol_count,
+ exprs=exprs,
+ )
+
+ use_config = TensorUseConfig(tensor_use, indexing_map)
+ self.uses[tensor_use] = use_config
+
+ def _get_scalar_map(self) -> _ir.AffineMap:
+ """Create an empty affine map used to index a scalar."""
+ with self.context:
+ return _ir.AffineMap.get(
+ dim_count=self.affine_state.dim_count,
+ symbol_count=self.affine_state.symbol_count,
+ exprs=list(),
+ )
+
+ def _normalize_affine_map(
+ self, affine_map: _ir.AffineMap, with_dims: bool = True
+ ) -> _ir.AffineMap:
+ """Normalizes an indexing map to have the max known symbols and dims."""
+ with self.context:
+ return _ir.AffineMap.get(
+ dim_count=self.affine_state.dim_count if with_dims else 0,
+ symbol_count=self.affine_state.symbol_count,
+ exprs=list(affine_map.results),
+ )
+
+ def to_yaml_custom_dict(self):
+ self_dict = dict(args=self.ordered_operands)
+ # TODO: Refactor the hierarchy internally when supporting more
+ # than static (preserving this serialized form).
+ self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
+ static_indexing_maps=self.indexing_maps
+ )
+ self_dict["iterator_types"] = self.iterator_types
+ self_dict["assignments"] = self.assignments
+ return self_dict
+
+ def __repr__(self):
+ lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
+ lines.append("operands=[")
+ for def_config in self.ordered_operands:
+ lines.append(f" {repr(def_config)}")
+ lines.append("], indexing_maps=[")
+ for m in self.indexing_maps:
+ lines.append(f" {repr(m)}")
+ lines.append(f"], iterator_types=[")
+ for t in self.iterator_types:
+ lines.append(f" {t}")
+ lines.append("])")
+ return "\n".join(lines)
class LinalgOpConfig(YAMLObject):
- """Container for any supported linalg op type.
-
- This includes the concrete type by name for ease of parsing by systems
- that ignore tags.
- """
- yaml_tag = "!LinalgOpConfig"
-
- def __init__(self,
- metadata: OpMetadataDef,
- *,
- structured_op: Optional[LinalgStructuredOpConfig] = None):
- self.metadata = metadata
- self.structured_op = structured_op
-
- def to_yaml_custom_dict(self):
- self_dict = dict(metadata=self.metadata,)
- if self.structured_op:
- self_dict["structured_op"] = self.structured_op
- return self_dict
-
- @staticmethod
- def from_linalg_op_def(
- op_def: LinalgOpDef,
- context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
- """Expands a LinalgOpDef into corresponding Linalg configured ops."""
- # TODO: Many LinalgOpDef patterns need to expand to multiple generics.
- assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
- return [
- LinalgOpConfig(
- op_def.metadata,
- structured_op=LinalgStructuredOpConfig(
- op_def.comprehensions[0], op_def.domain,
- op_def.registered_operands.values(), context)),
- ]
-
- def __repr__(self):
- return (f"LinalgOpConfig(metadata={self.metadata},\n"
- f"structured_op={self.structured_op})")
+ """Container for any supported linalg op type.
+
+ This includes the concrete type by name for ease of parsing by systems
+ that ignore tags.
+ """
+
+ yaml_tag = "!LinalgOpConfig"
+
+ def __init__(
+ self,
+ metadata: OpMetadataDef,
+ *,
+ structured_op: Optional[LinalgStructuredOpConfig] = None,
+ ):
+ self.metadata = metadata
+ self.structured_op = structured_op
+
+ def to_yaml_custom_dict(self):
+ self_dict = dict(
+ metadata=self.metadata,
+ )
+ if self.structured_op:
+ self_dict["structured_op"] = self.structured_op
+ return self_dict
+
+ @staticmethod
+ def from_linalg_op_def(
+ op_def: LinalgOpDef, context: Optional[_ir.Context] = None
+ ) -> Sequence["LinalgOpConfig"]:
+ """Expands a LinalgOpDef into corresponding Linalg configured ops."""
+ # TODO: Many LinalgOpDef patterns need to expand to multiple generics.
+ assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
+ return [
+ LinalgOpConfig(
+ op_def.metadata,
+ structured_op=LinalgStructuredOpConfig(
+ op_def.comprehensions[0],
+ op_def.domain,
+ op_def.registered_operands.values(),
+ context,
+ ),
+ ),
+ ]
+
+ def __repr__(self):
+ return (
+ f"LinalgOpConfig(metadata={self.metadata},\n"
+ f"structured_op={self.structured_op})"
+ )
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 45b8d5ccd13d6..8b8726f8f9a03 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -10,160 +10,192 @@
import threading
from ..... import ir
-from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ...._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+)
from .comprehension import *
from .config import *
from .emitter import *
_CONTEXT = threading.local()
-StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
- Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
+StructuredOpOuts = Union[
+ ir.Operation,
+ ir.OpView,
+ ir.OpResultList,
+ Sequence[Union[ir.Value, ir.Operation, ir.OpView]],
+]
@contextmanager
def bind_op_def(op_def: LinalgOpDef):
- if hasattr(_CONTEXT, "current_op_def"):
- raise ValueError("Cannot recursively define an operation")
- _CONTEXT.current_op_def = op_def
- try:
- yield op_def
- finally:
- del _CONTEXT.current_op_def
+ if hasattr(_CONTEXT, "current_op_def"):
+ raise ValueError("Cannot recursively define an operation")
+ _CONTEXT.current_op_def = op_def
+ try:
+ yield op_def
+ finally:
+ del _CONTEXT.current_op_def
def current_op_def() -> LinalgOpDef:
- try:
- return _CONTEXT.current_op_def
- except AttributeError:
- raise ValueError(
- "Attempt to access the current op definition being defined "
- "but none is set. Did you mean to call this in an op definition?")
+ try:
+ return _CONTEXT.current_op_def
+ except AttributeError:
+ raise ValueError(
+ "Attempt to access the current op definition being defined "
+ "but none is set. Did you mean to call this in an op definition?"
+ )
def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
- if isinstance(outs, (ir.Operation, ir.OpView)):
- return _get_op_results_or_values(outs)
- elif isinstance(outs, ir.OpResultList):
- return outs
+ if isinstance(outs, (ir.Operation, ir.OpView)):
+ return _get_op_results_or_values(outs)
+ elif isinstance(outs, ir.OpResultList):
+ return outs
- return [_get_op_result_or_value(o) for o in outs]
+ return [_get_op_result_or_value(o) for o in outs]
class DefinedOpCallable:
- """Callable that wraps any defined op function."""
-
- def __init__(self, op_name: str, op_def: LinalgOpDef):
- self.op_name = op_name
- self.op_def = op_def
-
- def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
- outs: StructuredOpOuts, **kwargs):
- """Emits the corresponding op definition as IR.
-
- Most arguments are passed through to the underlying emitter. The following
- keyword argument is interpreted here:
- emit_generic: Emits a generic form as appropriate (default True). If
- False, a named form is emitted (which must have been built in to the
- compiler).
- """
- emit_generic = kwargs.pop("emit_generic", False)
- if not isinstance(emit_generic, bool):
- raise ValueError(f"The named argument 'emit_generic' needs to be "
- f" of type bool but got {type(emit_generic)}")
-
- op_configs = LinalgOpConfig.from_linalg_op_def(
- self.op_def, context=ir.Context.current)
-
- if len(op_configs) != 1:
- # TODO: Support composite ops.
- raise NotImplementedError(
- f"Emission of composite linalg ops not supported: {op_configs}")
-
- ctx = ir.Context.current
- linalgDialect = ctx.get_dialect_descriptor("linalg")
- fully_qualified_name = "linalg." + self.op_name
- emit_generic = (
- emit_generic or not ctx.is_registered_operation(fully_qualified_name))
-
- op_config = op_configs[0]
- out_values = _prepare_structured_op_outs(outs)
- in_values = [_get_op_result_or_value(i) for i in ins]
- if op_config.structured_op:
- if emit_generic:
- return emit_generic_structured_op(
- op_config.structured_op, *in_values, outs=out_values, **kwargs)
- else:
- return emit_named_structured_op(
- op_config.structured_op,
- self.op_name,
- self.op_def.metadata.cpp_class_name,
- *in_values,
- outs=out_values,
- **kwargs)
-
- raise NotImplementedError(
- f"Emission of linalg op type not supported: {op_config}")
-
-
-def linalg_structured_op(dsl_func=None,
- *,
- op_name=None,
- op_class_name=None) -> DefinedOpCallable:
- if dsl_func is None:
- # Curry the keyword args in for delayed application.
- return functools.partial(
- linalg_structured_op, op_name=op_name, op_class_name=op_class_name)
- # Determine default names by introspecting the function.
- if op_name is None:
- op_name = dsl_func.__name__
- if op_class_name is None:
- # Camel case it.
- op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
-
- op_def = LinalgOpDef(
- name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func))
-
- # Extract arguments and TensorDefs from the signature.
- dsl_func_args = list()
- sig = inspect.signature(dsl_func)
- for param_name, param in sig.parameters.items():
- param_default = param.default
- if isinstance(param_default,
- (TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef,
- BinaryFnAttrDef, TypeFnAttrDef)):
- op_def.add_operand(param_name, param_default.operand_def)
- else:
- raise ValueError(
- f"@linalg_structured_op function parameters must be defaulted as "
- f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
- f"Found {param_name}: {param_default}")
- dsl_func_args.append(param_default)
-
- # Invoke the DSL func to finish populating the op definition.
- with bind_op_def(op_def):
- dsl_func(*dsl_func_args)
-
- # TODO: The returned callable should be an IR emitter but that is not
- # upstreamed yet.
- return DefinedOpCallable(op_name, op_def)
+ """Callable that wraps any defined op function."""
+
+ def __init__(self, op_name: str, op_def: LinalgOpDef):
+ self.op_name = op_name
+ self.op_def = op_def
+
+ def __call__(
+ self,
+ *ins: Union[ir.Operation, ir.OpView, ir.Value],
+ outs: StructuredOpOuts,
+ **kwargs,
+ ):
+ """Emits the corresponding op definition as IR.
+
+ Most arguments are passed through to the underlying emitter. The following
+ keyword argument is interpreted here:
+ emit_generic: Emits a generic form as appropriate (default True). If
+ False, a named form is emitted (which must have been built in to the
+ compiler).
+ """
+ emit_generic = kwargs.pop("emit_generic", False)
+ if not isinstance(emit_generic, bool):
+ raise ValueError(
+ f"The named argument 'emit_generic' needs to be "
+ f" of type bool but got {type(emit_generic)}"
+ )
+
+ op_configs = LinalgOpConfig.from_linalg_op_def(
+ self.op_def, context=ir.Context.current
+ )
+
+ if len(op_configs) != 1:
+ # TODO: Support composite ops.
+ raise NotImplementedError(
+ f"Emission of composite linalg ops not supported: {op_configs}"
+ )
+
+ ctx = ir.Context.current
+ linalgDialect = ctx.get_dialect_descriptor("linalg")
+ fully_qualified_name = "linalg." + self.op_name
+ emit_generic = emit_generic or not ctx.is_registered_operation(
+ fully_qualified_name
+ )
+
+ op_config = op_configs[0]
+ out_values = _prepare_structured_op_outs(outs)
+ in_values = [_get_op_result_or_value(i) for i in ins]
+ if op_config.structured_op:
+ if emit_generic:
+ return emit_generic_structured_op(
+ op_config.structured_op, *in_values, outs=out_values, **kwargs
+ )
+ else:
+ return emit_named_structured_op(
+ op_config.structured_op,
+ self.op_name,
+ self.op_def.metadata.cpp_class_name,
+ *in_values,
+ outs=out_values,
+ **kwargs,
+ )
+
+ raise NotImplementedError(
+ f"Emission of linalg op type not supported: {op_config}"
+ )
+
+
+def linalg_structured_op(
+ dsl_func=None, *, op_name=None, op_class_name=None
+) -> DefinedOpCallable:
+ if dsl_func is None:
+ # Curry the keyword args in for delayed application.
+ return functools.partial(
+ linalg_structured_op, op_name=op_name, op_class_name=op_class_name
+ )
+ # Determine default names by introspecting the function.
+ if op_name is None:
+ op_name = dsl_func.__name__
+ if op_class_name is None:
+ # Camel case it.
+ op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
+
+ op_def = LinalgOpDef(
+ name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)
+ )
+
+ # Extract arguments and TensorDefs from the signature.
+ dsl_func_args = list()
+ sig = inspect.signature(dsl_func)
+ for param_name, param in sig.parameters.items():
+ param_default = param.default
+ if isinstance(
+ param_default,
+ (
+ TensorDef,
+ ScalarDef,
+ IndexAttrDef,
+ UnaryFnAttrDef,
+ BinaryFnAttrDef,
+ TypeFnAttrDef,
+ ),
+ ):
+ op_def.add_operand(param_name, param_default.operand_def)
+ else:
+ raise ValueError(
+ f"@linalg_structured_op function parameters must be defaulted as "
+ f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
+ f"Found {param_name}: {param_default}"
+ )
+ dsl_func_args.append(param_default)
+
+ # Invoke the DSL func to finish populating the op definition.
+ with bind_op_def(op_def):
+ dsl_func(*dsl_func_args)
+
+ # TODO: The returned callable should be an IR emitter but that is not
+ # upstreamed yet.
+ return DefinedOpCallable(op_name, op_def)
def domain(*dimensions: DimDef):
- if any(not isinstance(d, DimDef) for d in dimensions):
- raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
- current_op_def().domain.extend(dimensions)
+ if any(not isinstance(d, DimDef) for d in dimensions):
+ raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
+ current_op_def().domain.extend(dimensions)
def implements(*interfaces: OpInterfaceDef):
- if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
- raise ValueError(
- f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
- current_op_def().metadata.implements.extend(interfaces)
+ if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
+ raise ValueError(
+ f"Expected interfaces of type OpInterfaceDef but got {interfaces}"
+ )
+ current_op_def().metadata.implements.extend(interfaces)
def defines(*definitions: OpDefinitionDef):
- if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
- raise ValueError(
- f"Expected definitions of type OpDefinitionDef but got {definitions}")
- current_op_def().metadata.defines.extend(definitions)
+ if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
+ raise ValueError(
+ f"Expected definitions of type OpDefinitionDef but got {definitions}"
+ )
+ current_op_def().metadata.defines.extend(definitions)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index b63cb40717aed..62730d9ca4d8e 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -11,7 +11,10 @@
from .... import math
from .... import arith
from .... import complex
-from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ...._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+)
from .scalar_expr import *
from .config import *
@@ -29,529 +32,618 @@
def isa(cls: Type, ty: Type):
- try:
- cls(ty)
- return True
- except ValueError:
- return False
-
-
-def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
- *ins: Value, outs: ValueList,
- **attrs: Union[Sequence[int], TypeFnType]):
- all_arg_defs = op_config.ordered_operands
- in_arg_defs = [
- d for d in all_arg_defs
- if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
- ]
- out_arg_defs = [
- d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
- ]
- index_attr_arg_defs = [
- d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
- ]
- fn_attr_arg_defs = [
- d for d in all_arg_defs if d.kind in [
- OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR,
- OperandKind.TYPE_FN_ATTR
- ]
- ]
-
- # Verify outs is a sequence or a list of results.
- if not isinstance(outs, (Sequence, OpResultList)):
- raise ValueError(f"Expected named argument outs to have type Sequence or "
- f"OpResultLis but got {type(outs)}")
-
- # Arity validation.
- if len(ins) != len(in_arg_defs):
- raise ValueError(f"Expected {len(in_arg_defs)} inputs but got "
- f"{len(ins)} for {op_config}")
- if outs and len(outs) != len(out_arg_defs):
- raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
- f"{len(outs)} for {op_config}")
-
- # Compute a replacement list for all index attribute symbols.
- expressions = [] # type: Sequence[AffineExpr]
- replacements = [] # type: Sequence[AffineExpr]
- for index_attr in index_attr_arg_defs:
- index_attr_vals = index_attr.operand_def.default_indices
- if index_attr.name in attrs:
- index_attr_vals = attrs.get(index_attr.name)
- assert index_attr_vals, "Index attribute has no value"
- if not all(isinstance(value, int) for value in index_attr_vals):
- raise ValueError(f"Attribute {index_attr.name} needs to be of type "
- f"Sequence[int] but got {type(index_attr_vals)}")
- results = index_attr.index_attr_map.results # type: AffineExprList
- if len(index_attr_vals) != len(results):
- raise ValueError(f"Attribute {index_attr.name} has length {len(results)} "
- f"but got {len(index_attr_vals)} values")
- for expr, value in zip(results, index_attr_vals):
- expressions.append(expr)
- replacements.append(AffineConstantExpr.get(value))
-
- # Replace all index attribute symbols by their value.
- # TODO: Add support for shape symbols.
- indexing_maps = [] # type: Sequence[AffineMap]
- for curr in op_config.indexing_maps:
- for expression, replacement in zip(expressions, replacements):
- curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
- indexing_maps.append(curr)
-
- # TODO: Linalg verification does not currently allow symbols.
- # Compress them for now and verify none are left.
- indexing_maps = AffineMap.compress_unused_symbols(indexing_maps,
- Context.current)
- if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
- raise ValueError(f"Expected indexing_maps to use no symbols after "
- f"replacement and compression but got {indexing_maps}")
-
- outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
- out_arg_defs, outs)
-
- result_types = [t for t in out_types if isa(RankedTensorType, t)]
-
- # Initialize the type dictionary with the predefined types.
- type_mapping = dict() # type: Dict[str, Type]
- type_mapping["F32"] = F32Type.get()
- type_mapping["F64"] = F64Type.get()
- type_mapping["I32"] = IntegerType.get_signless(32)
- type_mapping["I64"] = IntegerType.get_signless(64)
-
- # Extract type vars for input/output based types.
- block_arg_types = list() # type: List[Type]
- for arg_def, arg_element_type in zip(in_arg_defs + out_arg_defs,
- _get_types_from_values(*ins, *outs)):
- _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types)
-
- # Emit the generic op.
- # TODO: Support emission of pure memref form.
- indexing_maps_attr = ArrayAttr.get(
- [AffineMapAttr.get(am) for am in indexing_maps])
- iterator_types_attr = ArrayAttr.get([
- Attribute.parse(f"#linalg.iterator_type<{s}>")
- for s in op_config.iterator_types
- ])
-
- # Compute the index attributes used when emitting a named structured op.
- index_attrs = {} # type: Dict[str, DenseElementAttr]
- for index_attr in index_attr_arg_defs:
- index_attr_vals = attrs.get(index_attr.name)
- # Only forward attributes set to a non-default value.
- if index_attr_vals:
- array = np.array(index_attr_vals, dtype=np.int64)
- index_attrs[index_attr.name] = DenseElementsAttr.get(array)
-
- # Compute the function attribute mapping.
- fn_attr_mapping = {}
- for fn_attr in fn_attr_arg_defs:
- attr_val = fn_attr.operand_def.default_fn
- attr_kind = fn_attr.kind
- if fn_attr.name in attrs:
- fn = attrs.get(fn_attr.name)
- if attr_kind == OperandKind.UNARY_FN_ATTR:
- if not isinstance(fn, UnaryFnType):
- raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
- f"UnaryFnType but got {type(attr_val)}")
- elif attr_kind == OperandKind.BINARY_FN_ATTR:
- if not isinstance(fn, BinaryFnType):
- raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
- f"BinaryFnType but got {type(attr_val)}")
- else:
- if not isinstance(fn, TypeFnType):
- raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
- f"TypeFnType but got {type(attr_val)}")
- attr_val = fn.fn_name
- assert attr_val, "Function attribute has no value"
- fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
-
- return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
- type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
- fn_attr_mapping, block_arg_types)
-
-
-def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
- outs: ValueList, **attrs: Sequence[int]):
- all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
- indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
- block_arg_types = \
- prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
-
- # An operation that accesses only scalars and scalar/rank zero tensors is
- # rank polymorhpic. We implement rank polymorphism by generating
diff erent
- # indexing maps and iterators that match the rank of the first output tensor.
- # An operation is rank polymorphic if the iteration domain has rank zero.
- if not iterator_types_attr:
- rank = ShapedType(outs[0].type).rank
- iterator_types_attr = ArrayAttr.get(
- [Attribute.parse("#linalg.iterator_type<parallel>")] * rank)
- scalar_map = AffineMap.get(rank, 0, [])
- tensor_map = AffineMap.get_identity(rank)
- indexing_maps = []
- for arg_def in all_arg_defs:
- if arg_def.operand_def.kind == OperandKind.SCALAR:
- indexing_maps.append(scalar_map)
- if arg_def.operand_def.is_tensor():
- idx = arg_def.operand_def.registered_index
- if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
- indexing_maps.append(scalar_map)
- else:
- indexing_maps.append(tensor_map)
- indexing_maps_attr = ArrayAttr.get(
- [AffineMapAttr.get(am) for am in indexing_maps])
-
- generic_op = linalg.GenericOp(
- result_tensors=result_types,
- inputs=ins,
- outputs=outs,
- indexing_maps=indexing_maps_attr,
- iterator_types=iterator_types_attr,
- doc=None, # TODO: Make optional.
- library_call=None) # TODO: Make optional.
-
- # Construct the body.
- block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs)
- block = generic_op.regions[0].blocks.append(*block_arg_types)
- block_arg_mapping = dict(zip(block_arg_names, block.arguments))
- with InsertionPoint(block):
- body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
- fn_attr_mapping)
- for assignment in op_config.assignments:
- body_builder.assign(assignment)
- body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
-
- if len(result_types) == 1:
- return generic_op.result
- else:
- return generic_op.results
-
-
-def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
- op_class_name: str, *ins: Value, outs: ValueList,
- **attrs: Sequence[int]):
- all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
- indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
- block_arg_types = \
- prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
-
- # If we get here, there must exist a builtin class `op_class_name`.
- ctx = Context.current
- fully_qualified_name = "linalg." + op_name
- if (not ctx.is_registered_operation(fully_qualified_name) or
- not op_class_name in linalg.__dict__.keys()):
- raise NotImplementedError(
- f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
-
- # Set the index attributes used to compute the indexing maps.
- named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
- for name, value in index_attrs.items():
- named_op.operation.attributes[name] = value
-
- # Compute the function attributes by combining operand kind and function name.
- for name, (fn_name, kind) in fn_attr_mapping.items():
- assert kind.name.lower().endswith("_attr")
- enum_name = kind.name.lower()[:-5]
- named_op.operation.attributes[name] = Attribute.parse(
- f"#linalg.{enum_name}<{fn_name}>")
+ try:
+ cls(ty)
+ return True
+ except ValueError:
+ return False
- linalg.fill_builtin_region(named_op.operation)
- if len(result_types) == 1:
- return named_op.result
- else:
- return named_op.results
+def prepare_common_structured_op(
+ op_config: LinalgStructuredOpConfig,
+ *ins: Value,
+ outs: ValueList,
+ **attrs: Union[Sequence[int], TypeFnType],
+):
+ all_arg_defs = op_config.ordered_operands
+ in_arg_defs = [
+ d
+ for d in all_arg_defs
+ if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
+ ]
+ out_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR]
+ index_attr_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR]
+ fn_attr_arg_defs = [
+ d
+ for d in all_arg_defs
+ if d.kind
+ in [
+ OperandKind.UNARY_FN_ATTR,
+ OperandKind.BINARY_FN_ATTR,
+ OperandKind.TYPE_FN_ATTR,
+ ]
+ ]
+
+ # Verify outs is a sequence or a list of results.
+ if not isinstance(outs, (Sequence, OpResultList)):
+ raise ValueError(
+ f"Expected named argument outs to have type Sequence or "
+ f"OpResultLis but got {type(outs)}"
+ )
+
+ # Arity validation.
+ if len(ins) != len(in_arg_defs):
+ raise ValueError(
+ f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}"
+ )
+ if outs and len(outs) != len(out_arg_defs):
+ raise ValueError(
+ f"Expected {len(out_arg_defs)} outputs but got "
+ f"{len(outs)} for {op_config}"
+ )
+
+ # Compute a replacement list for all index attribute symbols.
+ expressions = [] # type: Sequence[AffineExpr]
+ replacements = [] # type: Sequence[AffineExpr]
+ for index_attr in index_attr_arg_defs:
+ index_attr_vals = index_attr.operand_def.default_indices
+ if index_attr.name in attrs:
+ index_attr_vals = attrs.get(index_attr.name)
+ assert index_attr_vals, "Index attribute has no value"
+ if not all(isinstance(value, int) for value in index_attr_vals):
+ raise ValueError(
+ f"Attribute {index_attr.name} needs to be of type "
+ f"Sequence[int] but got {type(index_attr_vals)}"
+ )
+ results = index_attr.index_attr_map.results # type: AffineExprList
+ if len(index_attr_vals) != len(results):
+ raise ValueError(
+ f"Attribute {index_attr.name} has length {len(results)} "
+ f"but got {len(index_attr_vals)} values"
+ )
+ for expr, value in zip(results, index_attr_vals):
+ expressions.append(expr)
+ replacements.append(AffineConstantExpr.get(value))
+
+ # Replace all index attribute symbols by their value.
+ # TODO: Add support for shape symbols.
+ indexing_maps = [] # type: Sequence[AffineMap]
+ for curr in op_config.indexing_maps:
+ for expression, replacement in zip(expressions, replacements):
+ curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
+ indexing_maps.append(curr)
+
+ # TODO: Linalg verification does not currently allow symbols.
+ # Compress them for now and verify none are left.
+ indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, Context.current)
+ if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
+ raise ValueError(
+ f"Expected indexing_maps to use no symbols after "
+ f"replacement and compression but got {indexing_maps}"
+ )
+
+ outs, out_types = _infer_structured_outs(
+ op_config, in_arg_defs, ins, out_arg_defs, outs
+ )
+
+ result_types = [t for t in out_types if isa(RankedTensorType, t)]
+
+ # Initialize the type dictionary with the predefined types.
+ type_mapping = dict() # type: Dict[str, Type]
+ type_mapping["F32"] = F32Type.get()
+ type_mapping["F64"] = F64Type.get()
+ type_mapping["I32"] = IntegerType.get_signless(32)
+ type_mapping["I64"] = IntegerType.get_signless(64)
+
+ # Extract type vars for input/output based types.
+ block_arg_types = list() # type: List[Type]
+ for arg_def, arg_element_type in zip(
+ in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs)
+ ):
+ _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types)
+
+ # Emit the generic op.
+ # TODO: Support emission of pure memref form.
+ indexing_maps_attr = ArrayAttr.get([AffineMapAttr.get(am) for am in indexing_maps])
+ iterator_types_attr = ArrayAttr.get(
+ [
+ Attribute.parse(f"#linalg.iterator_type<{s}>")
+ for s in op_config.iterator_types
+ ]
+ )
+
+ # Compute the index attributes used when emitting a named structured op.
+ index_attrs = {} # type: Dict[str, DenseElementAttr]
+ for index_attr in index_attr_arg_defs:
+ index_attr_vals = attrs.get(index_attr.name)
+ # Only forward attributes set to a non-default value.
+ if index_attr_vals:
+ array = np.array(index_attr_vals, dtype=np.int64)
+ index_attrs[index_attr.name] = DenseElementsAttr.get(array)
+
+ # Compute the function attribute mapping.
+ fn_attr_mapping = {}
+ for fn_attr in fn_attr_arg_defs:
+ attr_val = fn_attr.operand_def.default_fn
+ attr_kind = fn_attr.kind
+ if fn_attr.name in attrs:
+ fn = attrs.get(fn_attr.name)
+ if attr_kind == OperandKind.UNARY_FN_ATTR:
+ if not isinstance(fn, UnaryFnType):
+ raise ValueError(
+ f"Attribute {fn_attr.name} needs to be of type "
+ f"UnaryFnType but got {type(attr_val)}"
+ )
+ elif attr_kind == OperandKind.BINARY_FN_ATTR:
+ if not isinstance(fn, BinaryFnType):
+ raise ValueError(
+ f"Attribute {fn_attr.name} needs to be of type "
+ f"BinaryFnType but got {type(attr_val)}"
+ )
+ else:
+ if not isinstance(fn, TypeFnType):
+ raise ValueError(
+ f"Attribute {fn_attr.name} needs to be of type "
+ f"TypeFnType but got {type(attr_val)}"
+ )
+ attr_val = fn.fn_name
+ assert attr_val, "Function attribute has no value"
+ fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
+
+ return (
+ all_arg_defs,
+ in_arg_defs,
+ out_arg_defs,
+ outs,
+ result_types,
+ type_mapping,
+ indexing_maps_attr,
+ iterator_types_attr,
+ index_attrs,
+ fn_attr_mapping,
+ block_arg_types,
+ )
+
+
+def emit_generic_structured_op(
+ op_config: LinalgStructuredOpConfig,
+ *ins: Value,
+ outs: ValueList,
+ **attrs: Sequence[int],
+):
+ (
+ all_arg_defs,
+ in_arg_defs,
+ out_arg_defs,
+ outs,
+ result_types,
+ type_mapping,
+ indexing_maps_attr,
+ iterator_types_attr,
+ index_attrs,
+ fn_attr_mapping,
+ block_arg_types,
+ ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)
+
+ # An operation that accesses only scalars and scalar/rank zero tensors is
+ # rank polymorhpic. We implement rank polymorphism by generating
diff erent
+ # indexing maps and iterators that match the rank of the first output tensor.
+ # An operation is rank polymorphic if the iteration domain has rank zero.
+ if not iterator_types_attr:
+ rank = ShapedType(outs[0].type).rank
+ iterator_types_attr = ArrayAttr.get(
+ [Attribute.parse("#linalg.iterator_type<parallel>")] * rank
+ )
+ scalar_map = AffineMap.get(rank, 0, [])
+ tensor_map = AffineMap.get_identity(rank)
+ indexing_maps = []
+ for arg_def in all_arg_defs:
+ if arg_def.operand_def.kind == OperandKind.SCALAR:
+ indexing_maps.append(scalar_map)
+ if arg_def.operand_def.is_tensor():
+ idx = arg_def.operand_def.registered_index
+ if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
+ indexing_maps.append(scalar_map)
+ else:
+ indexing_maps.append(tensor_map)
+ indexing_maps_attr = ArrayAttr.get(
+ [AffineMapAttr.get(am) for am in indexing_maps]
+ )
+
+ generic_op = linalg.GenericOp(
+ result_tensors=result_types,
+ inputs=ins,
+ outputs=outs,
+ indexing_maps=indexing_maps_attr,
+ iterator_types=iterator_types_attr,
+ doc=None, # TODO: Make optional.
+ library_call=None,
+ ) # TODO: Make optional.
+
+ # Construct the body.
+ block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs)
+ block = generic_op.regions[0].blocks.append(*block_arg_types)
+ block_arg_mapping = dict(zip(block_arg_names, block.arguments))
+ with InsertionPoint(block):
+ body_builder = _BodyBuilder(type_mapping, block_arg_mapping, fn_attr_mapping)
+ for assignment in op_config.assignments:
+ body_builder.assign(assignment)
+ body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
+
+ if len(result_types) == 1:
+ return generic_op.result
+ else:
+ return generic_op.results
+
+
+def emit_named_structured_op(
+ op_config: LinalgStructuredOpConfig,
+ op_name: str,
+ op_class_name: str,
+ *ins: Value,
+ outs: ValueList,
+ **attrs: Sequence[int],
+):
+ (
+ all_arg_defs,
+ in_arg_defs,
+ out_arg_defs,
+ outs,
+ result_types,
+ type_mapping,
+ indexing_maps_attr,
+ iterator_types_attr,
+ index_attrs,
+ fn_attr_mapping,
+ block_arg_types,
+ ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)
+
+ # If we get here, there must exist a builtin class `op_class_name`.
+ ctx = Context.current
+ fully_qualified_name = "linalg." + op_name
+ if (
+ not ctx.is_registered_operation(fully_qualified_name)
+ or not op_class_name in linalg.__dict__.keys()
+ ):
+ raise NotImplementedError(
+ f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}"
+ )
+
+ # Set the index attributes used to compute the indexing maps.
+ named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
+ for name, value in index_attrs.items():
+ named_op.operation.attributes[name] = value
+
+ # Compute the function attributes by combining operand kind and function name.
+ for name, (fn_name, kind) in fn_attr_mapping.items():
+ assert kind.name.lower().endswith("_attr")
+ enum_name = kind.name.lower()[:-5]
+ named_op.operation.attributes[name] = Attribute.parse(
+ f"#linalg.{enum_name}<{fn_name}>"
+ )
+
+ linalg.fill_builtin_region(named_op.operation)
+
+ if len(result_types) == 1:
+ return named_op.result
+ else:
+ return named_op.results
class _BodyBuilder:
- """Constructs a structured op body by evaluating assignments."""
-
- def __init__(self, type_mapping: Dict[str, Type],
- block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str,
- str]):
- self.type_mapping = type_mapping
- self.block_arg_mapping = block_arg_mapping
- self.fn_attr_mapping = fn_attr_mapping
- self.yield_mapping = dict() # type: Dict[str, Value]
-
- def assign(self, assignment: ScalarAssign):
- if assignment.arg in self.yield_mapping:
- raise ValueError(
- f"Multiple assignments to the same argument are forbidden: "
- f"{assignment}")
- self.yield_mapping[assignment.arg] = self.expression(assignment.value)
-
- def expression(self, expr: ScalarExpression) -> Value:
- if expr.scalar_arg:
- try:
- return self.block_arg_mapping[expr.scalar_arg.arg]
- except KeyError:
- raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for "
- f"this structured op.")
- elif expr.scalar_const:
- value_attr = Attribute.parse(expr.scalar_const.value)
- return arith.ConstantOp(value_attr.type, value_attr).result
- elif expr.scalar_index:
- dim_attr = IntegerAttr.get(
- IntegerType.get_signless(64), expr.scalar_index.dim)
- return linalg.IndexOp(dim_attr).result
- elif expr.scalar_fn:
- kind = expr.scalar_fn.kind.name.lower()
- fn_name = expr.scalar_fn.fn_name
- if expr.scalar_fn.attr_name:
- fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
- fn = self._get_function(f"_{kind}_{fn_name}")
- operand_values = [
- self.expression(operand) for operand in expr.scalar_fn.operands
- ]
- if expr.scalar_fn.kind == FunctionKind.TYPE:
- operand_values = [expr.scalar_fn.type_var.name] + operand_values
- return fn(*operand_values)
- raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
-
- def yield_outputs(self, *output_names: str):
- output_values = []
- for n in output_names:
- try:
- output_values.append(self.yield_mapping[n])
- except KeyError:
- raise ValueError(f"Body assignments do not assign all outputs: "
- f"missing '{n}'")
- linalg.YieldOp(output_values)
-
- def _get_function(self, fn_name: str) -> Callable:
- try:
- fn = getattr(self, f"{fn_name}")
- except AttributeError:
- raise ValueError(f"Function '{fn_name}' is not a known function")
- return fn
-
- def _cast(self,
- type_var_name: str,
- operand: Value,
- is_unsigned_cast: bool = False) -> Value:
- try:
- to_type = self.type_mapping[type_var_name]
- except KeyError:
- raise ValueError(f"Unbound type variable '{type_var_name}' ("
- f"expected one of {self.type_mapping.keys()}")
- if operand.type == to_type:
- return operand
- if _is_integer_type(to_type):
- return self._cast_to_integer(to_type, operand, is_unsigned_cast)
- elif _is_floating_point_type(to_type):
- return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
-
- def _cast_to_integer(self, to_type: Type, operand: Value,
- is_unsigned_cast: bool) -> Value:
- to_width = IntegerType(to_type).width
- operand_type = operand.type
- if _is_floating_point_type(operand_type):
- if is_unsigned_cast:
- return arith.FPToUIOp(to_type, operand).result
- return arith.FPToSIOp(to_type, operand).result
- if _is_index_type(operand_type):
- return arith.IndexCastOp(to_type, operand).result
- # Assume integer.
- from_width = IntegerType(operand_type).width
- if to_width > from_width:
- if is_unsigned_cast:
- return arith.ExtUIOp(to_type, operand).result
- return arith.ExtSIOp(to_type, operand).result
- elif to_width < from_width:
- return arith.TruncIOp(to_type, operand).result
- raise ValueError(f"Unable to cast body expression from {operand_type} to "
- f"{to_type}")
-
- def _cast_to_floating_point(self, to_type: Type, operand: Value,
- is_unsigned_cast: bool) -> Value:
- operand_type = operand.type
- if _is_integer_type(operand_type):
- if is_unsigned_cast:
- return arith.UIToFPOp(to_type, operand).result
- return arith.SIToFPOp(to_type, operand).result
- # Assume FloatType.
- to_width = _get_floating_point_width(to_type)
- from_width = _get_floating_point_width(operand_type)
- if to_width > from_width:
- return arith.ExtFOp(to_type, operand).result
- elif to_width < from_width:
- return arith.TruncFOp(to_type, operand).result
- raise ValueError(f"Unable to cast body expression from {operand_type} to "
- f"{to_type}")
-
- def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
- return self._cast(type_var_name, operand, False)
-
- def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
- return self._cast(type_var_name, operand, True)
-
- def _unary_exp(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
- return math.ExpOp(x).result
- raise NotImplementedError("Unsupported 'exp' operand: {x}")
-
- def _unary_log(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
- return math.LogOp(x).result
- raise NotImplementedError("Unsupported 'log' operand: {x}")
-
- def _unary_abs(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
- return math.AbsFOp(x).result
- raise NotImplementedError("Unsupported 'abs' operand: {x}")
-
- def _unary_ceil(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
- return math.CeilOp(x).result
- raise NotImplementedError("Unsupported 'ceil' operand: {x}")
-
- def _unary_floor(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
- return math.FloorOp(x).result
- raise NotImplementedError("Unsupported 'floor' operand: {x}")
-
- def _unary_negf(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
- return arith.NegFOp(x).result
- if _is_complex_type(x.type):
- return complex.NegOp(x).result
- raise NotImplementedError("Unsupported 'negf' operand: {x}")
-
- def _binary_add(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.AddFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.AddIOp(lhs, rhs).result
- if _is_complex_type(lhs.type):
- return complex.AddOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
-
- def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.SubFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.SubIOp(lhs, rhs).result
- if _is_complex_type(lhs.type):
- return complex.SubOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
-
- def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.MulFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.MulIOp(lhs, rhs).result
- if _is_complex_type(lhs.type):
- return complex.MulOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
-
- def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.MaxFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.MaxSIOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
-
- def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.MaxFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.MaxUIOp(lhs, rhs).result
- raise NotImplementedError(
- "Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
-
- def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.MinFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.MinSIOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
-
- def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.MinFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.MinUIOp(lhs, rhs).result
- raise NotImplementedError(
- "Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
+ """Constructs a structured op body by evaluating assignments."""
+
+ def __init__(
+ self,
+ type_mapping: Dict[str, Type],
+ block_arg_mapping: Dict[str, Value],
+ fn_attr_mapping: Dict[str, str],
+ ):
+ self.type_mapping = type_mapping
+ self.block_arg_mapping = block_arg_mapping
+ self.fn_attr_mapping = fn_attr_mapping
+ self.yield_mapping = dict() # type: Dict[str, Value]
+
+ def assign(self, assignment: ScalarAssign):
+ if assignment.arg in self.yield_mapping:
+ raise ValueError(
+ f"Multiple assignments to the same argument are forbidden: "
+ f"{assignment}"
+ )
+ self.yield_mapping[assignment.arg] = self.expression(assignment.value)
+
+ def expression(self, expr: ScalarExpression) -> Value:
+ if expr.scalar_arg:
+ try:
+ return self.block_arg_mapping[expr.scalar_arg.arg]
+ except KeyError:
+ raise ValueError(
+ f"Argument {expr.scalar_arg.arg} is not bound for "
+ f"this structured op."
+ )
+ elif expr.scalar_const:
+ value_attr = Attribute.parse(expr.scalar_const.value)
+ return arith.ConstantOp(value_attr.type, value_attr).result
+ elif expr.scalar_index:
+ dim_attr = IntegerAttr.get(
+ IntegerType.get_signless(64), expr.scalar_index.dim
+ )
+ return linalg.IndexOp(dim_attr).result
+ elif expr.scalar_fn:
+ kind = expr.scalar_fn.kind.name.lower()
+ fn_name = expr.scalar_fn.fn_name
+ if expr.scalar_fn.attr_name:
+ fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
+ fn = self._get_function(f"_{kind}_{fn_name}")
+ operand_values = [
+ self.expression(operand) for operand in expr.scalar_fn.operands
+ ]
+ if expr.scalar_fn.kind == FunctionKind.TYPE:
+ operand_values = [expr.scalar_fn.type_var.name] + operand_values
+ return fn(*operand_values)
+ raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
+
+ def yield_outputs(self, *output_names: str):
+ output_values = []
+ for n in output_names:
+ try:
+ output_values.append(self.yield_mapping[n])
+ except KeyError:
+ raise ValueError(
+ f"Body assignments do not assign all outputs: " f"missing '{n}'"
+ )
+ linalg.YieldOp(output_values)
+
+ def _get_function(self, fn_name: str) -> Callable:
+ try:
+ fn = getattr(self, f"{fn_name}")
+ except AttributeError:
+ raise ValueError(f"Function '{fn_name}' is not a known function")
+ return fn
+
+ def _cast(
+ self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False
+ ) -> Value:
+ try:
+ to_type = self.type_mapping[type_var_name]
+ except KeyError:
+ raise ValueError(
+ f"Unbound type variable '{type_var_name}' ("
+ f"expected one of {self.type_mapping.keys()}"
+ )
+ if operand.type == to_type:
+ return operand
+ if _is_integer_type(to_type):
+ return self._cast_to_integer(to_type, operand, is_unsigned_cast)
+ elif _is_floating_point_type(to_type):
+ return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
+
+ def _cast_to_integer(
+ self, to_type: Type, operand: Value, is_unsigned_cast: bool
+ ) -> Value:
+ to_width = IntegerType(to_type).width
+ operand_type = operand.type
+ if _is_floating_point_type(operand_type):
+ if is_unsigned_cast:
+ return arith.FPToUIOp(to_type, operand).result
+ return arith.FPToSIOp(to_type, operand).result
+ if _is_index_type(operand_type):
+ return arith.IndexCastOp(to_type, operand).result
+ # Assume integer.
+ from_width = IntegerType(operand_type).width
+ if to_width > from_width:
+ if is_unsigned_cast:
+ return arith.ExtUIOp(to_type, operand).result
+ return arith.ExtSIOp(to_type, operand).result
+ elif to_width < from_width:
+ return arith.TruncIOp(to_type, operand).result
+ raise ValueError(
+ f"Unable to cast body expression from {operand_type} to " f"{to_type}"
+ )
+
+ def _cast_to_floating_point(
+ self, to_type: Type, operand: Value, is_unsigned_cast: bool
+ ) -> Value:
+ operand_type = operand.type
+ if _is_integer_type(operand_type):
+ if is_unsigned_cast:
+ return arith.UIToFPOp(to_type, operand).result
+ return arith.SIToFPOp(to_type, operand).result
+ # Assume FloatType.
+ to_width = _get_floating_point_width(to_type)
+ from_width = _get_floating_point_width(operand_type)
+ if to_width > from_width:
+ return arith.ExtFOp(to_type, operand).result
+ elif to_width < from_width:
+ return arith.TruncFOp(to_type, operand).result
+ raise ValueError(
+ f"Unable to cast body expression from {operand_type} to " f"{to_type}"
+ )
+
+ def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
+ return self._cast(type_var_name, operand, False)
+
+ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
+ return self._cast(type_var_name, operand, True)
+
+ def _unary_exp(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.ExpOp(x).result
+ raise NotImplementedError("Unsupported 'exp' operand: {x}")
+
+ def _unary_log(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.LogOp(x).result
+ raise NotImplementedError("Unsupported 'log' operand: {x}")
+
+ def _unary_abs(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.AbsFOp(x).result
+ raise NotImplementedError("Unsupported 'abs' operand: {x}")
+
+ def _unary_ceil(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.CeilOp(x).result
+ raise NotImplementedError("Unsupported 'ceil' operand: {x}")
+
+ def _unary_floor(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.FloorOp(x).result
+ raise NotImplementedError("Unsupported 'floor' operand: {x}")
+
+ def _unary_negf(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return arith.NegFOp(x).result
+ if _is_complex_type(x.type):
+ return complex.NegOp(x).result
+ raise NotImplementedError("Unsupported 'negf' operand: {x}")
+
+ def _binary_add(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return arith.AddFOp(lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ return arith.AddIOp(lhs, rhs).result
+ if _is_complex_type(lhs.type):
+ return complex.AddOp(lhs, rhs).result
+ raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
+
+ def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return arith.SubFOp(lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ return arith.SubIOp(lhs, rhs).result
+ if _is_complex_type(lhs.type):
+ return complex.SubOp(lhs, rhs).result
+ raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
+
+ def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return arith.MulFOp(lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ return arith.MulIOp(lhs, rhs).result
+ if _is_complex_type(lhs.type):
+ return complex.MulOp(lhs, rhs).result
+ raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
+
+ def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return arith.MaxFOp(lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ return arith.MaxSIOp(lhs, rhs).result
+ raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
+
+ def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return arith.MaxFOp(lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ return arith.MaxUIOp(lhs, rhs).result
+ raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
+
+ def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return arith.MinFOp(lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ return arith.MinSIOp(lhs, rhs).result
+ raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
+
+ def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return arith.MinFOp(lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ return arith.MinUIOp(lhs, rhs).result
+ raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
def _infer_structured_outs(
op_config: LinalgStructuredOpConfig,
- in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value],
+ in_arg_defs: Sequence[OperandDefConfig],
+ ins: Sequence[Value],
out_arg_defs: Sequence[OperandDefConfig],
- outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]:
- """Infers implicit outs and output types.
+ outs: Union[Sequence[Value], OpResultList],
+) -> Tuple[ValueList, List[Type]]:
+ """Infers implicit outs and output types.
- Respects existing contents of outs if not empty.
+ Respects existing contents of outs if not empty.
- Returns:
- normalized outs, output types
- """
- # If outs were explicitly provided, we accept them verbatim.
- if outs:
- return outs, [out.type for out in outs]
+ Returns:
+ normalized outs, output types
+ """
+ # If outs were explicitly provided, we accept them verbatim.
+ if outs:
+ return outs, [out.type for out in outs]
- raise NotImplementedError(f"Output tensor inference not yet supported for "
- "structured ops")
+ raise NotImplementedError(
+ f"Output tensor inference not yet supported for " "structured ops"
+ )
def _get_types_from_values(*values: Value) -> Sequence[Type]:
- types = []
- for v in values:
- types.append(v.type)
- return types
+ types = []
+ for v in values:
+ types.append(v.type)
+ return types
def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]:
- return [odc.operand_def.name for odc in operand_configs]
-
-
-def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type,
- type_mapping: Dict[str, Type],
- block_arg_types: Sequence[Type]):
- element_or_self_type = operand_type
- # Get the element type for tensor operands and the type itself for scalars.
- if operand_config.shape_map:
- try:
- element_or_self_type = ShapedType(operand_type).element_type
- except Exception as e:
- raise ValueError(f"Expected ShapedType but got {operand_type}") from e
- name = operand_config.type_var.name
- if name in type_mapping:
- if type_mapping[name] != element_or_self_type:
- raise ValueError(f"Cannot overwrite type mapping {name} = "
- f"{type_mapping[name]} by type {element_or_self_type}")
- type_mapping[name] = element_or_self_type
- block_arg_types.append(element_or_self_type)
+ return [odc.operand_def.name for odc in operand_configs]
+
+
+def _add_type_mapping(
+ operand_config: OperandDefConfig,
+ operand_type: Type,
+ type_mapping: Dict[str, Type],
+ block_arg_types: Sequence[Type],
+):
+ element_or_self_type = operand_type
+ # Get the element type for tensor operands and the type itself for scalars.
+ if operand_config.shape_map:
+ try:
+ element_or_self_type = ShapedType(operand_type).element_type
+ except Exception as e:
+ raise ValueError(f"Expected ShapedType but got {operand_type}") from e
+ name = operand_config.type_var.name
+ if name in type_mapping:
+ if type_mapping[name] != element_or_self_type:
+ raise ValueError(
+ f"Cannot overwrite type mapping {name} = "
+ f"{type_mapping[name]} by type {element_or_self_type}"
+ )
+ type_mapping[name] = element_or_self_type
+ block_arg_types.append(element_or_self_type)
def _is_complex_type(t: Type) -> bool:
- return ComplexType.isinstance(t)
+ return ComplexType.isinstance(t)
def _is_floating_point_type(t: Type) -> bool:
- # TODO: Create a FloatType in the Python API and implement the switch
- # there.
- return (F64Type.isinstance(t) or F32Type.isinstance(t) or
- F16Type.isinstance(t) or BF16Type.isinstance(t))
+ # TODO: Create a FloatType in the Python API and implement the switch
+ # there.
+ return (
+ F64Type.isinstance(t)
+ or F32Type.isinstance(t)
+ or F16Type.isinstance(t)
+ or BF16Type.isinstance(t)
+ )
def _is_integer_type(t: Type) -> bool:
- return IntegerType.isinstance(t)
+ return IntegerType.isinstance(t)
def _is_index_type(t: Type) -> bool:
- return IndexType.isinstance(t)
+ return IndexType.isinstance(t)
def _get_floating_point_width(t: Type) -> int:
- # TODO: Create a FloatType in the Python API and implement the switch
- # there.
- if F64Type.isinstance(t):
- return 64
- if F32Type.isinstance(t):
- return 32
- if F16Type.isinstance(t):
- return 16
- if BF16Type.isinstance(t):
- return 16
- raise NotImplementedError(f"Unhandled floating point type switch {t}")
+ # TODO: Create a FloatType in the Python API and implement the switch
+ # there.
+ if F64Type.isinstance(t):
+ return 64
+ if F32Type.isinstance(t):
+ return 32
+ if F16Type.isinstance(t):
+ return 16
+ if BF16Type.isinstance(t):
+ return 16
+ raise NotImplementedError(f"Unhandled floating point type switch {t}")
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
index aa894dc10954f..86853994c0a1c 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
@@ -30,123 +30,137 @@
class ScalarFn:
- """A type of ScalarExpression that applies a function."""
-
- def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
- attr_name: Optional[str], type_var: Optional["TypeVar"],
- operands: Sequence["ScalarExpression"]):
- if bool(fn_name) + bool(attr_name) != 1:
- raise ValueError("One of 'fn_name', 'attr_name' must be specified")
- self.kind = kind
- self.fn_name = fn_name
- self.attr_name = attr_name
- self.type_var = type_var
- self.operands = operands
-
- def expr(self) -> "ScalarExpression":
- return ScalarExpression(scalar_fn=self)
-
- def __repr__(self):
- name = self.fn_name if self.fn_name else self.attr_name
- return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
- f"operands=[{', '.join(self.operands)}])")
+ """A type of ScalarExpression that applies a function."""
+
+ def __init__(
+ self,
+ kind: "FunctionKind",
+ fn_name: Optional[str],
+ attr_name: Optional[str],
+ type_var: Optional["TypeVar"],
+ operands: Sequence["ScalarExpression"],
+ ):
+ if bool(fn_name) + bool(attr_name) != 1:
+ raise ValueError("One of 'fn_name', 'attr_name' must be specified")
+ self.kind = kind
+ self.fn_name = fn_name
+ self.attr_name = attr_name
+ self.type_var = type_var
+ self.operands = operands
+
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(scalar_fn=self)
+
+ def __repr__(self):
+ name = self.fn_name if self.fn_name else self.attr_name
+ return (
+ f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
+ f"operands=[{', '.join(self.operands)}])"
+ )
class ScalarArg:
- """A type of ScalarExpression that references a named argument."""
+ """A type of ScalarExpression that references a named argument."""
- def __init__(self, arg: str):
- self.arg = arg
+ def __init__(self, arg: str):
+ self.arg = arg
- def expr(self) -> "ScalarExpression":
- return ScalarExpression(scalar_arg=self)
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(scalar_arg=self)
- def __repr__(self):
- return f"(ScalarArg({self.arg})"
+ def __repr__(self):
+ return f"(ScalarArg({self.arg})"
class ScalarConst:
- """A type of ScalarExpression representing a constant."""
+ """A type of ScalarExpression representing a constant."""
- def __init__(self, value: str):
- self.value = value
+ def __init__(self, value: str):
+ self.value = value
- def expr(self) -> "ScalarExpression":
- return ScalarExpression(scalar_const=self)
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(scalar_const=self)
- def __repr__(self):
- return f"(ScalarConst({self.value})"
+ def __repr__(self):
+ return f"(ScalarConst({self.value})"
class ScalarIndex:
- """A type of ScalarExpression accessing an iteration index."""
+ """A type of ScalarExpression accessing an iteration index."""
- def __init__(self, dim: int):
- self.dim = dim
+ def __init__(self, dim: int):
+ self.dim = dim
- def expr(self) -> "ScalarExpression":
- return ScalarExpression(scalar_index=self)
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(scalar_index=self)
- def __repr__(self):
- return f"(ScalarIndex({self.dim})"
+ def __repr__(self):
+ return f"(ScalarIndex({self.dim})"
class ScalarExpression(YAMLObject):
- """An expression on scalar values.
-
- Can be one of:
- - ScalarFn
- - ScalarArg
- - ScalarConst
- - ScalarIndex
- """
- yaml_tag = "!ScalarExpression"
-
- def __init__(self,
- scalar_fn: Optional[ScalarFn] = None,
- scalar_arg: Optional[ScalarArg] = None,
- scalar_const: Optional[ScalarConst] = None,
- scalar_index: Optional[ScalarIndex] = None):
- if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
- bool(scalar_index)) != 1:
- raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
- "'scalar_index' must be specified")
- self.scalar_fn = scalar_fn
- self.scalar_arg = scalar_arg
- self.scalar_const = scalar_const
- self.scalar_index = scalar_index
-
- def to_yaml_custom_dict(self):
- if self.scalar_fn:
- scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
- if self.scalar_fn.fn_name:
- scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
- if self.scalar_fn.attr_name:
- scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
- if self.scalar_fn.type_var:
- scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
- scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
- return dict(scalar_fn=scalar_fn_dict)
- elif self.scalar_arg:
- return dict(scalar_arg=self.scalar_arg.arg)
- elif self.scalar_const:
- return dict(scalar_const=self.scalar_const.value)
- elif self.scalar_index:
- return dict(scalar_index=self.scalar_index.dim)
- else:
- raise ValueError(f"Unexpected ScalarExpression type: {self}")
+ """An expression on scalar values.
+
+ Can be one of:
+ - ScalarFn
+ - ScalarArg
+ - ScalarConst
+ - ScalarIndex
+ """
+
+ yaml_tag = "!ScalarExpression"
+
+ def __init__(
+ self,
+ scalar_fn: Optional[ScalarFn] = None,
+ scalar_arg: Optional[ScalarArg] = None,
+ scalar_const: Optional[ScalarConst] = None,
+ scalar_index: Optional[ScalarIndex] = None,
+ ):
+ if (
+ bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index)
+ ) != 1:
+ raise ValueError(
+ "One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
+ "'scalar_index' must be specified"
+ )
+ self.scalar_fn = scalar_fn
+ self.scalar_arg = scalar_arg
+ self.scalar_const = scalar_const
+ self.scalar_index = scalar_index
+
+ def to_yaml_custom_dict(self):
+ if self.scalar_fn:
+ scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
+ if self.scalar_fn.fn_name:
+ scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
+ if self.scalar_fn.attr_name:
+ scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
+ if self.scalar_fn.type_var:
+ scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
+ scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
+ return dict(scalar_fn=scalar_fn_dict)
+ elif self.scalar_arg:
+ return dict(scalar_arg=self.scalar_arg.arg)
+ elif self.scalar_const:
+ return dict(scalar_const=self.scalar_const.value)
+ elif self.scalar_index:
+ return dict(scalar_index=self.scalar_index.dim)
+ else:
+ raise ValueError(f"Unexpected ScalarExpression type: {self}")
class ScalarAssign(YAMLObject):
- """An assignment to a named argument (LHS of a comprehension)."""
- yaml_tag = "!ScalarAssign"
+ """An assignment to a named argument (LHS of a comprehension)."""
+
+ yaml_tag = "!ScalarAssign"
- def __init__(self, arg: str, value: ScalarExpression):
- self.arg = arg
- self.value = value
+ def __init__(self, arg: str, value: ScalarExpression):
+ self.arg = arg
+ self.value = value
- def to_yaml_custom_dict(self):
- return dict(arg=self.arg, value=self.value)
+ def to_yaml_custom_dict(self):
+ return dict(arg=self.arg, value=self.value)
- def __repr__(self):
- return f"ScalarAssign({self.arg}, {self.value})"
+ def __repr__(self):
+ return f"ScalarAssign({self.arg}, {self.value})"
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py
index ddac87287e617..4f36029b7fb88 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py
@@ -21,13 +21,11 @@
__all__ = [
"TypeVar",
"TV",
-
# Predefined types.
"I32",
"I64",
"F32",
"F64",
-
# TypeVar aliases.
"T",
"U",
@@ -36,34 +34,34 @@
class TypeVar:
- """A replaceable type variable.
+ """A replaceable type variable.
- Type variables are uniqued by name.
- """
- ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"]
+ Type variables are uniqued by name.
+ """
- def __new__(cls, name: str):
- existing = cls.ALL_TYPEVARS.get(name)
- if existing is not None:
- return existing
- new = super().__new__(cls)
- new.name = name
- cls.ALL_TYPEVARS[name] = new
- return new
+ ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"]
- def __repr__(self):
- return f"TypeVar({self.name})"
+ def __new__(cls, name: str):
+ existing = cls.ALL_TYPEVARS.get(name)
+ if existing is not None:
+ return existing
+ new = super().__new__(cls)
+ new.name = name
+ cls.ALL_TYPEVARS[name] = new
+ return new
- @classmethod
- def create_expando(cls):
- """Create an expando class that creates unique type vars on attr access."""
+ def __repr__(self):
+ return f"TypeVar({self.name})"
- class ExpandoTypeVars:
+ @classmethod
+ def create_expando(cls):
+ """Create an expando class that creates unique type vars on attr access."""
- def __getattr__(self, n):
- return cls(n)
+ class ExpandoTypeVars:
+ def __getattr__(self, n):
+ return cls(n)
- return ExpandoTypeVars()
+ return ExpandoTypeVars()
# Expando access via TV.foo
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
index 1945eea53f80b..1672656b3a1f8 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
@@ -6,11 +6,12 @@
import sys
try:
- import yaml
+ import yaml
except ModuleNotFoundError as e:
- raise ModuleNotFoundError(
- f"This tool requires PyYAML but it was not installed. "
- f"Recommend: {sys.executable} -m pip install PyYAML") from e
+ raise ModuleNotFoundError(
+ f"This tool requires PyYAML but it was not installed. "
+ f"Recommend: {sys.executable} -m pip install PyYAML"
+ ) from e
__all__ = [
"yaml_dump",
@@ -20,35 +21,33 @@
class YAMLObject(yaml.YAMLObject):
+ @classmethod
+ def to_yaml(cls, dumper, self):
+ """Default to a custom dictionary mapping."""
+ return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
- @classmethod
- def to_yaml(cls, dumper, self):
- """Default to a custom dictionary mapping."""
- return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
+ def to_yaml_custom_dict(self):
+ raise NotImplementedError()
- def to_yaml_custom_dict(self):
- raise NotImplementedError()
-
- def as_linalg_yaml(self):
- return yaml_dump(self)
+ def as_linalg_yaml(self):
+ return yaml_dump(self)
def multiline_str_representer(dumper, data):
- if len(data.splitlines()) > 1:
- return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
- else:
- return dumper.represent_scalar('tag:yaml.org,2002:str', data)
+ if len(data.splitlines()) > 1:
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
+ else:
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data)
yaml.add_representer(str, multiline_str_representer)
def yaml_dump(data, sort_keys=False, **kwargs):
- return yaml.dump(data, sort_keys=sort_keys, **kwargs)
+ return yaml.dump(data, sort_keys=sort_keys, **kwargs)
def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
- return yaml.dump_all(data,
- sort_keys=sort_keys,
- explicit_start=explicit_start,
- **kwargs)
+ return yaml.dump_all(
+ data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
+ )
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 9c96868c10f1a..bac22a2e59db1 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -7,99 +7,113 @@
@linalg_structured_op
-def copy(I=TensorDef(T1),
- O=TensorDef(U, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- """Copies the tensor elementwise.
+def copy(
+ I=TensorDef(T1),
+ O=TensorDef(U, output=True),
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ """Copies the tensor elementwise.
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- O[None] = cast(U, I[None])
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ O[None] = cast(U, I[None])
@linalg_structured_op
-def elemwise_unary(I=TensorDef(T1),
- O=TensorDef(U, output=True),
- fun=UnaryFnAttrDef(default=UnaryFn.exp),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- """Applies the unary function fun elementwise.
+def elemwise_unary(
+ I=TensorDef(T1),
+ O=TensorDef(U, output=True),
+ fun=UnaryFnAttrDef(default=UnaryFn.exp),
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ """Applies the unary function fun elementwise.
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- O[None] = fun(cast(U, I[None]))
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ O[None] = fun(cast(U, I[None]))
@linalg_structured_op
-def elemwise_binary(lhs=TensorDef(T1),
- rhs=TensorDef(T2),
- O=TensorDef(U, output=True),
- fun=BinaryFnAttrDef(default=BinaryFn.add),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- """Applies the binary function fun elementwise.
+def elemwise_binary(
+ lhs=TensorDef(T1),
+ rhs=TensorDef(T2),
+ O=TensorDef(U, output=True),
+ fun=BinaryFnAttrDef(default=BinaryFn.add),
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ """Applies the binary function fun elementwise.
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
@linalg_structured_op
-def matmul(A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- """Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+def matmul(
+ A=TensorDef(T1, S.M, S.K),
+ B=TensorDef(T2, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True),
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ """Performs a matrix multiplication of two 2D inputs.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
@linalg_structured_op
-def matmul_unsigned(A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True)):
- """Performs an unsigned matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
- U, B[D.k, D.n])
+def matmul_unsigned(
+ A=TensorDef(T1, S.M, S.K),
+ B=TensorDef(T2, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True),
+):
+ """Performs an unsigned matrix multiplication of two 2D inputs.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
+ U, B[D.k, D.n]
+ )
@linalg_structured_op
-def quantized_matmul(A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.K, S.N),
- AZp=ScalarDef(I32),
- BZp=ScalarDef(I32),
- C=TensorDef(U, S.M, S.N, output=True)):
- """Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. The quantized variant
- includes zero-point adjustments for the left and right operands of the
- matmul.
- """
- domain(D.m, D.n, D.k)
- C[D.m,
- D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) -
- TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(U, B[D.k, D.n]) -
- TypeFn.cast_signed(U, BZp))
+def quantized_matmul(
+ A=TensorDef(T1, S.M, S.K),
+ B=TensorDef(T2, S.K, S.N),
+ AZp=ScalarDef(I32),
+ BZp=ScalarDef(I32),
+ C=TensorDef(U, S.M, S.N, output=True),
+):
+ """Performs a matrix multiplication of two 2D inputs.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. The quantized variant
+ includes zero-point adjustments for the left and right operands of the
+ matmul.
+ """
+ domain(D.m, D.n, D.k)
+ C[D.m, D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
+ TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp)
+ )
@linalg_structured_op
-def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
- rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
- accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)):
- """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
+def mmt4d(
+ lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
+ rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
+ accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True),
+):
+ """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
Differences from linalg.matmul:
* The right hand side is transposed, whence the 't' in 'mmt'.
@@ -108,1132 +122,1201 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
whence the 2+2=4 dimensions. The inner tile dimensions are identified with
'0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
as: MxK tiles, each of shape M0xK0.
- """
- domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
- implements(ContractionOpInterface)
- accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
- TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed(
- TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
+ """
+ domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
+ implements(ContractionOpInterface)
+ accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
+ TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]
+ ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
@linalg_structured_op
-def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, Batch, S.M, S.N, output=True)):
- """Performs a batched matrix multiplication of two 3D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m,
- D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.k, D.n])
+def batch_matmul(
+ A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+ """Performs a batched matrix multiplication of two 3D inputs.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+ U, B[D.b, D.k, D.n]
+ )
@linalg_structured_op
-def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K, S.N),
- AZp=ScalarDef(I32),
- BZp=ScalarDef(I32),
- C=TensorDef(U, Batch, S.M, S.N, output=True)):
- """Performs a batched matrix multiplication of two 3D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. The quantized variant
- includes zero-point adjustments for the left and right operands of the
- matmul.
- """
- domain(D.b, D.m, D.n, D.k)
- C[D.b, D.m, D.n] += (TypeFn.cast_signed(U, A[D.b, D.m, D.k]) -
- TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
- U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
+def quantized_batch_matmul(
+ A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ AZp=ScalarDef(I32),
+ BZp=ScalarDef(I32),
+ C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+ """Performs a batched matrix multiplication of two 3D inputs.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. The quantized variant
+ includes zero-point adjustments for the left and right operands of the
+ matmul.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ C[D.b, D.m, D.n] += (
+ TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)
+ ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
@linalg_structured_op
-def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True)):
- """Performs a batch-reduce matrix multiplication of two 3D inputs.
- The partial multiplication results are reduced into a 2D output.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += TypeFn.cast_signed(
- U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n]))
+def batch_reduce_matmul(
+ A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True),
+):
+ """Performs a batch-reduce matrix multiplication of two 3D inputs.
+ The partial multiplication results are reduced into a 2D output.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += TypeFn.cast_signed(
+ U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
+ )
@linalg_structured_op
-def matvec(A=TensorDef(T1, S.M, S.N),
- y=TensorDef(T2, S.N),
- x=TensorDef(U, S.M, output=True)):
- """Performs a matrix-vector multiplication.
+def matvec(
+ A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True)
+):
+ """Performs a matrix-vector multiplication.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n)
- implements(ContractionOpInterface)
- x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.m, D.n)
+ implements(ContractionOpInterface)
+ x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
@linalg_structured_op
-def vecmat(y=TensorDef(T1, S.M),
- A=TensorDef(T2, S.M, S.N),
- x=TensorDef(U, S.N, output=True)):
- """Performs a vector-matrix multiplication.
+def vecmat(
+ y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True)
+):
+ """Performs a vector-matrix multiplication.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.n, D.m)
- implements(ContractionOpInterface)
- x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.n, D.m)
+ implements(ContractionOpInterface)
+ x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
@linalg_structured_op
-def batch_matvec(A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K),
- C=TensorDef(U, Batch, S.M, output=True)):
- """Performs a batched matrix-vector multiplication.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.k])
+def batch_matvec(
+ A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.K),
+ C=TensorDef(U, Batch, S.M, output=True),
+):
+ """Performs a batched matrix-vector multiplication.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+ U, B[D.b, D.k]
+ )
@linalg_structured_op
-def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
- output=True)):
- """Performs a dot product of two vectors to a scalar result.
+def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
+ """Performs a dot product of two vectors to a scalar result.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ContractionOpInterface)
- C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ContractionOpInterface)
+ C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
@linalg_structured_op
-def conv_1d(I=TensorDef(T1, S.OW + S.KW),
- K=TensorDef(T2, S.KW),
- O=TensorDef(U, S.OW, output=True)):
- """Performs 1-D convolution with no channels.
+def conv_1d(
+ I=TensorDef(T1, S.OW + S.KW),
+ K=TensorDef(T2, S.KW),
+ O=TensorDef(U, S.OW, output=True),
+):
+ """Performs 1-D convolution with no channels.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.ow, D.kw)
- O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(
- U, K[D.kw])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.ow, D.kw)
+ O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kw])
@linalg_structured_op
-def conv_2d(I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
- K=TensorDef(T2, S.KH, S.KW),
- O=TensorDef(U, S.OH, S.OW, output=True)):
- """Performs 2-D convolution with no channels.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.oh, D.ow, D.kh, D.kw)
- O[D.oh, D.ow] += TypeFn.cast_signed(
- U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw])
+def conv_2d(
+ I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
+ K=TensorDef(T2, S.KH, S.KW),
+ O=TensorDef(U, S.OH, S.OW, output=True),
+):
+ """Performs 2-D convolution with no channels.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.oh, D.ow, D.kh, D.kw)
+ O[D.oh, D.ow] += TypeFn.cast_signed(
+ U, I[D.oh + D.kh, D.ow + D.kw]
+ ) * TypeFn.cast_signed(U, K[D.kh, D.kw])
@linalg_structured_op
-def conv_3d(I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
- K=TensorDef(T2, S.KD, S.KH, S.KW),
- O=TensorDef(U, S.OD, S.OH, S.OW, output=True)):
- """Performs 3-D convolution with no channels.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
- O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
- U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(
- U, K[D.kd, D.kh, D.kw])
+def conv_3d(
+ I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
+ K=TensorDef(T2, S.KD, S.KH, S.KW),
+ O=TensorDef(U, S.OD, S.OH, S.OW, output=True),
+):
+ """Performs 3-D convolution with no channels.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
+ O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
+ U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]
+ ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw])
@linalg_structured_op
-def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KW, S.C, S.F),
- O=TensorDef(U, S.N, S.OW, S.F, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs 1-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.ow, D.f, D.kw, D.c)
- O[D.n, D.ow, D.f] += TypeFn.cast_signed(
- U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
- U, K[D.kw, D.c, D.f])
+def conv_1d_nwc_wcf(
+ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KW, S.C, S.F),
+ O=TensorDef(U, S.N, S.OW, S.F, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs 1-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.ow, D.f, D.kw, D.c)
+ O[D.n, D.ow, D.f] += TypeFn.cast_signed(
+ U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]
+ ) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f])
@linalg_structured_op
-def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.F, S.C, S.KW),
- O=TensorDef(U, S.N, S.F, S.OW, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs 1-D convolution.
-
- Layout:
- * Input: NCW.
- * Kernel: FCW.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.f, D.ow, D.c, D.kw)
- O[D.n, D.f, D.ow] += TypeFn.cast_signed(
- U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
- U, K[D.f, D.c, D.kw])
+def conv_1d_ncw_fcw(
+ I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.F, S.C, S.KW),
+ O=TensorDef(U, S.N, S.F, S.OW, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs 1-D convolution.
+
+ Layout:
+ * Input: NCW.
+ * Kernel: FCW.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.f, D.ow, D.c, D.kw)
+ O[D.n, D.f, D.ow] += TypeFn.cast_signed(
+ U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]
+ ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kw])
@linalg_structured_op
-def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
- O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs 2-D convolution.
-
- Layout:
- * Input: NHWC.
- * Kernel: HWCF.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
- D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
+def conv_2d_nhwc_hwcf(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs 2-D convolution.
+
+ Layout:
+ * Input: NHWC.
+ * Kernel: HWCF.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+ O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+ ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
@linalg_structured_op
-def conv_2d_nhwc_fhwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
- O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs 2-D convolution.
-
- Layout:
- * Input: NHWC.
- * Kernel: FHWC.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
- D.c]) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c])
+def conv_2d_nhwc_fhwc(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs 2-D convolution.
+
+ Layout:
+ * Input: NHWC.
+ * Kernel: FHWC.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+ O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+ ) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c])
@linalg_structured_op
-def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
- IZp=ScalarDef(I32),
- KZp=ScalarDef(I32),
- O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs 2-D convolution with zero point offsets.
-
- Layout:
- * Input: NHWC.
- * Kernel: HWCF.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. This includes the zero
- point offsets common to quantized operations.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow,
- D.f] += (TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
- TypeFn.cast_signed(U, IZp)) * (TypeFn.cast_signed(
- U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
+def conv_2d_nhwc_hwcf_q(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+ IZp=ScalarDef(I32),
+ KZp=ScalarDef(I32),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs 2-D convolution with zero point offsets.
+
+ Layout:
+ * Input: NHWC.
+ * Kernel: HWCF.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. This includes the zero
+ point offsets common to quantized operations.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+ O[D.n, D.oh, D.ow, D.f] += (
+ TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+ )
+ - TypeFn.cast_signed(U, IZp)
+ ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
@linalg_structured_op
-def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
- O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs 2-D convolution.
-
- Layout:
- * Input: NCHW.
- * Kernel: FCHW.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
- O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
- U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
- D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
+def conv_2d_nchw_fchw(
+ I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
+ O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs 2-D convolution.
+
+ Layout:
+ * Input: NCHW.
+ * Kernel: FCHW.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
+ U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+ ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
@linalg_structured_op
-def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
- O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs 2-D grouped convolution.
-
- Layout:
- * Input: NGCHW.
- * Kernel: FGCHW.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
- O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
- U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
- D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
+def conv_2d_ngchw_fgchw(
+ I=TensorDef(
+ T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
+ ),
+ K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
+ O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs 2-D grouped convolution.
+
+ Layout:
+ * Input: NGCHW.
+ * Kernel: FGCHW.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
+ U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+ ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
@linalg_structured_op
-def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
- O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
- strides=IndexAttrDef(S.SD,
- S.SH,
- S.SW,
- default=[1, 1, 1]),
- dilations=IndexAttrDef(S.DD,
- S.DH,
- S.DW,
- default=[1, 1, 1])):
- """Performs 3-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
- O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
- U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
- U, K[D.kd, D.kh, D.kw, D.c, D.f])
+def conv_3d_ndhwc_dhwcf(
+ I=TensorDef(
+ T1,
+ S.N,
+ S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW,
+ S.C,
+ ),
+ K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
+ O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+ """Performs 3-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+ O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.c,
+ ],
+ ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
@linalg_structured_op
-def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
- IZp=ScalarDef(I32),
- KZp=ScalarDef(I32),
- O=TensorDef(U,
- S.N,
- S.OD,
- S.OH,
- S.OW,
- S.F,
- output=True),
- strides=IndexAttrDef(S.SD,
- S.SH,
- S.SW,
- default=[1, 1, 1]),
- dilations=IndexAttrDef(S.DD,
- S.DH,
- S.DW,
- default=[1, 1, 1])):
- """Performs 3-D convolution with zero point offsets.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. This includes the zero
- point offsets common to quantized operations.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
- O[D.n, D.od, D.oh, D.ow, D.f] += (TypeFn.cast_signed(
- U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW, D.c]) - TypeFn.cast_signed(U, IZp)) * (
- TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) -
- TypeFn.cast_signed(U, KZp))
+def conv_3d_ndhwc_dhwcf_q(
+ I=TensorDef(
+ T1,
+ S.N,
+ S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW,
+ S.C,
+ ),
+ K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
+ IZp=ScalarDef(I32),
+ KZp=ScalarDef(I32),
+ O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+ """Performs 3-D convolution with zero point offsets.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. This includes the zero
+ point offsets common to quantized operations.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+ O[D.n, D.od, D.oh, D.ow, D.f] += (
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.c,
+ ],
+ )
+ - TypeFn.cast_signed(U, IZp)
+ ) * (
+ TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
+ - TypeFn.cast_signed(U, KZp)
+ )
@linalg_structured_op
-def conv_3d_ncdhw_fcdhw(I=TensorDef(T1, S.N, S.C, S.OD * S.SD + S.KD * S.DD,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW),
- O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SD,
- S.SH,
- S.SW,
- default=[1, 1, 1]),
- dilations=IndexAttrDef(S.DD,
- S.DH,
- S.DW,
- default=[1, 1, 1])):
- """Performs 3-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
- O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
- U, I[D.n, D.c, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
- U, K[D.f, D.c, D.kd, D.kh, D.kw])
+def conv_3d_ncdhw_fcdhw(
+ I=TensorDef(
+ T1,
+ S.N,
+ S.C,
+ S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW,
+ ),
+ K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW),
+ O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+ """Performs 3-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+ O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.c,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ ],
+ ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
@linalg_structured_op
-def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
- S.IC),
- K=TensorDef(T2, S.KW, S.IC),
- O=TensorDef(U, S.N, S.OW, S.IC, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs depth-wise 1-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most depthwise convolutions.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.ow, D.ic, D.kw)
- O[D.n, D.ow, D.ic] += \
- TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
- TypeFn.cast_signed(U, K[D.kw, D.ic])
+def depthwise_conv_1d_nwc_wc(
+ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
+ K=TensorDef(T2, S.KW, S.IC),
+ O=TensorDef(U, S.N, S.OW, S.IC, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs depth-wise 1-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. Multiplier is set to 1
+ which is a special case for most depthwise convolutions.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.ow, D.ic, D.kw)
+ O[D.n, D.ow, D.ic] += TypeFn.cast_signed(
+ U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]
+ ) * TypeFn.cast_signed(U, K[D.kw, D.ic])
@linalg_structured_op
-def depthwise_conv_1d_ncw_cw(I=TensorDef(T1, S.N, S.IC,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.IC, S.KW),
- O=TensorDef(U, S.N, S.IC, S.OW, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs depth-wise 1-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most depthwise convolutions.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.ow, D.ic, D.kw)
- O[D.n, D.ic, D.ow] += \
- TypeFn.cast_signed(U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]) * \
- TypeFn.cast_signed(U, K[D.ic, D.kw])
+def depthwise_conv_1d_ncw_cw(
+ I=TensorDef(T1, S.N, S.IC, S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.IC, S.KW),
+ O=TensorDef(U, S.N, S.IC, S.OW, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs depth-wise 1-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. Multiplier is set to 1
+ which is a special case for most depthwise convolutions.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.ow, D.ic, D.kw)
+ O[D.n, D.ic, D.ow] += TypeFn.cast_signed(
+ U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]
+ ) * TypeFn.cast_signed(U, K[D.ic, D.kw])
@linalg_structured_op
-def depthwise_conv_1d_nwc_wcm(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
- S.IC),
- K=TensorDef(T2, S.KW, S.IC, S.CM),
- O=TensorDef(U, S.N, S.OW, S.IC, S.CM,
- output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs depth-wise 1-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.ow, D.ic, D.cm, D.kw)
- O[D.n, D.ow, D.ic, D.cm] += \
- TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
- TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm])
+def depthwise_conv_1d_nwc_wcm(
+ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
+ K=TensorDef(T2, S.KW, S.IC, S.CM),
+ O=TensorDef(U, S.N, S.OW, S.IC, S.CM, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs depth-wise 1-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.ow, D.ic, D.cm, D.kw)
+ O[D.n, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
+ U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]
+ ) * TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm])
@linalg_structured_op
-def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.IC),
- K=TensorDef(T2, S.KH, S.KW, S.IC),
- O=TensorDef(U,
- S.N,
- S.OH,
- S.OW,
- S.IC,
- output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH,
- S.DW,
- default=[1, 1])):
- """Performs depth-wise 2-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most depthwise convolutions.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
- O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
- D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
+def depthwise_conv_2d_nhwc_hwc(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+ K=TensorDef(T2, S.KH, S.KW, S.IC),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs depth-wise 2-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. Multiplier is set to 1
+ which is a special case for most depthwise convolutions.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+ O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+ ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
@linalg_structured_op
-def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.IC, S.KH, S.KW),
- O=TensorDef(U,
- S.N,
- S.IC,
- S.OH,
- S.OW,
- output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH,
- S.DW,
- default=[1, 1])):
- """Performs depth-wise 2-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most depthwise convolutions.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
- O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
- U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
- D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
+def depthwise_conv_2d_nchw_chw(
+ I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.IC, S.KH, S.KW),
+ O=TensorDef(U, S.N, S.IC, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs depth-wise 2-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. Multiplier is set to 1
+ which is a special case for most depthwise convolutions.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+ O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
+ U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+ ) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
@linalg_structured_op
-def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.IC),
- K=TensorDef(T2, S.KH, S.KW, S.IC),
- IZp=ScalarDef(I32),
- KZp=ScalarDef(I32),
- O=TensorDef(U,
- S.N,
- S.OH,
- S.OW,
- S.IC,
- output=True),
- strides=IndexAttrDef(S.SH,
- S.SW,
- default=[1, 1]),
- dilations=IndexAttrDef(S.DH,
- S.DW,
- default=[1, 1])):
- """Performs depth-wise 2-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
- O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
- TypeFn.cast_signed(U, IZp)) *
- (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) -
- TypeFn.cast_signed(U, KZp)))
+def depthwise_conv_2d_nhwc_hwc_q(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+ K=TensorDef(T2, S.KH, S.KW, S.IC),
+ IZp=ScalarDef(I32),
+ KZp=ScalarDef(I32),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs depth-wise 2-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+ O[D.n, D.oh, D.ow, D.ic] += (
+ TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+ )
+ - TypeFn.cast_signed(U, IZp)
+ ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast_signed(U, KZp))
@linalg_structured_op
-def depthwise_conv_2d_nhwc_hwcm(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.IC),
- K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
- O=TensorDef(U,
- S.N,
- S.OH,
- S.OW,
- S.IC,
- S.CM,
- output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1,
- 1]),
- dilations=IndexAttrDef(S.DH,
- S.DW,
- default=[1, 1])):
- """Performs depth-wise 2-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
- O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
- D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
+def depthwise_conv_2d_nhwc_hwcm(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+ K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs depth-wise 2-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
+ O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+ ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
@linalg_structured_op
-def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.IC),
- K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
- IZp=ScalarDef(I32),
- KZp=ScalarDef(I32),
- O=TensorDef(U,
- S.N,
- S.OH,
- S.OW,
- S.IC,
- S.CM,
- output=True),
- strides=IndexAttrDef(S.SH,
- S.SW,
- default=[1, 1]),
- dilations=IndexAttrDef(S.DH,
- S.DW,
- default=[1, 1])):
- """Performs depth-wise 2-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
- O[D.n, D.oh, D.ow, D.ic,
- D.cm] += ((TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
- TypeFn.cast_signed(U, IZp)) *
- (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) -
- TypeFn.cast_signed(U, KZp)))
+def depthwise_conv_2d_nhwc_hwcm_q(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+ K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
+ IZp=ScalarDef(I32),
+ KZp=ScalarDef(I32),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs depth-wise 2-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
+ O[D.n, D.oh, D.ow, D.ic, D.cm] += (
+ TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+ )
+ - TypeFn.cast_signed(U, IZp)
+ ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast_signed(U, KZp))
@linalg_structured_op
-def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.IC),
- K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
- O=TensorDef(U,
- S.N,
- S.OD,
- S.OH,
- S.OW,
- output=True),
- strides=IndexAttrDef(S.SD,
- S.SH,
- S.SW,
- default=[1, 1, 1]),
- dilations=IndexAttrDef(S.DD,
- S.DH,
- S.DW,
- default=[1, 1, 1])):
- """Performs depth-wise 3-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most depthwise convolutions.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
- O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
- U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed(
- U, K[D.kd, D.kh, D.kw, D.ic])
+def depthwise_conv_3d_ndhwc_dhwc(
+ I=TensorDef(
+ T1,
+ S.N,
+ S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW,
+ S.IC,
+ ),
+ K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
+ O=TensorDef(U, S.N, S.OD, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+ """Performs depth-wise 3-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. Multiplier is set to 1
+ which is a special case for most depthwise convolutions.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
+ O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.ic,
+ ],
+ ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
@linalg_structured_op
-def depthwise_conv_3d_ncdhw_cdhw(I=TensorDef(T1, S.N, S.IC,
- S.OD * S.SD + S.KD * S.DD,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW),
- O=TensorDef(U,
- S.N,
- S.IC,
- S.OD,
- S.OH,
- S.OW,
- output=True),
- strides=IndexAttrDef(S.SD,
- S.SH,
- S.SW,
- default=[1, 1, 1]),
- dilations=IndexAttrDef(S.DD,
- S.DH,
- S.DW,
- default=[1, 1, 1])):
- """Performs depth-wise 3-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most depthwise convolutions.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
- O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
- U, I[D.n, D.ic, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
- U, K[D.ic, D.kd, D.kh, D.kw])
+def depthwise_conv_3d_ncdhw_cdhw(
+ I=TensorDef(
+ T1,
+ S.N,
+ S.IC,
+ S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW,
+ ),
+ K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW),
+ O=TensorDef(U, S.N, S.IC, S.OD, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+ """Performs depth-wise 3-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. Multiplier is set to 1
+ which is a special case for most depthwise convolutions.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
+ O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.ic,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ ],
+ ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
@linalg_structured_op
-def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N,
- S.OD * S.SD + S.KD * S.DD,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.IC),
- K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
- O=TensorDef(U,
- S.N,
- S.OD,
- S.OH,
- S.OW,
- S.CM,
- output=True),
- strides=IndexAttrDef(S.SD,
- S.SH,
- S.SW,
- default=[1, 1, 1]),
- dilations=IndexAttrDef(S.DD,
- S.DH,
- S.DW,
- default=[1, 1, 1])):
- """Performs depth-wise 3-D convolution.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
- O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
- U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed(
- U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
+def depthwise_conv_3d_ndhwc_dhwcm(
+ I=TensorDef(
+ T1,
+ S.N,
+ S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW,
+ S.IC,
+ ),
+ K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
+ O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, output=True),
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+ """Performs depth-wise 3-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
+ O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.ic,
+ ],
+ ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
@linalg_structured_op
-def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs sum pooling.
-
- Layout:
- * Input: NHWC.
- * Kernel: HW.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
- O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+def pooling_nhwc_sum(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs sum pooling.
+
+ Layout:
+ * Input: NHWC.
+ * Kernel: HW.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+ )
@linalg_structured_op
-def pooling_nchw_sum(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs sum pooling.
-
- Layout:
- * Input: NCHW.
- * Kernel: HW.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
- O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed(
- U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW])
+def pooling_nchw_sum(
+ I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs sum pooling.
+
+ Layout:
+ * Input: NCHW.
+ * Kernel: HW.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
+ O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed(
+ U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+ )
@linalg_structured_op
-def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs max pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_max(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs max pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](
+ TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+ )
+ )
@linalg_structured_op
-def pooling_nhwc_max_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2,
- S.KH,
- S.KW,
- index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1,
- 1])):
- """Performs unsigned max pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
- O[D.n, D.oh, D.ow,
- D.c] = ReduceFn.max_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_max_unsigned(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs unsigned max pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
+ TypeFn.cast_unsigned(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+ )
+ )
@linalg_structured_op
-def pooling_nchw_max(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs max pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
- O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
- U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,]))
+def pooling_nchw_max(
+ I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs max pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
+ O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.c,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ ],
+ )
+ )
@linalg_structured_op
-def pooling_nhwc_min(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- """Performs min pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_min(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs min pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](
+ TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+ )
+ )
@linalg_structured_op
-def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2,
- S.KH,
- S.KW,
- index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1,
- 1])):
- """Performs unsigned min pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
- O[D.n, D.oh, D.ow,
- D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_min_unsigned(
+ I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs unsigned min pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
+ TypeFn.cast_unsigned(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+ )
+ )
+
@linalg_structured_op
-def pooling_nwc_sum(I=TensorDef(T1, S.N,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KW, index_dims=[D.kw]),
- O=TensorDef(U, S.N, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs sum pooling.
-
- Layout:
- * Input: NWC.
- * Kernel: W.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.ow, D.c, D.kw)
- O[D.n, D.ow, D.c] += TypeFn.cast_signed(
- U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+def pooling_nwc_sum(
+ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+ O=TensorDef(U, S.N, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs sum pooling.
+
+ Layout:
+ * Input: NWC.
+ * Kernel: W.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.ow, D.c, D.kw)
+ O[D.n, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
@linalg_structured_op
-def pooling_ncw_sum(I=TensorDef(T1, S.N, S.C,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.KW, index_dims=[D.kw]),
- O=TensorDef(U, S.N, S.C, S.OW, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs sum pooling.
-
- Layout:
- * Input: NCW.
- * Kernel: W.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.c, D.ow, D.kw)
- O[D.n, D.c, D.ow] += TypeFn.cast_signed(
- U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW])
+def pooling_ncw_sum(
+ I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+ O=TensorDef(U, S.N, S.C, S.OW, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs sum pooling.
+
+ Layout:
+ * Input: NCW.
+ * Kernel: W.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.c, D.ow, D.kw)
+ O[D.n, D.c, D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW])
@linalg_structured_op
-def pooling_nwc_max(I=TensorDef(T1, S.N,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KW, index_dims=[D.kw]),
- O=TensorDef(U, S.N, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs max pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.ow, D.c, D.kw)
- O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed(
- U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nwc_max(
+ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+ O=TensorDef(U, S.N, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs max pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.ow, D.c, D.kw)
+ O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](
+ TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+ )
@linalg_structured_op
-def pooling_nwc_max_unsigned(I=TensorDef(T1, S.N,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2,
- S.KW,
- index_dims=[D.kw]),
- O=TensorDef(U, S.N, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs unsigned max pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.ow, D.c, D.kw)
- O[D.n, D.ow,
- D.c] = ReduceFn.max_unsigned[[D.kw]](TypeFn.cast_unsigned(
- U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nwc_max_unsigned(
+ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+ O=TensorDef(U, S.N, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs unsigned max pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.ow, D.c, D.kw)
+ O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]](
+ TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+ )
@linalg_structured_op
-def pooling_ncw_max(I=TensorDef(T1, S.N, S.C,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.KW, index_dims=[D.kw]),
- O=TensorDef(U, S.N, S.C, S.OW, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs max pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.c, D.ow, D.kw)
- O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed(
- U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW,]))
+def pooling_ncw_max(
+ I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+ O=TensorDef(U, S.N, S.C, S.OW, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs max pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.c, D.ow, D.kw)
+ O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.c,
+ D.ow * S.SW + D.kw * S.DW,
+ ],
+ )
+ )
@linalg_structured_op
-def pooling_nwc_min(I=TensorDef(T1, S.N,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2, S.KW, index_dims=[D.kw]),
- O=TensorDef(U, S.N, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs min pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.ow, D.c, D.kw)
- O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](TypeFn.cast_signed(
- U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nwc_min(
+ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+ O=TensorDef(U, S.N, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs min pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.ow, D.c, D.kw)
+ O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](
+ TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+ )
@linalg_structured_op
-def pooling_nwc_min_unsigned(I=TensorDef(T1, S.N,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2,
- S.KW,
- index_dims=[D.kw]),
- O=TensorDef(U, S.N, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SW, default=[1]),
- dilations=IndexAttrDef(S.DW, default=[1])):
- """Performs unsigned min pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.ow, D.c, D.kw)
- O[D.n, D.ow,
- D.c] = ReduceFn.min_unsigned[[D.kw]](TypeFn.cast_unsigned(
- U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
-
+def pooling_nwc_min_unsigned(
+ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+ O=TensorDef(U, S.N, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1]),
+):
+ """Performs unsigned min pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.ow, D.c, D.kw)
+ O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]](
+ TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+ )
@linalg_structured_op
-def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2,
- S.KD,
- S.KH,
- S.KW,
- index_dims=[D.kd, D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
- dilations=IndexAttrDef(S.DD,
- S.DH,
- S.DW,
- default=[1, 1, 1])):
- """Performs 3D sum pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
- O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
- U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW, D.c])
+def pooling_ndhwc_sum(
+ I=TensorDef(
+ T1,
+ S.N,
+ S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW,
+ S.C,
+ ),
+ K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+ """Performs 3D sum pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
+ O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.c,
+ ],
+ )
@linalg_structured_op
-def pooling_ndhwc_max(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2,
- S.KD,
- S.KH,
- S.KW,
- index_dims=[D.kd, D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
- dilations=IndexAttrDef(S.DD,
- S.DH,
- S.DW,
- default=[1, 1, 1])):
- """Performs 3D max pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
- O[D.n, D.od, D.oh, D.ow,
- D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
- U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_ndhwc_max(
+ I=TensorDef(
+ T1,
+ S.N,
+ S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW,
+ S.C,
+ ),
+ K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+ """Performs 3D max pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
+ O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.c,
+ ],
+ )
+ )
@linalg_structured_op
-def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
- S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW, S.C),
- K=TensorDef(T2,
- S.KD,
- S.KH,
- S.KW,
- index_dims=[D.kd, D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
- dilations=IndexAttrDef(S.DD,
- S.DH,
- S.DW,
- default=[1, 1, 1])):
- """Performs 3D min pooling.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- implements(ConvolutionOpInterface)
- domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
- O[D.n, D.od, D.oh, D.ow,
- D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
- U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_ndhwc_min(
+ I=TensorDef(
+ T1,
+ S.N,
+ S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW,
+ S.C,
+ ),
+ K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+ """Performs 3D min pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
+ O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.c,
+ ],
+ )
+ )
@linalg_structured_op
def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
- """Fills the output tensor with the given value.
+ """Fills the output tensor with the given value.
- Works for arbitrary ranked output tensors since the operation performs scalar
- accesses only and is thus rank polymorphic. Numeric casting is performed on
- the value operand, promoting it to the same data type as the output.
- """
- implements(FillOpInterface)
- defines(Canonicalizer)
- O[None] = TypeFn.cast_signed(U, value)
+ Works for arbitrary ranked output tensors since the operation performs scalar
+ accesses only and is thus rank polymorphic. Numeric casting is performed on
+ the value operand, promoting it to the same data type as the output.
+ """
+ implements(FillOpInterface)
+ defines(Canonicalizer)
+ O[None] = TypeFn.cast_signed(U, value)
@linalg_structured_op
-def fill_rng_2d(min=ScalarDef(F64),
- max=ScalarDef(F64),
- seed=ScalarDef(I32),
- O=TensorDef(T, S.M, S.N, output=True)):
- """Fills the output tensor with pseudo random numbers.
-
- The operation generations pseudo random numbers using a linear congruential
- generator. It provides no guarantees regarding the distribution of the
- generated random numbers. Instead of generating the random numbers
- sequentially, it instantiates one random number generator per data element
- and runs them in parallel. The seed operand and the indices of the data
- element seed the random number generation. The min and max operands limit
- the range of the generated random numbers.
- """
- domain(D.m, D.n)
- multiplier = TypeFn.cast_signed(I32, const(1103515245))
- increment = TypeFn.cast_signed(I32, const(12345))
- rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
- rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
- inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
- offset = TypeFn.cast_signed(F64, const(2147483647))
- scaling = (max - min) * inv_range
- O[D.m, D.n] = TypeFn.cast_signed(
- T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
+def fill_rng_2d(
+ min=ScalarDef(F64),
+ max=ScalarDef(F64),
+ seed=ScalarDef(I32),
+ O=TensorDef(T, S.M, S.N, output=True),
+):
+ """Fills the output tensor with pseudo random numbers.
+
+ The operation generations pseudo random numbers using a linear congruential
+ generator. It provides no guarantees regarding the distribution of the
+ generated random numbers. Instead of generating the random numbers
+ sequentially, it instantiates one random number generator per data element
+ and runs them in parallel. The seed operand and the indices of the data
+ element seed the random number generation. The min and max operands limit
+ the range of the generated random numbers.
+ """
+ domain(D.m, D.n)
+ multiplier = TypeFn.cast_signed(I32, const(1103515245))
+ increment = TypeFn.cast_signed(I32, const(12345))
+ rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
+ rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
+ inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
+ offset = TypeFn.cast_signed(F64, const(2147483647))
+ scaling = (max - min) * inv_range
+ O[D.m, D.n] = TypeFn.cast_signed(
+ T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min
+ )
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index ca0d479f1f5fc..980f237b19391 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -5,6 +5,8 @@
from ._python_test_ops_gen import *
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
+
def register_python_test_dialect(context, load=True):
- from .._mlir_libs import _mlirPythonTest
- _mlirPythonTest.register_python_test_dialect(context, load)
+ from .._mlir_libs import _mlirPythonTest
+
+ _mlirPythonTest.register_python_test_dialect(context, load)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 78956c4370049..b505a490aeb97 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -6,16 +6,18 @@
class FailurePropagationMode(Enum):
- """Propagation mode for silenceable errors."""
- PROPAGATE = 1
- SUPPRESS = 2
+ """Propagation mode for silenceable errors."""
- def _as_int(self):
- if self is FailurePropagationMode.PROPAGATE:
- return 1
+ PROPAGATE = 1
+ SUPPRESS = 2
+
+ def _as_int(self):
+ if self is FailurePropagationMode.PROPAGATE:
+ return 1
+
+ assert self is FailurePropagationMode.SUPPRESS
+ return 2
- assert self is FailurePropagationMode.SUPPRESS
- return 2
from .._transform_ops_gen import *
from ..._mlir_libs._mlirDialectsTransform import *
diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py
index 262545b9ce726..4739231c155ba 100644
--- a/mlir/python/mlir/execution_engine.py
+++ b/mlir/python/mlir/execution_engine.py
@@ -7,37 +7,37 @@
import ctypes
__all__ = [
- "ExecutionEngine",
+ "ExecutionEngine",
]
-class ExecutionEngine(_execution_engine.ExecutionEngine):
- def lookup(self, name):
- """Lookup a function emitted with the `llvm.emit_c_interface`
- attribute and returns a ctype callable.
- Raise a RuntimeError if the function isn't found.
- """
- func = self.raw_lookup("_mlir_ciface_" + name)
- if not func:
- raise RuntimeError("Unknown function " + name)
- prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
- return prototype(func)
+class ExecutionEngine(_execution_engine.ExecutionEngine):
+ def lookup(self, name):
+ """Lookup a function emitted with the `llvm.emit_c_interface`
+ attribute and returns a ctype callable.
+ Raise a RuntimeError if the function isn't found.
+ """
+ func = self.raw_lookup("_mlir_ciface_" + name)
+ if not func:
+ raise RuntimeError("Unknown function " + name)
+ prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
+ return prototype(func)
- def invoke(self, name, *ctypes_args):
- """Invoke a function with the list of ctypes arguments.
- All arguments must be pointers.
- Raise a RuntimeError if the function isn't found.
- """
- func = self.lookup(name)
- packed_args = (ctypes.c_void_p * len(ctypes_args))()
- for argNum in range(len(ctypes_args)):
- packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
- func(packed_args)
+ def invoke(self, name, *ctypes_args):
+ """Invoke a function with the list of ctypes arguments.
+ All arguments must be pointers.
+ Raise a RuntimeError if the function isn't found.
+ """
+ func = self.lookup(name)
+ packed_args = (ctypes.c_void_p * len(ctypes_args))()
+ for argNum in range(len(ctypes_args)):
+ packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
+ func(packed_args)
- def register_runtime(self, name, ctypes_callback):
- """Register a runtime function available to the jitted code
- under the provided `name`. The `ctypes_callback` must be a
- `CFuncType` that outlives the execution engine.
- """
- callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
- self.raw_register_runtime("_mlir_ciface_" + name, callback)
+ def register_runtime(self, name, ctypes_callback):
+ """Register a runtime function available to the jitted code
+ under the provided `name`. The `ctypes_callback` must be a
+ `CFuncType` that outlives the execution engine.
+ """
+ callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
+ self.raw_register_runtime("_mlir_ciface_" + name, callback)
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index be065d4639519..99c21ff9aaef2 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -8,124 +8,123 @@
# Convenience decorator for registering user-friendly Attribute builders.
def register_attribute_builder(kind):
+ def decorator_builder(func):
+ AttrBuilder.insert(kind, func)
+ return func
- def decorator_builder(func):
- AttrBuilder.insert(kind, func)
- return func
-
- return decorator_builder
+ return decorator_builder
@register_attribute_builder("BoolAttr")
def _boolAttr(x, context):
- return BoolAttr.get(x, context=context)
+ return BoolAttr.get(x, context=context)
@register_attribute_builder("IndexAttr")
def _indexAttr(x, context):
- return IntegerAttr.get(IndexType.get(context=context), x)
+ return IntegerAttr.get(IndexType.get(context=context), x)
@register_attribute_builder("I16Attr")
def _i16Attr(x, context):
- return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
+ return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
@register_attribute_builder("I32Attr")
def _i32Attr(x, context):
- return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
@register_attribute_builder("I64Attr")
def _i64Attr(x, context):
- return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
@register_attribute_builder("SI16Attr")
def _si16Attr(x, context):
- return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
+ return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
@register_attribute_builder("SI32Attr")
def _si32Attr(x, context):
- return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
+ return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
@register_attribute_builder("F32Attr")
def _f32Attr(x, context):
- return FloatAttr.get_f32(x, context=context)
+ return FloatAttr.get_f32(x, context=context)
@register_attribute_builder("F64Attr")
def _f64Attr(x, context):
- return FloatAttr.get_f64(x, context=context)
+ return FloatAttr.get_f64(x, context=context)
@register_attribute_builder("StrAttr")
def _stringAttr(x, context):
- return StringAttr.get(x, context=context)
+ return StringAttr.get(x, context=context)
@register_attribute_builder("SymbolNameAttr")
def _symbolNameAttr(x, context):
- return StringAttr.get(x, context=context)
+ return StringAttr.get(x, context=context)
@register_attribute_builder("SymbolRefAttr")
def _symbolRefAttr(x, context):
- return FlatSymbolRefAttr.get(x, context=context)
+ return FlatSymbolRefAttr.get(x, context=context)
@register_attribute_builder("ArrayAttr")
def _arrayAttr(x, context):
- return ArrayAttr.get(x, context=context)
+ return ArrayAttr.get(x, context=context)
@register_attribute_builder("I32ArrayAttr")
def _i32ArrayAttr(x, context):
- return ArrayAttr.get([_i32Attr(v, context) for v in x])
+ return ArrayAttr.get([_i32Attr(v, context) for v in x])
@register_attribute_builder("I64ArrayAttr")
def _i64ArrayAttr(x, context):
- return ArrayAttr.get([_i64Attr(v, context) for v in x])
+ return ArrayAttr.get([_i64Attr(v, context) for v in x])
@register_attribute_builder("F32ArrayAttr")
def _f32ArrayAttr(x, context):
- return ArrayAttr.get([_f32Attr(v, context) for v in x])
+ return ArrayAttr.get([_f32Attr(v, context) for v in x])
@register_attribute_builder("F64ArrayAttr")
def _f64ArrayAttr(x, context):
- return ArrayAttr.get([_f64Attr(v, context) for v in x])
+ return ArrayAttr.get([_f64Attr(v, context) for v in x])
@register_attribute_builder("DenseI64ArrayAttr")
def _denseI64ArrayAttr(x, context):
- return DenseI64ArrayAttr.get(x, context=context)
+ return DenseI64ArrayAttr.get(x, context=context)
@register_attribute_builder("TypeAttr")
def _typeAttr(x, context):
- return TypeAttr.get(x, context=context)
+ return TypeAttr.get(x, context=context)
@register_attribute_builder("TypeArrayAttr")
def _typeArrayAttr(x, context):
- return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
+ return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
try:
- import numpy as np
+ import numpy as np
- @register_attribute_builder("IndexElementsAttr")
- def _indexElementsAttr(x, context):
- return DenseElementsAttr.get(
- np.array(x, dtype=np.int64),
- type=IndexType.get(context=context),
- context=context,
- )
+ @register_attribute_builder("IndexElementsAttr")
+ def _indexElementsAttr(x, context):
+ return DenseElementsAttr.get(
+ np.array(x, dtype=np.int64),
+ type=IndexType.get(context=context),
+ context=context,
+ )
except ImportError:
- pass
+ pass
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index d70967983c45d..51433d75ac4fb 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -9,131 +9,134 @@
class C128(ctypes.Structure):
- """A ctype representation for MLIR's Double Complex."""
- _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
+ """A ctype representation for MLIR's Double Complex."""
+
+ _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
class C64(ctypes.Structure):
- """A ctype representation for MLIR's Float Complex."""
- _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
+ """A ctype representation for MLIR's Float Complex."""
+
+ _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
class F16(ctypes.Structure):
- """A ctype representation for MLIR's Float16."""
- _fields_ = [("f16", ctypes.c_int16)]
+ """A ctype representation for MLIR's Float16."""
+
+ _fields_ = [("f16", ctypes.c_int16)]
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
- """Converts dtype to ctype."""
- if dtp == np.dtype(np.complex128):
- return C128
- if dtp == np.dtype(np.complex64):
- return C64
- if dtp == np.dtype(np.float16):
- return F16
- return np.ctypeslib.as_ctypes_type(dtp)
+ """Converts dtype to ctype."""
+ if dtp == np.dtype(np.complex128):
+ return C128
+ if dtp == np.dtype(np.complex64):
+ return C64
+ if dtp == np.dtype(np.float16):
+ return F16
+ return np.ctypeslib.as_ctypes_type(dtp)
def to_numpy(array):
- """Converts ctypes array back to numpy dtype array."""
- if array.dtype == C128:
- return array.view("complex128")
- if array.dtype == C64:
- return array.view("complex64")
- if array.dtype == F16:
- return array.view("float16")
- return array
+ """Converts ctypes array back to numpy dtype array."""
+ if array.dtype == C128:
+ return array.view("complex128")
+ if array.dtype == C64:
+ return array.view("complex64")
+ if array.dtype == F16:
+ return array.view("float16")
+ return array
def make_nd_memref_descriptor(rank, dtype):
+ class MemRefDescriptor(ctypes.Structure):
+ """Builds an empty descriptor for the given rank/dtype, where rank>0."""
- class MemRefDescriptor(ctypes.Structure):
- """Builds an empty descriptor for the given rank/dtype, where rank>0."""
+ _fields_ = [
+ ("allocated", ctypes.c_longlong),
+ ("aligned", ctypes.POINTER(dtype)),
+ ("offset", ctypes.c_longlong),
+ ("shape", ctypes.c_longlong * rank),
+ ("strides", ctypes.c_longlong * rank),
+ ]
- _fields_ = [
- ("allocated", ctypes.c_longlong),
- ("aligned", ctypes.POINTER(dtype)),
- ("offset", ctypes.c_longlong),
- ("shape", ctypes.c_longlong * rank),
- ("strides", ctypes.c_longlong * rank),
- ]
-
- return MemRefDescriptor
+ return MemRefDescriptor
def make_zero_d_memref_descriptor(dtype):
+ class MemRefDescriptor(ctypes.Structure):
+ """Builds an empty descriptor for the given dtype, where rank=0."""
- class MemRefDescriptor(ctypes.Structure):
- """Builds an empty descriptor for the given dtype, where rank=0."""
-
- _fields_ = [
- ("allocated", ctypes.c_longlong),
- ("aligned", ctypes.POINTER(dtype)),
- ("offset", ctypes.c_longlong),
- ]
+ _fields_ = [
+ ("allocated", ctypes.c_longlong),
+ ("aligned", ctypes.POINTER(dtype)),
+ ("offset", ctypes.c_longlong),
+ ]
- return MemRefDescriptor
+ return MemRefDescriptor
class UnrankedMemRefDescriptor(ctypes.Structure):
- """Creates a ctype struct for memref descriptor"""
- _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
+ """Creates a ctype struct for memref descriptor"""
+
+ _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
def get_ranked_memref_descriptor(nparray):
- """Returns a ranked memref descriptor for the given numpy array."""
- ctp = as_ctype(nparray.dtype)
- if nparray.ndim == 0:
- x = make_zero_d_memref_descriptor(ctp)()
+ """Returns a ranked memref descriptor for the given numpy array."""
+ ctp = as_ctype(nparray.dtype)
+ if nparray.ndim == 0:
+ x = make_zero_d_memref_descriptor(ctp)()
+ x.allocated = nparray.ctypes.data
+ x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
+ x.offset = ctypes.c_longlong(0)
+ return x
+
+ x = make_nd_memref_descriptor(nparray.ndim, ctp)()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
- return x
+ x.shape = nparray.ctypes.shape
- x = make_nd_memref_descriptor(nparray.ndim, ctp)()
- x.allocated = nparray.ctypes.data
- x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
- x.offset = ctypes.c_longlong(0)
- x.shape = nparray.ctypes.shape
-
- # Numpy uses byte quantities to express strides, MLIR OTOH uses the
- # torch abstraction which specifies strides in terms of elements.
- strides_ctype_t = ctypes.c_longlong * nparray.ndim
- x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
- return x
+ # Numpy uses byte quantities to express strides, MLIR OTOH uses the
+ # torch abstraction which specifies strides in terms of elements.
+ strides_ctype_t = ctypes.c_longlong * nparray.ndim
+ x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
+ return x
def get_unranked_memref_descriptor(nparray):
- """Returns a generic/unranked memref descriptor for the given numpy array."""
- d = UnrankedMemRefDescriptor()
- d.rank = nparray.ndim
- x = get_ranked_memref_descriptor(nparray)
- d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
- return d
+ """Returns a generic/unranked memref descriptor for the given numpy array."""
+ d = UnrankedMemRefDescriptor()
+ d.rank = nparray.ndim
+ x = get_ranked_memref_descriptor(nparray)
+ d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
+ return d
def unranked_memref_to_numpy(unranked_memref, np_dtype):
- """Converts unranked memrefs to numpy arrays."""
- ctp = as_ctype(np_dtype)
- descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
- val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
- np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
- strided_arr = np.lib.stride_tricks.as_strided(
- np_arr,
- np.ctypeslib.as_array(val[0].shape),
- np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
- )
- return to_numpy(strided_arr)
+ """Converts unranked memrefs to numpy arrays."""
+ ctp = as_ctype(np_dtype)
+ descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
+ val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
+ np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
+ strided_arr = np.lib.stride_tricks.as_strided(
+ np_arr,
+ np.ctypeslib.as_array(val[0].shape),
+ np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
+ )
+ return to_numpy(strided_arr)
def ranked_memref_to_numpy(ranked_memref):
- """Converts ranked memrefs to numpy arrays."""
- np_arr = np.ctypeslib.as_array(
- ranked_memref[0].aligned, shape=ranked_memref[0].shape)
- strided_arr = np.lib.stride_tricks.as_strided(
- np_arr,
- np.ctypeslib.as_array(ranked_memref[0].shape),
- np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
- )
- return to_numpy(strided_arr)
+ """Converts ranked memrefs to numpy arrays."""
+ np_arr = np.ctypeslib.as_array(
+ ranked_memref[0].aligned, shape=ranked_memref[0].shape
+ )
+ strided_arr = np.lib.stride_tricks.as_strided(
+ np_arr,
+ np.ctypeslib.as_array(ranked_memref[0].shape),
+ np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
+ )
+ return to_numpy(strided_arr)
diff --git a/mlir/test/CAPI/lit.local.cfg b/mlir/test/CAPI/lit.local.cfg
index f08a0de488ddd..bb0c17cdbada7 100644
--- a/mlir/test/CAPI/lit.local.cfg
+++ b/mlir/test/CAPI/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.add('.c')
+config.suffixes.add(".c")
diff --git a/mlir/test/Conversion/GPUToCUDA/lit.local.cfg b/mlir/test/Conversion/GPUToCUDA/lit.local.cfg
index 847c3efbf24ac..bc470ccc5733a 100644
--- a/mlir/test/Conversion/GPUToCUDA/lit.local.cfg
+++ b/mlir/test/Conversion/GPUToCUDA/lit.local.cfg
@@ -1,2 +1,2 @@
if not config.run_cuda_tests:
- config.unsupported = True
\ No newline at end of file
+ config.unsupported = True
diff --git a/mlir/test/Conversion/GPUToROCm/lit.local.cfg b/mlir/test/Conversion/GPUToROCm/lit.local.cfg
index 6eb561783b3fb..2f5cc9f3bad97 100644
--- a/mlir/test/Conversion/GPUToROCm/lit.local.cfg
+++ b/mlir/test/Conversion/GPUToROCm/lit.local.cfg
@@ -1,2 +1,2 @@
if not config.run_rocm_tests:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/Examples/Toy/Ch6/lit.local.cfg b/mlir/test/Examples/Toy/Ch6/lit.local.cfg
index c5aeb13c427c5..0d9aa1006cf8b 100644
--- a/mlir/test/Examples/Toy/Ch6/lit.local.cfg
+++ b/mlir/test/Examples/Toy/Ch6/lit.local.cfg
@@ -1,5 +1,3 @@
# Requires native execution.
-if 'host-supports-jit' not in config.available_features:
+if "host-supports-jit" not in config.available_features:
config.unsupported = True
-
-
diff --git a/mlir/test/Examples/Toy/Ch7/lit.local.cfg b/mlir/test/Examples/Toy/Ch7/lit.local.cfg
index c5aeb13c427c5..0d9aa1006cf8b 100644
--- a/mlir/test/Examples/Toy/Ch7/lit.local.cfg
+++ b/mlir/test/Examples/Toy/Ch7/lit.local.cfg
@@ -1,5 +1,3 @@
# Requires native execution.
-if 'host-supports-jit' not in config.available_features:
+if "host-supports-jit" not in config.available_features:
config.unsupported = True
-
-
diff --git a/mlir/test/Examples/lit.local.cfg b/mlir/test/Examples/lit.local.cfg
index 97db322f29e29..1a51296e2b60a 100644
--- a/mlir/test/Examples/lit.local.cfg
+++ b/mlir/test/Examples/lit.local.cfg
@@ -1,2 +1,2 @@
if not config.build_examples:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/Examples/standalone/lit.local.cfg b/mlir/test/Examples/standalone/lit.local.cfg
index cf7c8ff4d54f7..fe8397c6b9a10 100644
--- a/mlir/test/Examples/standalone/lit.local.cfg
+++ b/mlir/test/Examples/standalone/lit.local.cfg
@@ -1,13 +1,12 @@
# Disable with sanitizers for now, this require some more setup apparently.
-for san in ['asan', 'msan', 'ubsan']:
- if (san in config.available_features):
- config.unsupported = True
+for san in ["asan", "msan", "ubsan"]:
+ if san in config.available_features:
+ config.unsupported = True
config.substitutions.append(("%cmake_exe", config.host_cmake))
config.substitutions.append(("%cmake_generator", config.host_cmake_generator))
config.substitutions.append(("%host_cxx", config.host_cxx))
config.substitutions.append(("%host_cc", config.host_cc))
config.substitutions.append(("%enable_libcxx", config.enable_libcxx))
-config.substitutions.append(
- ("%mlir_cmake_dir", config.mlir_cmake_dir))
+config.substitutions.append(("%mlir_cmake_dir", config.mlir_cmake_dir))
config.substitutions.append(("%llvm_use_linker", config.llvm_use_linker))
diff --git a/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg b/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg
index 7215edacf7a83..073f6373d8581 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg
@@ -1,5 +1,5 @@
import sys
# Windows does not have aligned_alloc
-if sys.platform == 'win32':
+if sys.platform == "win32":
config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg b/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg
index 263c8f8a12283..071a13c27769e 100644
--- a/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg
@@ -1,4 +1,4 @@
import platform
-if platform.machine() != 'x86_64':
+if platform.machine() != "x86_64":
config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg b/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg
index 7d1e494eb96e9..3214a1101f084 100644
--- a/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg
@@ -1,18 +1,22 @@
import sys
-lli_cmd = 'lli'
+lli_cmd = "lli"
if config.riscv_emulator_lli_executable:
lli_cmd = config.riscv_emulator_lli_executable
-config.substitutions.append(('%mlir_native_utils_lib_dir',
- config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir))
+config.substitutions.append(
+ (
+ "%mlir_native_utils_lib_dir",
+ config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir,
+ )
+)
if config.riscv_vector_emulator_executable:
# Run test in qemu emulator.
emulation_cmd = config.riscv_vector_emulator_executable
if config.riscv_vector_emulator_options:
- emulation_cmd = emulation_cmd + ' ' + config.riscv_vector_emulator_options
- emulation_cmd = emulation_cmd + ' ' + lli_cmd + ' --march=riscv64 -mattr=+v '
- config.substitutions.append(('%lli', emulation_cmd))
+ emulation_cmd = emulation_cmd + " " + config.riscv_vector_emulator_options
+ emulation_cmd = emulation_cmd + " " + lli_cmd + " --march=riscv64 -mattr=+v "
+ config.substitutions.append(("%lli", emulation_cmd))
else:
- config.substitutions.append(('%lli', lli_cmd))
+ config.substitutions.append(("%lli", lli_cmd))
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg
index 9bf49cc8246ec..6e07eb87987a4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg
@@ -2,14 +2,16 @@ import sys
from lit.llvm import llvm_config
# FIXME: %mlir_native_utils_lib_dir is set incorrectly on Windows
-if sys.platform == 'win32':
+if sys.platform == "win32":
config.unsupported = True
# ArmSVE tests must be enabled via build flag.
if config.mlir_run_arm_sve_tests:
- config.substitutions.append(('%ENABLE_VLA', 'true'))
- config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', '--march=aarch64 --mattr="+sve"'))
+ config.substitutions.append(("%ENABLE_VLA", "true"))
+ config.substitutions.append(
+ ("%VLA_ARCH_ATTR_OPTIONS", '--march=aarch64 --mattr="+sve"')
+ )
else:
- config.substitutions.append(('%ENABLE_VLA', 'false'))
- config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', ''))
- config.substitutions.append(('%mlir_native_utils_lib_dir', config.mlir_lib_dir))
+ config.substitutions.append(("%ENABLE_VLA", "false"))
+ config.substitutions.append(("%VLA_ARCH_ATTR_OPTIONS", ""))
+ config.substitutions.append(("%mlir_native_utils_lib_dir", config.mlir_lib_dir))
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg
index c586aae6475bd..6788ccea3a222 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg
@@ -1,2 +1,2 @@
if not config.enable_cuda_runner or not config.mlir_run_cuda_sm80_tests:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg
index cf04454dea6ef..361b657dd2d83 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg
@@ -1,5 +1,5 @@
# Disable ASAN's leak detection for python OpsDSL tests.
-config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
+config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
# Only run when python bindings are enabled.
if not config.enable_bindings_python:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
index 958aa86752a07..1f9b636038318 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
@@ -18,42 +18,45 @@
sys.path.append(_SCRIPT_PATH)
from tools import sparse_compiler
+
@dsl.linalg_structured_op
def sddmm_dsl(
A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
S=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N),
- C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
- C[dsl.D.m,
- dsl.D.n] += S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+ C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
+):
+ C[dsl.D.m, dsl.D.n] += (
+ S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+ )
def build_SDDMM(attr: st.EncodingAttr):
- """Build SDDMM kernel.
+ """Build SDDMM kernel.
- This method generates a linalg op with for matrix multiplication using
- just the Python API. Effectively, a generic linalg op is constructed
- that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
- """
- module = ir.Module.create()
- f64 = ir.F64Type.get()
- a = ir.RankedTensorType.get([8, 8], f64)
- b = ir.RankedTensorType.get([8, 8], f64)
- c = ir.RankedTensorType.get([8, 8], f64)
- s = ir.RankedTensorType.get([8, 8], f64, attr)
- arguments = [a, b, s, c]
- with ir.InsertionPoint(module.body):
+ This method generates a linalg op with for matrix multiplication using
+ just the Python API. Effectively, a generic linalg op is constructed
+ that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
+ """
+ module = ir.Module.create()
+ f64 = ir.F64Type.get()
+ a = ir.RankedTensorType.get([8, 8], f64)
+ b = ir.RankedTensorType.get([8, 8], f64)
+ c = ir.RankedTensorType.get([8, 8], f64)
+ s = ir.RankedTensorType.get([8, 8], f64, attr)
+ arguments = [a, b, s, c]
+ with ir.InsertionPoint(module.body):
- @func.FuncOp.from_py_func(*arguments)
- def sddmm(*args):
- return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
+ @func.FuncOp.from_py_func(*arguments)
+ def sddmm(*args):
+ return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
- return module
+ return module
def boilerplate(attr: st.EncodingAttr):
- """Returns boilerplate code for main driver."""
- return f"""
+ """Returns boilerplate code for main driver."""
+ return f"""
func.func @main(%a: tensor<8x8xf64>,
%b: tensor<8x8xf64>,
%c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{
@@ -69,92 +72,100 @@ def boilerplate(attr: st.EncodingAttr):
def build_compile_and_run_SDDMMM(attr: st.EncodingAttr, compiler):
- # Build.
- module = build_SDDMM(attr)
- func = str(module.operation.regions[0].blocks[0].operations[0].operation)
- module = ir.Module.parse(func + boilerplate(attr))
-
- # Compile.
- engine = compiler.compile_and_jit(module)
-
- # Set up numpy input and buffer for output.
- a = np.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
- [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
- [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
- [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
- [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
- [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
- [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
- [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8]], np.float64)
- b = np.ones((8, 8), np.float64)
- c = np.zeros((8, 8), np.float64)
-
- mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
- mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
- mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-
- # Allocate a MemRefDescriptor to receive the output tensor.
- # The buffer itself is allocated inside the MLIR code generation.
- ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
- mem_out = ctypes.pointer(ctypes.pointer(ref_out))
-
- # Invoke the kernel and get numpy output.
- # Built-in bufferization uses in-out buffers.
- # TODO: replace with inplace comprehensive bufferization.
- engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
-
- # Sanity check on computed result. Only a few elements
- # are sampled from the full dense matrix multiplication.
- full_matmul = np.matmul(a, b)
- expected = np.zeros((8, 8), np.float64)
- expected[0, 0] = 1.0 * full_matmul[0, 0]
- expected[0, 2] = 2.0 * full_matmul[0, 2]
- expected[4, 1] = 3.0 * full_matmul[4, 1]
- c = rt.ranked_memref_to_numpy(mem_out[0])
- if np.allclose(c, expected):
- pass
- else:
- quit(f'FAILURE')
+ # Build.
+ module = build_SDDMM(attr)
+ func = str(module.operation.regions[0].blocks[0].operations[0].operation)
+ module = ir.Module.parse(func + boilerplate(attr))
+
+ # Compile.
+ engine = compiler.compile_and_jit(module)
+
+ # Set up numpy input and buffer for output.
+ a = np.array(
+ [
+ [1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
+ [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
+ [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
+ [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
+ [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
+ [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
+ [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
+ [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8],
+ ],
+ np.float64,
+ )
+ b = np.ones((8, 8), np.float64)
+ c = np.zeros((8, 8), np.float64)
+
+ mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+ mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+ mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+
+ # Allocate a MemRefDescriptor to receive the output tensor.
+ # The buffer itself is allocated inside the MLIR code generation.
+ ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
+ mem_out = ctypes.pointer(ctypes.pointer(ref_out))
+
+ # Invoke the kernel and get numpy output.
+ # Built-in bufferization uses in-out buffers.
+ # TODO: replace with inplace comprehensive bufferization.
+ engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
+
+ # Sanity check on computed result. Only a few elements
+ # are sampled from the full dense matrix multiplication.
+ full_matmul = np.matmul(a, b)
+ expected = np.zeros((8, 8), np.float64)
+ expected[0, 0] = 1.0 * full_matmul[0, 0]
+ expected[0, 2] = 2.0 * full_matmul[0, 2]
+ expected[4, 1] = 3.0 * full_matmul[4, 1]
+ c = rt.ranked_memref_to_numpy(mem_out[0])
+ if np.allclose(c, expected):
+ pass
+ else:
+ quit(f"FAILURE")
def main():
- support_lib = os.getenv('SUPPORT_LIB')
- assert support_lib is not None, 'SUPPORT_LIB is undefined'
- if not os.path.exists(support_lib):
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
- support_lib)
-
- # CHECK-LABEL: TEST: testSDDMMM
- print('\nTEST: testSDDMMM')
- with ir.Context() as ctx, ir.Location.unknown():
- count = 0
- # Loop over various ways to compile and annotate the SDDMM kernel with
- # a *single* sparse tensor. Note that we deliberate do not exhaustively
- # search the full state space to reduce runtime of the test. It is
- # straightforward to adapt the code below to explore more combinations.
- levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
- [st.DimLevelType.dense, st.DimLevelType.compressed],
- [st.DimLevelType.compressed, st.DimLevelType.dense],
- [st.DimLevelType.compressed, st.DimLevelType.compressed]]
- orderings = [
- ir.AffineMap.get_permutation([0, 1]),
- ir.AffineMap.get_permutation([1, 0])
- ]
- for level in levels:
- for ordering in orderings:
- for pwidth in [32]:
- for iwidth in [32]:
- for e in [True]:
- attr = st.EncodingAttr.get(level, ordering, None, pwidth,
- iwidth)
- opt = (f'parallelization-strategy=none')
- compiler = sparse_compiler.SparseCompiler(
- options=opt, opt_level=0, shared_libs=[support_lib])
- build_compile_and_run_SDDMMM(attr, compiler)
- count = count + 1
- # CHECK: Passed 8 tests
- print('Passed ', count, 'tests')
-
-
-if __name__ == '__main__':
- main()
+ support_lib = os.getenv("SUPPORT_LIB")
+ assert support_lib is not None, "SUPPORT_LIB is undefined"
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+ # CHECK-LABEL: TEST: testSDDMMM
+ print("\nTEST: testSDDMMM")
+ with ir.Context() as ctx, ir.Location.unknown():
+ count = 0
+ # Loop over various ways to compile and annotate the SDDMM kernel with
+ # a *single* sparse tensor. Note that we deliberate do not exhaustively
+ # search the full state space to reduce runtime of the test. It is
+ # straightforward to adapt the code below to explore more combinations.
+ levels = [
+ [st.DimLevelType.dense, st.DimLevelType.dense],
+ [st.DimLevelType.dense, st.DimLevelType.compressed],
+ [st.DimLevelType.compressed, st.DimLevelType.dense],
+ [st.DimLevelType.compressed, st.DimLevelType.compressed],
+ ]
+ orderings = [
+ ir.AffineMap.get_permutation([0, 1]),
+ ir.AffineMap.get_permutation([1, 0]),
+ ]
+ for level in levels:
+ for ordering in orderings:
+ for pwidth in [32]:
+ for iwidth in [32]:
+ for e in [True]:
+ attr = st.EncodingAttr.get(
+ level, ordering, None, pwidth, iwidth
+ )
+ opt = f"parallelization-strategy=none"
+ compiler = sparse_compiler.SparseCompiler(
+ options=opt, opt_level=0, shared_libs=[support_lib]
+ )
+ build_compile_and_run_SDDMMM(attr, compiler)
+ count = count + 1
+ # CHECK: Passed 8 tests
+ print("Passed ", count, "tests")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
index 97954ce08ced1..69f6cdcea967f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
@@ -18,45 +18,47 @@
sys.path.append(_SCRIPT_PATH)
from tools import sparse_compiler
+
@dsl.linalg_structured_op
def matmul_dsl(
A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
- C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
- C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+ C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
+):
+ C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
def build_SpMM(attr: st.EncodingAttr):
- """Build SpMM kernel.
+ """Build SpMM kernel.
- This method generates a linalg op with for matrix multiplication using
- just the Python API. Effectively, a generic linalg op is constructed
- that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
- """
- module = ir.Module.create()
- f64 = ir.F64Type.get()
- a = ir.RankedTensorType.get([3, 4], f64, attr)
- b = ir.RankedTensorType.get([4, 2], f64)
- c = ir.RankedTensorType.get([3, 2], f64)
- arguments = [a, b, c]
- with ir.InsertionPoint(module.body):
+ This method generates a linalg op with for matrix multiplication using
+ just the Python API. Effectively, a generic linalg op is constructed
+ that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
+ """
+ module = ir.Module.create()
+ f64 = ir.F64Type.get()
+ a = ir.RankedTensorType.get([3, 4], f64, attr)
+ b = ir.RankedTensorType.get([4, 2], f64)
+ c = ir.RankedTensorType.get([3, 2], f64)
+ arguments = [a, b, c]
+ with ir.InsertionPoint(module.body):
- @func.FuncOp.from_py_func(*arguments)
- def spMxM(*args):
- return matmul_dsl(args[0], args[1], outs=[args[2]])
+ @func.FuncOp.from_py_func(*arguments)
+ def spMxM(*args):
+ return matmul_dsl(args[0], args[1], outs=[args[2]])
- return module
+ return module
def boilerplate(attr: st.EncodingAttr):
- """Returns boilerplate main method.
-
- This method sets up a boilerplate main method that takes three tensors
- (a, b, c), converts the first tensor a into s sparse tensor, and then
- calls the sparse kernel for matrix multiplication. For convenience,
- this part is purely done as string input.
- """
- return f"""
+ """Returns boilerplate main method.
+
+ This method sets up a boilerplate main method that takes three tensors
+ (a, b, c), converts the first tensor a into s sparse tensor, and then
+ calls the sparse kernel for matrix multiplication. For convenience,
+ this part is purely done as string input.
+ """
+ return f"""
func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64>
attributes {{ llvm.emit_c_interface }} {{
%a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
@@ -69,82 +71,87 @@ def boilerplate(attr: st.EncodingAttr):
def build_compile_and_run_SpMM(attr: st.EncodingAttr, compiler):
- # Build.
- module = build_SpMM(attr)
- func = str(module.operation.regions[0].blocks[0].operations[0].operation)
- module = ir.Module.parse(func + boilerplate(attr))
-
- # Compile.
- engine = compiler.compile_and_jit(module)
-
- # Set up numpy input and buffer for output.
- a = np.array(
- [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]],
- np.float64)
- b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
- c = np.zeros((3, 2), np.float64)
-
- mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
- mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
- mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
- # Allocate a MemRefDescriptor to receive the output tensor.
- # The buffer itself is allocated inside the MLIR code generation.
- ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
- mem_out = ctypes.pointer(ctypes.pointer(ref_out))
-
- # Invoke the kernel and get numpy output.
- # Built-in bufferization uses in-out buffers.
- # TODO: replace with inplace comprehensive bufferization.
- engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
-
- # Sanity check on computed result.
- expected = np.matmul(a, b);
- c = rt.ranked_memref_to_numpy(mem_out[0])
- if np.allclose(c, expected):
- pass
- else:
- quit(f'FAILURE')
+ # Build.
+ module = build_SpMM(attr)
+ func = str(module.operation.regions[0].blocks[0].operations[0].operation)
+ module = ir.Module.parse(func + boilerplate(attr))
+
+ # Compile.
+ engine = compiler.compile_and_jit(module)
+
+ # Set up numpy input and buffer for output.
+ a = np.array(
+ [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], np.float64
+ )
+ b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
+ c = np.zeros((3, 2), np.float64)
+
+ mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+ mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+ mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+ # Allocate a MemRefDescriptor to receive the output tensor.
+ # The buffer itself is allocated inside the MLIR code generation.
+ ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
+ mem_out = ctypes.pointer(ctypes.pointer(ref_out))
+
+ # Invoke the kernel and get numpy output.
+ # Built-in bufferization uses in-out buffers.
+ # TODO: replace with inplace comprehensive bufferization.
+ engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
+
+ # Sanity check on computed result.
+ expected = np.matmul(a, b)
+ c = rt.ranked_memref_to_numpy(mem_out[0])
+ if np.allclose(c, expected):
+ pass
+ else:
+ quit(f"FAILURE")
def main():
- support_lib = os.getenv('SUPPORT_LIB')
- assert support_lib is not None, 'SUPPORT_LIB is undefined'
- if not os.path.exists(support_lib):
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-
- # CHECK-LABEL: TEST: testSpMM
- print('\nTEST: testSpMM')
- with ir.Context() as ctx, ir.Location.unknown():
- count = 0
- # Loop over various ways to compile and annotate the SpMM kernel with
- # a *single* sparse tensor. Note that we deliberate do not exhaustively
- # search the full state space to reduce runtime of the test. It is
- # straightforward to adapt the code below to explore more combinations.
-
- vl = 1
- e = False
- opt = (f'parallelization-strategy=none')
- levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
- [st.DimLevelType.dense, st.DimLevelType.compressed],
- [st.DimLevelType.compressed, st.DimLevelType.dense],
- [st.DimLevelType.compressed, st.DimLevelType.compressed]]
- orderings = [
- ir.AffineMap.get_permutation([0, 1]),
- ir.AffineMap.get_permutation([1, 0])
- ]
- bitwidths = [0]
- compiler = sparse_compiler.SparseCompiler(
- options=opt, opt_level=0, shared_libs=[support_lib])
- for level in levels:
- for ordering in orderings:
- for pwidth in bitwidths:
- for iwidth in bitwidths:
- attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
- build_compile_and_run_SpMM(attr, compiler)
- count = count + 1
- # CHECK: Passed 8 tests
- print('Passed ', count, 'tests')
-
-
-if __name__ == '__main__':
- main()
+ support_lib = os.getenv("SUPPORT_LIB")
+ assert support_lib is not None, "SUPPORT_LIB is undefined"
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+ # CHECK-LABEL: TEST: testSpMM
+ print("\nTEST: testSpMM")
+ with ir.Context() as ctx, ir.Location.unknown():
+ count = 0
+ # Loop over various ways to compile and annotate the SpMM kernel with
+ # a *single* sparse tensor. Note that we deliberate do not exhaustively
+ # search the full state space to reduce runtime of the test. It is
+ # straightforward to adapt the code below to explore more combinations.
+
+ vl = 1
+ e = False
+ opt = f"parallelization-strategy=none"
+ levels = [
+ [st.DimLevelType.dense, st.DimLevelType.dense],
+ [st.DimLevelType.dense, st.DimLevelType.compressed],
+ [st.DimLevelType.compressed, st.DimLevelType.dense],
+ [st.DimLevelType.compressed, st.DimLevelType.compressed],
+ ]
+ orderings = [
+ ir.AffineMap.get_permutation([0, 1]),
+ ir.AffineMap.get_permutation([1, 0]),
+ ]
+ bitwidths = [0]
+ compiler = sparse_compiler.SparseCompiler(
+ options=opt, opt_level=0, shared_libs=[support_lib]
+ )
+ for level in levels:
+ for ordering in orderings:
+ for pwidth in bitwidths:
+ for iwidth in bitwidths:
+ attr = st.EncodingAttr.get(
+ level, ordering, None, pwidth, iwidth
+ )
+ build_compile_and_run_SpMM(attr, compiler)
+ count = count + 1
+ # CHECK: Passed 8 tests
+ print("Passed ", count, "tests")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
index b29b029c7a331..a41bde1ee2d34 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
@@ -57,49 +57,52 @@
def _run_test(support_lib, kernel):
- """Compiles, runs and checks results."""
- compiler = sparse_compiler.SparseCompiler(
- options='', opt_level=2, shared_libs=[support_lib])
- module = ir.Module.parse(kernel)
- engine = compiler.compile_and_jit(module)
-
- # Set up numpy inputs and buffer for output.
- a = np.array(
- [[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]],
- np.float64)
- b = np.array(
- [[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
- np.float64)
-
- mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
- mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
-
- # The sparse tensor output is a pointer to pointer of char.
- out = ctypes.c_char(0)
- mem_out = ctypes.pointer(ctypes.pointer(out))
-
- # Invoke the kernel.
- engine.invoke('main', mem_a, mem_b, mem_out)
-
- # Retrieve and check the result.
- rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
- support_lib, mem_out[0], np.float64)
-
- # CHECK: PASSED
- if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
- indices, [[0, 0], [0, 3], [2, 2]]):
- print('PASSED')
- else:
- quit('FAILURE')
+ """Compiles, runs and checks results."""
+ compiler = sparse_compiler.SparseCompiler(
+ options="", opt_level=2, shared_libs=[support_lib]
+ )
+ module = ir.Module.parse(kernel)
+ engine = compiler.compile_and_jit(module)
+
+ # Set up numpy inputs and buffer for output.
+ a = np.array(
+ [[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]], np.float64
+ )
+ b = np.array(
+ [[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], np.float64
+ )
+
+ mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+ mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+
+ # The sparse tensor output is a pointer to pointer of char.
+ out = ctypes.c_char(0)
+ mem_out = ctypes.pointer(ctypes.pointer(out))
+
+ # Invoke the kernel.
+ engine.invoke("main", mem_a, mem_b, mem_out)
+
+ # Retrieve and check the result.
+ rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
+ support_lib, mem_out[0], np.float64
+ )
+
+ # CHECK: PASSED
+ if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
+ indices, [[0, 0], [0, 3], [2, 2]]
+ ):
+ print("PASSED")
+ else:
+ quit("FAILURE")
def test_elementwise_add():
- # Obtain path to runtime support library.
- support_lib = os.getenv('SUPPORT_LIB')
- assert support_lib is not None, 'SUPPORT_LIB is undefined'
- assert os.path.exists(support_lib), f'{support_lib} does not exist'
- with ir.Context() as ctx, ir.Location.unknown():
- _run_test(support_lib, _KERNEL_STR)
+ # Obtain path to runtime support library.
+ support_lib = os.getenv("SUPPORT_LIB")
+ assert support_lib is not None, "SUPPORT_LIB is undefined"
+ assert os.path.exists(support_lib), f"{support_lib} does not exist"
+ with ir.Context() as ctx, ir.Location.unknown():
+ _run_test(support_lib, _KERNEL_STR)
test_elementwise_add()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 7d57b1c901948..7d77490080205 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -18,8 +18,8 @@
# TODO: move more into actual IR building.
def boilerplate(attr: st.EncodingAttr):
- """Returns boilerplate main method."""
- return f"""
+ """Returns boilerplate main method."""
+ return f"""
func.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{
%d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
[1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
@@ -31,13 +31,13 @@ def boilerplate(attr: st.EncodingAttr):
def expected():
- """Returns expected contents of output.
+ """Returns expected contents of output.
- Regardless of the dimension ordering, compression, and bitwidths that are
- used in the sparse tensor, the output is always lexicographically sorted
- by natural index order.
- """
- return f"""; extended FROSTT format
+ Regardless of the dimension ordering, compression, and bitwidths that are
+ used in the sparse tensor, the output is always lexicographically sorted
+ by natural index order.
+ """
+ return f"""; extended FROSTT format
2 5
10 10
1 1 1
@@ -49,53 +49,55 @@ def expected():
def build_compile_and_run_output(attr: st.EncodingAttr, compiler):
- # Build and Compile.
- module = ir.Module.parse(boilerplate(attr))
- engine = compiler.compile_and_jit(module)
+ # Build and Compile.
+ module = ir.Module.parse(boilerplate(attr))
+ engine = compiler.compile_and_jit(module)
- # Invoke the kernel and compare output.
- with tempfile.TemporaryDirectory() as test_dir:
- out = os.path.join(test_dir, 'out.tns')
- buf = out.encode('utf-8')
- mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
- engine.invoke('main', mem_a)
+ # Invoke the kernel and compare output.
+ with tempfile.TemporaryDirectory() as test_dir:
+ out = os.path.join(test_dir, "out.tns")
+ buf = out.encode("utf-8")
+ mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
+ engine.invoke("main", mem_a)
- actual = open(out).read()
- if actual != expected():
- quit('FAILURE')
+ actual = open(out).read()
+ if actual != expected():
+ quit("FAILURE")
def main():
- support_lib = os.getenv('SUPPORT_LIB')
- assert support_lib is not None, 'SUPPORT_LIB is undefined'
- if not os.path.exists(support_lib):
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
- support_lib)
-
- # CHECK-LABEL: TEST: test_output
- print('\nTEST: test_output')
- count = 0
- with ir.Context() as ctx, ir.Location.unknown():
- # Loop over various sparse types: CSR, DCSR, CSC, DCSC.
- levels = [[st.DimLevelType.dense, st.DimLevelType.compressed],
- [st.DimLevelType.compressed, st.DimLevelType.compressed]]
- orderings = [
- ir.AffineMap.get_permutation([0, 1]),
- ir.AffineMap.get_permutation([1, 0])
- ]
- bitwidths = [8, 16, 32, 64]
- compiler = sparse_compiler.SparseCompiler(
- options='', opt_level=2, shared_libs=[support_lib])
- for level in levels:
- for ordering in orderings:
- for bwidth in bitwidths:
- attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
- build_compile_and_run_output(attr, compiler)
- count = count + 1
-
- # CHECK: Passed 16 tests
- print('Passed', count, 'tests')
-
-
-if __name__ == '__main__':
- main()
+ support_lib = os.getenv("SUPPORT_LIB")
+ assert support_lib is not None, "SUPPORT_LIB is undefined"
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+ # CHECK-LABEL: TEST: test_output
+ print("\nTEST: test_output")
+ count = 0
+ with ir.Context() as ctx, ir.Location.unknown():
+ # Loop over various sparse types: CSR, DCSR, CSC, DCSC.
+ levels = [
+ [st.DimLevelType.dense, st.DimLevelType.compressed],
+ [st.DimLevelType.compressed, st.DimLevelType.compressed],
+ ]
+ orderings = [
+ ir.AffineMap.get_permutation([0, 1]),
+ ir.AffineMap.get_permutation([1, 0]),
+ ]
+ bitwidths = [8, 16, 32, 64]
+ compiler = sparse_compiler.SparseCompiler(
+ options="", opt_level=2, shared_libs=[support_lib]
+ )
+ for level in levels:
+ for ordering in orderings:
+ for bwidth in bitwidths:
+ attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
+ build_compile_and_run_output(attr, compiler)
+ count = count + 1
+
+ # CHECK: Passed 16 tests
+ print("Passed", count, "tests")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
index 3a04e5b9ab5ca..373f7457e0b5f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
@@ -28,216 +28,241 @@
# TODO: move this boilerplate to its own module, so it can be used by
# other tests and programs.
class TypeConverter:
- """Converter between NumPy types and MLIR types."""
-
- def __init__(self, context: ir.Context):
- # Note 1: these are numpy "scalar types" (i.e., the values of
- # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
- #
- # Note 2: we must construct the MLIR types in the same context as the
- # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
- # otherwise, those methods will raise a KeyError.
- types_list = [
- (np.float64, ir.F64Type.get(context=context)),
- (np.float32, ir.F32Type.get(context=context)),
- (np.int64, ir.IntegerType.get_signless(64, context=context)),
- (np.int32, ir.IntegerType.get_signless(32, context=context)),
- (np.int16, ir.IntegerType.get_signless(16, context=context)),
- (np.int8, ir.IntegerType.get_signless(8, context=context)),
- ]
- self._sc2ir = dict(types_list)
- self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
-
- def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
- """Returns the MLIR equivalent of a NumPy dtype."""
- try:
- return self.sctype_to_irtype(dtype.type)
- except KeyError as e:
- raise KeyError(f'Unknown dtype: {dtype}') from e
-
- def sctype_to_irtype(self, sctype) -> ir.Type:
- """Returns the MLIR equivalent of a NumPy scalar type."""
- if sctype in self._sc2ir:
- return self._sc2ir[sctype]
- else:
- raise KeyError(f'Unknown sctype: {sctype}')
-
- def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
- """Returns the NumPy dtype equivalent of an MLIR type."""
- return np.dtype(self.irtype_to_sctype(tp))
-
- def irtype_to_sctype(self, tp: ir.Type):
- """Returns the NumPy scalar-type equivalent of an MLIR type."""
- if tp in self._ir2sc:
- return self._ir2sc[tp]
- else:
- raise KeyError(f'Unknown ir.Type: {tp}')
-
- def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
- """Returns the ir.RankedTensorType of a NumPy array. Note that NumPy
- arrays can only be converted to/from dense tensors, not sparse tensors."""
- # TODO: handle strides as well?
- return ir.RankedTensorType.get(nparray.shape,
- self.dtype_to_irtype(nparray.dtype))
+ """Converter between NumPy types and MLIR types."""
+
+ def __init__(self, context: ir.Context):
+ # Note 1: these are numpy "scalar types" (i.e., the values of
+ # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
+ #
+ # Note 2: we must construct the MLIR types in the same context as the
+ # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
+ # otherwise, those methods will raise a KeyError.
+ types_list = [
+ (np.float64, ir.F64Type.get(context=context)),
+ (np.float32, ir.F32Type.get(context=context)),
+ (np.int64, ir.IntegerType.get_signless(64, context=context)),
+ (np.int32, ir.IntegerType.get_signless(32, context=context)),
+ (np.int16, ir.IntegerType.get_signless(16, context=context)),
+ (np.int8, ir.IntegerType.get_signless(8, context=context)),
+ ]
+ self._sc2ir = dict(types_list)
+ self._ir2sc = dict(((ir, sc) for sc, ir in types_list))
+
+ def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
+ """Returns the MLIR equivalent of a NumPy dtype."""
+ try:
+ return self.sctype_to_irtype(dtype.type)
+ except KeyError as e:
+ raise KeyError(f"Unknown dtype: {dtype}") from e
+
+ def sctype_to_irtype(self, sctype) -> ir.Type:
+ """Returns the MLIR equivalent of a NumPy scalar type."""
+ if sctype in self._sc2ir:
+ return self._sc2ir[sctype]
+ else:
+ raise KeyError(f"Unknown sctype: {sctype}")
+
+ def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
+ """Returns the NumPy dtype equivalent of an MLIR type."""
+ return np.dtype(self.irtype_to_sctype(tp))
+
+ def irtype_to_sctype(self, tp: ir.Type):
+ """Returns the NumPy scalar-type equivalent of an MLIR type."""
+ if tp in self._ir2sc:
+ return self._ir2sc[tp]
+ else:
+ raise KeyError(f"Unknown ir.Type: {tp}")
+
+ def get_RankedTensorType_of_nparray(
+ self, nparray: np.ndarray
+ ) -> ir.RankedTensorType:
+ """Returns the ir.RankedTensorType of a NumPy array. Note that NumPy
+ arrays can only be converted to/from dense tensors, not sparse tensors."""
+ # TODO: handle strides as well?
+ return ir.RankedTensorType.get(
+ nparray.shape, self.dtype_to_irtype(nparray.dtype)
+ )
+
# ===----------------------------------------------------------------------=== #
+
class StressTest:
- def __init__(self, tyconv: TypeConverter):
- self._tyconv = tyconv
- self._roundtripTp = None
- self._module = None
- self._engine = None
-
- def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
- assert self._roundtripTp is not None, \
- 'StressTest: uninitialized roundtrip type'
- if tp != self._roundtripTp:
- raise AssertionError(
- f"Type is not equal to the roundtrip type.\n"
- f"\tExpected: {self._roundtripTp}\n"
- f"\tFound: {tp}\n")
-
- def build(self, types: List[ir.Type]):
- """Builds the ir.Module. The module has only the @main function,
- which will convert the input through the list of types and then back
- to the initial type. The roundtrip type must be a dense tensor."""
- assert self._module is None, 'StressTest: must not call build() repeatedly'
- self._module = ir.Module.create()
- with ir.InsertionPoint(self._module.body):
- tp0 = types.pop(0)
- self._roundtripTp = tp0
- # TODO: assert dense? assert element type is recognised by the TypeConverter?
- types.append(tp0)
- funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
- funcOp = func.FuncOp(name='main', type=funcTp)
- funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
- with ir.InsertionPoint(funcOp.add_entry_block()):
- arg0 = funcOp.entry_block.arguments[0]
- self._assertEqualsRoundtripTp(arg0.type)
- v = st.ConvertOp(types.pop(0), arg0)
- for tp in types:
- w = st.ConvertOp(tp, v)
- # Release intermediate tensors before they fall out of scope.
- bufferization.DeallocTensorOp(v.result)
- v = w
- self._assertEqualsRoundtripTp(v.result.type)
- func.ReturnOp(v)
- return self
-
- def writeTo(self, filename):
- """Write the ir.Module to the given file. If the file already exists,
- then raises an error. If the filename is None, then is a no-op."""
- assert self._module is not None, \
- 'StressTest: must call build() before writeTo()'
- if filename is None:
- # Silent no-op, for convenience.
- return self
- if os.path.exists(filename):
- raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
- with open(filename, 'w') as f:
- f.write(str(self._module))
- return self
-
- def compile(self, compiler):
- """Compile the ir.Module."""
- assert self._module is not None, \
- 'StressTest: must call build() before compile()'
- assert self._engine is None, \
- 'StressTest: must not call compile() repeatedly'
- self._engine = compiler.compile_and_jit(self._module)
- return self
-
- def run(self, np_arg0: np.ndarray) -> np.ndarray:
- """Runs the test on the given numpy array, and returns the resulting
- numpy array."""
- assert self._engine is not None, \
- 'StressTest: must call compile() before run()'
- self._assertEqualsRoundtripTp(
- self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
- np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
- self._assertEqualsRoundtripTp(
- self._tyconv.get_RankedTensorType_of_nparray(np_out))
- mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
- mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
- self._engine.invoke('main', mem_out, mem_arg0)
- return rt.ranked_memref_to_numpy(mem_out[0])
+ def __init__(self, tyconv: TypeConverter):
+ self._tyconv = tyconv
+ self._roundtripTp = None
+ self._module = None
+ self._engine = None
+
+ def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
+ assert self._roundtripTp is not None, "StressTest: uninitialized roundtrip type"
+ if tp != self._roundtripTp:
+ raise AssertionError(
+ f"Type is not equal to the roundtrip type.\n"
+ f"\tExpected: {self._roundtripTp}\n"
+ f"\tFound: {tp}\n"
+ )
+
+ def build(self, types: List[ir.Type]):
+ """Builds the ir.Module. The module has only the @main function,
+ which will convert the input through the list of types and then back
+ to the initial type. The roundtrip type must be a dense tensor."""
+ assert self._module is None, "StressTest: must not call build() repeatedly"
+ self._module = ir.Module.create()
+ with ir.InsertionPoint(self._module.body):
+ tp0 = types.pop(0)
+ self._roundtripTp = tp0
+ # TODO: assert dense? assert element type is recognised by the TypeConverter?
+ types.append(tp0)
+ funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
+ funcOp = func.FuncOp(name="main", type=funcTp)
+ funcOp.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ with ir.InsertionPoint(funcOp.add_entry_block()):
+ arg0 = funcOp.entry_block.arguments[0]
+ self._assertEqualsRoundtripTp(arg0.type)
+ v = st.ConvertOp(types.pop(0), arg0)
+ for tp in types:
+ w = st.ConvertOp(tp, v)
+ # Release intermediate tensors before they fall out of scope.
+ bufferization.DeallocTensorOp(v.result)
+ v = w
+ self._assertEqualsRoundtripTp(v.result.type)
+ func.ReturnOp(v)
+ return self
+
+ def writeTo(self, filename):
+ """Write the ir.Module to the given file. If the file already exists,
+ then raises an error. If the filename is None, then is a no-op."""
+ assert (
+ self._module is not None
+ ), "StressTest: must call build() before writeTo()"
+ if filename is None:
+ # Silent no-op, for convenience.
+ return self
+ if os.path.exists(filename):
+ raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
+ with open(filename, "w") as f:
+ f.write(str(self._module))
+ return self
+
+ def compile(self, compiler):
+ """Compile the ir.Module."""
+ assert (
+ self._module is not None
+ ), "StressTest: must call build() before compile()"
+ assert self._engine is None, "StressTest: must not call compile() repeatedly"
+ self._engine = compiler.compile_and_jit(self._module)
+ return self
+
+ def run(self, np_arg0: np.ndarray) -> np.ndarray:
+ """Runs the test on the given numpy array, and returns the resulting
+ numpy array."""
+ assert self._engine is not None, "StressTest: must call compile() before run()"
+ self._assertEqualsRoundtripTp(
+ self._tyconv.get_RankedTensorType_of_nparray(np_arg0)
+ )
+ np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
+ self._assertEqualsRoundtripTp(
+ self._tyconv.get_RankedTensorType_of_nparray(np_out)
+ )
+ mem_arg0 = ctypes.pointer(
+ ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0))
+ )
+ mem_out = ctypes.pointer(
+ ctypes.pointer(rt.get_ranked_memref_descriptor(np_out))
+ )
+ self._engine.invoke("main", mem_out, mem_arg0)
+ return rt.ranked_memref_to_numpy(mem_out[0])
+
# ===----------------------------------------------------------------------=== #
+
def main():
- """
- USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
-
- The environment variable SUPPORT_LIB must be set to point to the
- libmlir_c_runner_utils shared library. There are two optional
- arguments, for debugging purposes. The first argument specifies where
- to write out the raw/generated ir.Module. The second argument specifies
- where to write out the compiled version of that ir.Module.
- """
- support_lib = os.getenv('SUPPORT_LIB')
- assert support_lib is not None, 'SUPPORT_LIB is undefined'
- if not os.path.exists(support_lib):
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-
- # CHECK-LABEL: TEST: test_stress
- print("\nTEST: test_stress")
- with ir.Context() as ctx, ir.Location.unknown():
- # Disable direct sparse2sparse conversion, because it doubles the time!
- # TODO: While direct s2s is far too slow for per-commit testing,
- # we should have some framework ensure that we run this test with
- # `s2s=0` on a regular basis, to ensure that it does continue to work.
- # TODO: be sure to test s2s=0 together with singletons.
- s2s = 1
- sparsification_options = (
- f'parallelization-strategy=none '
- f's2s-strategy={s2s}')
- compiler = sparse_compiler.SparseCompiler(
- options=sparsification_options, opt_level=0, shared_libs=[support_lib])
- f64 = ir.F64Type.get()
- # Be careful about increasing this because
- # len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
- shape = range(2, 3)
- rank = len(shape)
- # All combinations.
- # TODO: add singleton here too; which requires updating how `np_arg0`
- # is initialized below.
- levels = list(itertools.product(*itertools.repeat(
- [st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
- # All permutations.
- orderings = list(map(ir.AffineMap.get_permutation,
- itertools.permutations(range(rank))))
- bitwidths = [0]
- # The first type must be a dense tensor for numpy conversion to work.
- types = [ir.RankedTensorType.get(shape, f64)]
- for level in levels:
- for ordering in orderings:
- for pwidth in bitwidths:
- for iwidth in bitwidths:
- attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
- types.append(ir.RankedTensorType.get(shape, f64, attr))
- #
- # For exhaustiveness we should have one or more StressTest, such
- # that their paths cover all 2*n*(n-1) directed pairwise combinations
- # of the `types` set. However, since n is already superexponential,
- # such exhaustiveness would be prohibitive for a test that runs on
- # every commit. So for now we'll just pick one particular path that
- # at least hits all n elements of the `types` set.
- #
- tyconv = TypeConverter(ctx)
- size = 1
- for d in shape:
- size *= d
- np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
- np_out = (
- StressTest(tyconv).build(types).writeTo(
- sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler)
- .writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
- # CHECK: Passed
- if np.allclose(np_out, np_arg0):
- print('Passed')
- else:
- sys.exit('FAILURE')
-
-if __name__ == '__main__':
- main()
+ """
+ USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
+
+ The environment variable SUPPORT_LIB must be set to point to the
+ libmlir_c_runner_utils shared library. There are two optional
+ arguments, for debugging purposes. The first argument specifies where
+ to write out the raw/generated ir.Module. The second argument specifies
+ where to write out the compiled version of that ir.Module.
+ """
+ support_lib = os.getenv("SUPPORT_LIB")
+ assert support_lib is not None, "SUPPORT_LIB is undefined"
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+ # CHECK-LABEL: TEST: test_stress
+ print("\nTEST: test_stress")
+ with ir.Context() as ctx, ir.Location.unknown():
+ # Disable direct sparse2sparse conversion, because it doubles the time!
+ # TODO: While direct s2s is far too slow for per-commit testing,
+ # we should have some framework ensure that we run this test with
+ # `s2s=0` on a regular basis, to ensure that it does continue to work.
+ # TODO: be sure to test s2s=0 together with singletons.
+ s2s = 1
+ sparsification_options = f"parallelization-strategy=none " f"s2s-strategy={s2s}"
+ compiler = sparse_compiler.SparseCompiler(
+ options=sparsification_options, opt_level=0, shared_libs=[support_lib]
+ )
+ f64 = ir.F64Type.get()
+ # Be careful about increasing this because
+ # len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
+ shape = range(2, 3)
+ rank = len(shape)
+ # All combinations.
+ # TODO: add singleton here too; which requires updating how `np_arg0`
+ # is initialized below.
+ levels = list(
+ itertools.product(
+ *itertools.repeat(
+ [st.DimLevelType.dense, st.DimLevelType.compressed], rank
+ )
+ )
+ )
+ # All permutations.
+ orderings = list(
+ map(ir.AffineMap.get_permutation, itertools.permutations(range(rank)))
+ )
+ bitwidths = [0]
+ # The first type must be a dense tensor for numpy conversion to work.
+ types = [ir.RankedTensorType.get(shape, f64)]
+ for level in levels:
+ for ordering in orderings:
+ for pwidth in bitwidths:
+ for iwidth in bitwidths:
+ attr = st.EncodingAttr.get(
+ level, ordering, None, pwidth, iwidth
+ )
+ types.append(ir.RankedTensorType.get(shape, f64, attr))
+ #
+ # For exhaustiveness we should have one or more StressTest, such
+ # that their paths cover all 2*n*(n-1) directed pairwise combinations
+ # of the `types` set. However, since n is already superexponential,
+ # such exhaustiveness would be prohibitive for a test that runs on
+ # every commit. So for now we'll just pick one particular path that
+ # at least hits all n elements of the `types` set.
+ #
+ tyconv = TypeConverter(ctx)
+ size = 1
+ for d in shape:
+ size *= d
+ np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
+ np_out = (
+ StressTest(tyconv)
+ .build(types)
+ .writeTo(sys.argv[1] if len(sys.argv) > 1 else None)
+ .compile(compiler)
+ .writeTo(sys.argv[2] if len(sys.argv) > 2 else None)
+ .run(np_arg0)
+ )
+ # CHECK: Passed
+ if np.allclose(np_out, np_arg0):
+ print("Passed")
+ else:
+ sys.exit("FAILURE")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
index f5b0ab60e85e9..785d42cadbbe9 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
@@ -11,65 +11,71 @@
@functools.lru_cache()
def _get_c_shared_lib(lib_name: str):
- """Loads and returns the requested C shared library.
+ """Loads and returns the requested C shared library.
- Args:
- lib_name: A string representing the C shared library.
+ Args:
+ lib_name: A string representing the C shared library.
- Returns:
- The C shared library.
+ Returns:
+ The C shared library.
- Raises:
- OSError: If there is any problem in loading the shared library.
- ValueError: If the shared library doesn't contain the needed routine.
- """
- # This raises OSError exception if there is any problem in loading the shared
- # library.
- c_lib = ctypes.CDLL(lib_name)
+ Raises:
+ OSError: If there is any problem in loading the shared library.
+ ValueError: If the shared library doesn't contain the needed routine.
+ """
+ # This raises OSError exception if there is any problem in loading the shared
+ # library.
+ c_lib = ctypes.CDLL(lib_name)
- try:
- c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
- except Exception as e:
- raise ValueError('Missing function convertFromMLIRSparseTensorF64 from '
- f'the C shared library: {e} ') from e
+ try:
+ c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
+ except Exception as e:
+ raise ValueError(
+ "Missing function convertFromMLIRSparseTensorF64 from "
+ f"the C shared library: {e} "
+ ) from e
- return c_lib
+ return c_lib
def sparse_tensor_to_coo_tensor(support_lib, sparse, dtype):
- """Converts a sparse tensor to COO-flavored format.
+ """Converts a sparse tensor to COO-flavored format.
- Args:
- support_lib: A string for the supporting C shared library.
- sparse: A ctypes.pointer to the sparse tensor descriptor.
- dtype: The numpy data type for the tensor elements.
+ Args:
+ support_lib: A string for the supporting C shared library.
+ sparse: A ctypes.pointer to the sparse tensor descriptor.
+ dtype: The numpy data type for the tensor elements.
- Returns:
- A tuple that contains the following values:
- rank: An integer for the rank of the tensor.
- nse: An integer for the number of non-zero values in the tensor.
- shape: A 1D numpy array of integers, for the shape of the tensor.
- values: A 1D numpy array, for the non-zero values in the tensor.
- indices: A 2D numpy array of integers, representing the indices for the
- non-zero values in the tensor.
+ Returns:
+ A tuple that contains the following values:
+ rank: An integer for the rank of the tensor.
+ nse: An integer for the number of non-zero values in the tensor.
+ shape: A 1D numpy array of integers, for the shape of the tensor.
+ values: A 1D numpy array, for the non-zero values in the tensor.
+ indices: A 2D numpy array of integers, representing the indices for the
+ non-zero values in the tensor.
- Raises:
- OSError: If there is any problem in loading the shared library.
- ValueError: If the shared library doesn't contain the needed routine.
- """
- c_lib = _get_c_shared_lib(support_lib)
+ Raises:
+ OSError: If there is any problem in loading the shared library.
+ ValueError: If the shared library doesn't contain the needed routine.
+ """
+ c_lib = _get_c_shared_lib(support_lib)
- rank = ctypes.c_ulonglong(0)
- nse = ctypes.c_ulonglong(0)
- shape = ctypes.POINTER(ctypes.c_ulonglong)()
- values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
- indices = ctypes.POINTER(ctypes.c_ulonglong)()
- c_lib.convertFromMLIRSparseTensorF64(sparse, ctypes.byref(rank),
- ctypes.byref(nse), ctypes.byref(shape),
- ctypes.byref(values),
- ctypes.byref(indices))
- # Convert the returned values to the corresponding numpy types.
- shape = np.ctypeslib.as_array(shape, shape=[rank.value])
- values = np.ctypeslib.as_array(values, shape=[nse.value])
- indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
- return rank, nse, shape, values, indices
+ rank = ctypes.c_ulonglong(0)
+ nse = ctypes.c_ulonglong(0)
+ shape = ctypes.POINTER(ctypes.c_ulonglong)()
+ values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
+ indices = ctypes.POINTER(ctypes.c_ulonglong)()
+ c_lib.convertFromMLIRSparseTensorF64(
+ sparse,
+ ctypes.byref(rank),
+ ctypes.byref(nse),
+ ctypes.byref(shape),
+ ctypes.byref(values),
+ ctypes.byref(indices),
+ )
+ # Convert the returned values to the corresponding numpy types.
+ shape = np.ctypeslib.as_array(shape, shape=[rank.value])
+ values = np.ctypeslib.as_array(values, shape=[nse.value])
+ indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
+ return rank, nse, shape, values, indices
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
index 25004f9492dbc..d549a9a0954c6 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
@@ -9,30 +9,31 @@
from mlir import passmanager
from typing import Sequence
+
class SparseCompiler:
- """Sparse compiler class for compiling and building MLIR modules."""
-
- def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
- pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
- self.pipeline = pipeline
- self.opt_level = opt_level
- self.shared_libs = shared_libs
-
- def __call__(self, module: ir.Module):
- """Convenience application method."""
- self.compile(module)
-
- def compile(self, module: ir.Module):
- """Compiles the module by invoking the sparse copmiler pipeline."""
- passmanager.PassManager.parse(self.pipeline).run(module.operation)
-
- def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
- """Wraps the module in a JIT execution engine."""
- return execution_engine.ExecutionEngine(
- module, opt_level=self.opt_level, shared_libs=self.shared_libs)
-
- def compile_and_jit(self,
- module: ir.Module) -> execution_engine.ExecutionEngine:
- """Compiles and jits the module."""
- self.compile(module)
- return self.jit(module)
+ """Sparse compiler class for compiling and building MLIR modules."""
+
+ def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+ pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
+ self.pipeline = pipeline
+ self.opt_level = opt_level
+ self.shared_libs = shared_libs
+
+ def __call__(self, module: ir.Module):
+ """Convenience application method."""
+ self.compile(module)
+
+ def compile(self, module: ir.Module):
+ """Compiles the module by invoking the sparse copmiler pipeline."""
+ passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+ def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Wraps the module in a JIT execution engine."""
+ return execution_engine.ExecutionEngine(
+ module, opt_level=self.opt_level, shared_libs=self.shared_libs
+ )
+
+ def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Compiles and jits the module."""
+ self.compile(module)
+ return self.jit(module)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg
index 7137d0fba95f8..f1bbcf486bc27 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg
@@ -1,5 +1,5 @@
# Disable ASAN's leak detection for python taco tests.
-config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
+config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
# Only run when python bindings are enabled.
if not config.enable_bindings_python:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
index 88b13ae161b0d..2d558f8d6ddff 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
@@ -46,10 +46,10 @@
# Perform the MTTKRP computation and write the result to file.
with tempfile.TemporaryDirectory() as test_dir:
- golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
- out_file = os.path.join(test_dir, "A.tns")
- pt.write(out_file, A)
- #
- # CHECK: Compare result True
- #
- print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
+ golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
+ out_file = os.path.join(test_dir, "A.tns")
+ pt.write(out_file, A)
+ #
+ # CHECK: Compare result True
+ #
+ print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
index ba4ea9c1e6d0a..ef94ea9900fe4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
@@ -46,13 +46,13 @@
# Force evaluation of the kernels by writing out X and Y.
with tempfile.TemporaryDirectory() as test_dir:
- x_file = os.path.join(test_dir, "X.tns")
- y_file = os.path.join(test_dir, "Y.tns")
- pt.write(x_file, X)
- pt.write(y_file, Y)
- #
- # CHECK: Compare result True True
- #
- x_data = utils.file_as_string(x_file)
- y_data = utils.file_as_string(y_file)
- print(f"Compare result {x_data == expected} {y_data == expected}")
+ x_file = os.path.join(test_dir, "X.tns")
+ y_file = os.path.join(test_dir, "Y.tns")
+ pt.write(x_file, X)
+ pt.write(y_file, Y)
+ #
+ # CHECK: Compare result True True
+ #
+ x_data = utils.file_as_string(x_file)
+ y_data = utils.file_as_string(y_file)
+ print(f"Compare result {x_data == expected} {y_data == expected}")
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
index 10309cb706f78..02bbbc096e7a3 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
@@ -26,10 +26,10 @@
# Force evaluation of the kernel by writing out C.
with tempfile.TemporaryDirectory() as test_dir:
- golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
- out_file = os.path.join(test_dir, "C.tns")
- pt.write(out_file, C)
- #
- # CHECK: Compare result True
- #
- print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
+ golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
+ out_file = os.path.join(test_dir, "C.tns")
+ pt.write(out_file, C)
+ #
+ # CHECK: Compare result True
+ #
+ print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
index de150ea68efd5..2038a473ae530 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
@@ -47,10 +47,10 @@
# Perform the SpMV computation and write the result to file
with tempfile.TemporaryDirectory() as test_dir:
- golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
- out_file = os.path.join(test_dir, "y.tns")
- pt.write(out_file, y)
- #
- # CHECK: Compare result True
- #
- print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
+ golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
+ out_file = os.path.join(test_dir, "y.tns")
+ pt.write(out_file, y)
+ #
+ # CHECK: Compare result True
+ #
+ print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py
index c1e6c87940a02..cd24e0dbb0a43 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py
@@ -18,11 +18,10 @@
alpha = pt.tensor(42.0)
# Set up some sparse tensors with
diff erent dim annotations and ordering.
-S = pt.tensor([8, 8, 8],
- pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
-X = pt.tensor([8, 8, 8],
- pt.format([pt.compressed, pt.compressed, pt.compressed],
- [1, 0, 2]))
+S = pt.tensor([8, 8, 8], pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
+X = pt.tensor(
+ [8, 8, 8], pt.format([pt.compressed, pt.compressed, pt.compressed], [1, 0, 2])
+)
S.insert([0, 0, 0], 2.0)
S.insert([1, 1, 1], 3.0)
S.insert([4, 4, 4], 4.0)
@@ -32,16 +31,14 @@
# Set up tensors with a dense last dimension. This results in a full
# enveloping storage of all last "rows" with one or more nonzeros.
-T = pt.tensor([1, 2, 3, 4, 5],
- pt.format([
- pt.compressed, pt.compressed, pt.compressed, pt.compressed,
- pt.dense
- ]))
-Y = pt.tensor([1, 2, 3, 4, 5],
- pt.format([
- pt.compressed, pt.compressed, pt.compressed, pt.compressed,
- pt.dense
- ]))
+T = pt.tensor(
+ [1, 2, 3, 4, 5],
+ pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
+)
+Y = pt.tensor(
+ [1, 2, 3, 4, 5],
+ pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
+)
T.insert([0, 1, 2, 3, 4], -2.0)
Y[i, j, k, l, m] = alpha[0] * T[i, j, k, l, m]
@@ -85,18 +82,18 @@
# Force evaluation of the kernel by writing out X.
with tempfile.TemporaryDirectory() as test_dir:
- x_file = os.path.join(test_dir, 'X.tns')
- pt.write(x_file, X)
- y_file = os.path.join(test_dir, 'Y.tns')
- pt.write(y_file, Y)
- z_file = os.path.join(test_dir, 'Z.tns')
- pt.write(z_file, Z)
- #
- # CHECK: Compare result True True True
- #
- x_data = utils.file_as_string(x_file)
- y_data = utils.file_as_string(y_file)
- z_data = utils.file_as_string(z_file)
- print(
- f'Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}'
- )
+ x_file = os.path.join(test_dir, "X.tns")
+ pt.write(x_file, X)
+ y_file = os.path.join(test_dir, "Y.tns")
+ pt.write(y_file, Y)
+ z_file = os.path.join(test_dir, "Z.tns")
+ pt.write(z_file, Z)
+ #
+ # CHECK: Compare result True True True
+ #
+ x_data = utils.file_as_string(x_file)
+ y_data = utils.file_as_string(y_file)
+ z_data = utils.file_as_string(z_file)
+ print(
+ f"Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}"
+ )
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
index 60b91de42157e..206ffa9316d48 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
@@ -12,7 +12,7 @@
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3])
-S = pt.tensor(3) # S is a scalar tensor.
+S = pt.tensor(3) # S is a scalar tensor.
B = pt.tensor([2, 3], compressed)
A.insert([0, 1], 10)
A.insert([1, 2], 40)
@@ -26,11 +26,11 @@
# Sum all the values in A.
S[0] = A[i, j]
-passed += (S.get_scalar_value() == 50.0)
+passed += S.get_scalar_value() == 50.0
indices, values = S.get_coordinates_and_values()
-passed += (len(indices)==0)
-passed += (values == 50.0)
+passed += len(indices) == 0
+passed += values == 50.0
# CHECK: Number of passed: 5
print("Number of passed:", passed)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
index 8fd545b91710f..b0fed50f8b5db 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
@@ -12,20 +12,20 @@
passed = 0
all_types = [pt.complex64, pt.complex128]
for t in all_types:
- i, j = pt.get_index_vars(2)
- A = pt.tensor([2, 3], dtype=t)
- B = pt.tensor([2, 3], dtype=t)
- C = pt.tensor([2, 3], compressed, dtype=t)
- A.insert([0, 1], 10 + 20j)
- A.insert([1, 2], 40 + 0.5j)
- B.insert([0, 0], 20)
- B.insert([1, 2], 30 + 15j)
- C[i, j] = A[i, j] + B[i, j]
+ i, j = pt.get_index_vars(2)
+ A = pt.tensor([2, 3], dtype=t)
+ B = pt.tensor([2, 3], dtype=t)
+ C = pt.tensor([2, 3], compressed, dtype=t)
+ A.insert([0, 1], 10 + 20j)
+ A.insert([1, 2], 40 + 0.5j)
+ B.insert([0, 0], 20)
+ B.insert([1, 2], 30 + 15j)
+ C[i, j] = A[i, j] + B[i, j]
- indices, values = C.get_coordinates_and_values()
- passed += isinstance(values[0], t.value)
- passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
- passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
+ indices, values = C.get_coordinates_and_values()
+ passed += isinstance(values[0], t.value)
+ passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+ passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
# CHECK: Number of passed: 6
print("Number of passed:", passed)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
index cec687ff4de5c..4ba2836dd4616 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
@@ -12,24 +12,22 @@
dense = pt.dense
passed = 0
-all_types = [
- pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64
-]
+all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64]
for t in all_types:
- i, j = pt.get_index_vars(2)
- A = pt.tensor([2, 3], dtype=t)
- B = pt.tensor([2, 3], dtype=t)
- C = pt.tensor([2, 3], compressed, dtype=t)
- A.insert([0, 1], 10)
- A.insert([1, 2], 40)
- B.insert([0, 0], 20)
- B.insert([1, 2], 30)
- C[i, j] = A[i, j] + B[i, j]
+ i, j = pt.get_index_vars(2)
+ A = pt.tensor([2, 3], dtype=t)
+ B = pt.tensor([2, 3], dtype=t)
+ C = pt.tensor([2, 3], compressed, dtype=t)
+ A.insert([0, 1], 10)
+ A.insert([1, 2], 40)
+ B.insert([0, 0], 20)
+ B.insert([1, 2], 30)
+ C[i, j] = A[i, j] + B[i, j]
- indices, values = C.get_coordinates_and_values()
- passed += isinstance(values[0], t.value)
- passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
- passed += np.allclose(values, [20.0, 10.0, 70.0])
+ indices, values = C.get_coordinates_and_values()
+ passed += isinstance(values[0], t.value)
+ passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+ passed += np.allclose(values, [20.0, 10.0, 70.0])
# CHECK: Number of passed: 21
print("Number of passed:", passed)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py
index a138678a02e86..78bce344e3b6f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py
@@ -10,8 +10,8 @@
i, j = pt.get_index_vars(2)
# Both tensors are true dense tensors.
-A = pt.from_array(np.full([2,3], 1, dtype=np.float64))
-B = pt.from_array(np.full([2,3], 2, dtype=np.float64))
+A = pt.from_array(np.full([2, 3], 1, dtype=np.float64))
+B = pt.from_array(np.full([2, 3], 2, dtype=np.float64))
# Define the result tensor as a true dense tensor. The parameter is_dense=True
# is an MLIR-PyTACO extension.
C = pt.tensor([2, 3], dtype=pt.float64, is_dense=True)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 44d28b08d8b30..b3194f7edecd5 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -65,19 +65,20 @@
class Type(enum.Enum):
- """The data types supported by TACO.
+ """The data types supported by TACO.
- We use numpy data types to implement the enum data types.
- """
- INT8 = np.int8
- INT16 = np.int16
- INT32 = np.int32
- INT64 = np.int64
- FLOAT16 = np.float16
- FLOAT32 = np.float32
- FLOAT64 = np.float64
- COMPLEX64 = np.complex64
- COMPLEX128 = np.complex128
+ We use numpy data types to implement the enum data types.
+ """
+
+ INT8 = np.int8
+ INT16 = np.int16
+ INT32 = np.int32
+ INT64 = np.int64
+ FLOAT16 = np.float16
+ FLOAT32 = np.float32
+ FLOAT64 = np.float64
+ COMPLEX64 = np.complex64
+ COMPLEX128 = np.complex128
# All floating point type enums.
@@ -88,1732 +89,1810 @@ class Type(enum.Enum):
_COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128)
# Type alias for any numpy type used to implement the runtime support for the
# enum data types.
-_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16,
- np.float32, np.float64, np.complex64, np.complex128]
+_AnyRuntimeType = Union[
+ np.int8,
+ np.int16,
+ np.int32,
+ np.int64,
+ np.float16,
+ np.float32,
+ np.float64,
+ np.complex64,
+ np.complex128,
+]
@dataclasses.dataclass(frozen=True)
class DType:
- """The data type class.
+ """The data type class.
- We support the TACO API dtype class with an alias of this class.
+ We support the TACO API dtype class with an alias of this class.
- The following methods are defined by the TACO API:
- is_float: Returns whether the data type represents a floating point value.
- is_int: Returns whether the data type represents an integral value.
+ The following methods are defined by the TACO API:
+ is_float: Returns whether the data type represents a floating point value.
+ is_int: Returns whether the data type represents an integral value.
- Attributes:
- kind: A Type enum representing the data type.
- value: The numpy data type for the TACO data type.
- """
- kind: Type = Type.FLOAT32
+ Attributes:
+ kind: A Type enum representing the data type.
+ value: The numpy data type for the TACO data type.
+ """
- def is_float(self) -> bool:
- """Returns whether the data type represents a floating point value."""
- return self.kind in _FLOAT_TYPES
+ kind: Type = Type.FLOAT32
- def is_int(self) -> bool:
- """Returns whether the data type represents an integral value."""
- return self.kind in _INT_TYPES
+ def is_float(self) -> bool:
+ """Returns whether the data type represents a floating point value."""
+ return self.kind in _FLOAT_TYPES
- def is_complex(self) -> bool:
- """Returns whether the data type represents a complex value."""
- return self.kind in _COMPLEX_TYPES
+ def is_int(self) -> bool:
+ """Returns whether the data type represents an integral value."""
+ return self.kind in _INT_TYPES
- @property
- def value(self) -> _AnyRuntimeType:
- """Returns the numpy dtype for the data type."""
- return self.kind.value
+ def is_complex(self) -> bool:
+ """Returns whether the data type represents a complex value."""
+ return self.kind in _COMPLEX_TYPES
+
+ @property
+ def value(self) -> _AnyRuntimeType:
+ """Returns the numpy dtype for the data type."""
+ return self.kind.value
def _dtype_to_mlir_str(dtype: DType) -> str:
- """Returns the MLIR string for the given dtype."""
- dtype_to_str = {
- Type.INT16: "i8",
- Type.INT16: "i16",
- Type.INT32: "i32",
- Type.INT64: "i64",
- Type.FLOAT16: "f16",
- Type.FLOAT32: "f32",
- Type.FLOAT64: "f64",
- Type.COMPLEX64: "complex<f32>",
- Type.COMPLEX128: "complex<f64>"
- }
- return dtype_to_str[dtype.kind]
+ """Returns the MLIR string for the given dtype."""
+ dtype_to_str = {
+ Type.INT16: "i8",
+ Type.INT16: "i16",
+ Type.INT32: "i32",
+ Type.INT64: "i64",
+ Type.FLOAT16: "f16",
+ Type.FLOAT32: "f32",
+ Type.FLOAT64: "f64",
+ Type.COMPLEX64: "complex<f32>",
+ Type.COMPLEX128: "complex<f64>",
+ }
+ return dtype_to_str[dtype.kind]
def _nptype_to_taco_type(ty: np.dtype) -> DType:
- """Returns the TACO type for the given numpy type."""
- nptype_to_dtype = {
- np.int8: Type.INT8,
- np.int16: Type.INT16,
- np.int32: Type.INT32,
- np.int64: Type.INT64,
- np.float16: Type.FLOAT16,
- np.float32: Type.FLOAT32,
- np.float64: Type.FLOAT64,
- np.complex64: Type.COMPLEX64,
- np.complex128: Type.COMPLEX128
- }
- return DType(nptype_to_dtype[ty])
+ """Returns the TACO type for the given numpy type."""
+ nptype_to_dtype = {
+ np.int8: Type.INT8,
+ np.int16: Type.INT16,
+ np.int32: Type.INT32,
+ np.int64: Type.INT64,
+ np.float16: Type.FLOAT16,
+ np.float32: Type.FLOAT32,
+ np.float64: Type.FLOAT64,
+ np.complex64: Type.COMPLEX64,
+ np.complex128: Type.COMPLEX128,
+ }
+ return DType(nptype_to_dtype[ty])
def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
- """Returns the MLIR type corresponding to the given TACO type."""
- dtype_to_irtype = {
- Type.INT8: ir.IntegerType.get_signless(8),
- Type.INT16: ir.IntegerType.get_signless(16),
- Type.INT32: ir.IntegerType.get_signless(32),
- Type.INT64: ir.IntegerType.get_signless(64),
- Type.FLOAT16: ir.F16Type.get(),
- Type.FLOAT32: ir.F32Type.get(),
- Type.FLOAT64: ir.F64Type.get(),
- Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),
- Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get())
- }
- return dtype_to_irtype[dtype.kind]
+ """Returns the MLIR type corresponding to the given TACO type."""
+ dtype_to_irtype = {
+ Type.INT8: ir.IntegerType.get_signless(8),
+ Type.INT16: ir.IntegerType.get_signless(16),
+ Type.INT32: ir.IntegerType.get_signless(32),
+ Type.INT64: ir.IntegerType.get_signless(64),
+ Type.FLOAT16: ir.F16Type.get(),
+ Type.FLOAT32: ir.F32Type.get(),
+ Type.FLOAT64: ir.F64Type.get(),
+ Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),
+ Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get()),
+ }
+ return dtype_to_irtype[dtype.kind]
+
def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer:
- """Returns the ctype pointer for the given numpy array."""
- return ctypes.pointer(
- ctypes.pointer(runtime.get_ranked_memref_descriptor(array)))
+ """Returns the ctype pointer for the given numpy array."""
+ return ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(array)))
class ModeFormat(enum.Enum):
- """The tensor dimension storage format class.
+ """The tensor dimension storage format class.
- We support the TACO API mode_format class with an alias of this class.
+ We support the TACO API mode_format class with an alias of this class.
+
+ In TACO, a tensor dimension is called a mode and the storage format for a
+ tensor dimension is called a mode format.
+ """
- In TACO, a tensor dimension is called a mode and the storage format for a
- tensor dimension is called a mode format.
- """
- DENSE = sparse_tensor.DimLevelType.dense
- COMPRESSED = sparse_tensor.DimLevelType.compressed
+ DENSE = sparse_tensor.DimLevelType.dense
+ COMPRESSED = sparse_tensor.DimLevelType.compressed
-def _mode_format_operation(a: ModeFormat, b: ModeFormat,
- op: _LogicalOp) -> ModeFormat:
- """Implements the given operator on ModeFormat."""
- return (ModeFormat.COMPRESSED
- if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED) else
- ModeFormat.DENSE)
+def _mode_format_operation(a: ModeFormat, b: ModeFormat, op: _LogicalOp) -> ModeFormat:
+ """Implements the given operator on ModeFormat."""
+ return (
+ ModeFormat.COMPRESSED
+ if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED)
+ else ModeFormat.DENSE
+ )
def _mode_format_estimator(op: _BinaryOp) -> _ModeFormatOp:
- """Produces a ModeFormat operator for the given binary operator.
+ """Produces a ModeFormat operator for the given binary operator.
- The ModeFormat operator is used as a heuristic to derive the destination
- dimension sparsity from the source dimension sparsity. In particular, if the
- binary operator produces a disjunction of the zero values from its source
- operands, such as the MUL operator, we return a ModeFormat operator that
- uses operator.or_. That is, we estimate that a dimension for the MUL
- operation result to be sparse if either of its source operands is sparse.
+ The ModeFormat operator is used as a heuristic to derive the destination
+ dimension sparsity from the source dimension sparsity. In particular, if the
+ binary operator produces a disjunction of the zero values from its source
+ operands, such as the MUL operator, we return a ModeFormat operator that
+ uses operator.or_. That is, we estimate that a dimension for the MUL
+ operation result to be sparse if either of its source operands is sparse.
- On the other hand, if the binary operator produces a conjunction of the
- zero values from its source operands, such as the ADD operator, we return
- a ModeFormat operator that uses operator.and_. In this case, we estimate
- that a dimension for the ADD operation result to be sparse if both of its
- source operands are sparse.
+ On the other hand, if the binary operator produces a conjunction of the
+ zero values from its source operands, such as the ADD operator, we return
+ a ModeFormat operator that uses operator.and_. In this case, we estimate
+ that a dimension for the ADD operation result to be sparse if both of its
+ source operands are sparse.
- Args:
- op: A _BinaryOp object representing a supporting operator on tensors.
+ Args:
+ op: A _BinaryOp object representing a supporting operator on tensors.
- Returns:
- A ModeFormatOp for estimating the destination dimension sparsity from
- the source dimension sparsity.
- """
- conjunction = functools.partial(_mode_format_operation, op=operator.and_)
- disjunction = functools.partial(_mode_format_operation, op=operator.or_)
- return conjunction if op(0, 1) != 0 else disjunction
+ Returns:
+ A ModeFormatOp for estimating the destination dimension sparsity from
+ the source dimension sparsity.
+ """
+ conjunction = functools.partial(_mode_format_operation, op=operator.and_)
+ disjunction = functools.partial(_mode_format_operation, op=operator.or_)
+ return conjunction if op(0, 1) != 0 else disjunction
def _all_instance_of(collection: Iterable, cls: Any) -> bool:
- """Returns true if all elements of the iterable is an instance of cls."""
- return all(isinstance(e, cls) for e in collection)
+ """Returns true if all elements of the iterable is an instance of cls."""
+ return all(isinstance(e, cls) for e in collection)
def _identity_ordering(rank: int) -> List[int]:
- """Returns the identity ordering for tensor of given rank."""
- return list(range(rank))
+ """Returns the identity ordering for tensor of given rank."""
+ return list(range(rank))
@dataclasses.dataclass(frozen=True)
class ModeOrdering:
- """The tensor dimension ordering class.
-
- We support the TACO API mode_ordering class with an alias of this class.
+ """The tensor dimension ordering class.
- Attributes:
- ordering: A list of integers representing the ordering of the tensor
- dimensions.
- """
- ordering: List[int]
+ We support the TACO API mode_ordering class with an alias of this class.
- def __post_init__(self) -> None:
- """Verifies the value in ordering.
-
- Raises:
- ValueError: If ordering is not a list of integers.
+ Attributes:
+ ordering: A list of integers representing the ordering of the tensor
+ dimensions.
"""
- if (not isinstance(self.ordering, list) or
- not _all_instance_of(self.ordering, int)):
- raise ValueError("Ordering must be a list of integers: "
- f"{self.ordering}")
- # Check that ordering is a permutation of the dimension numbers.
- if sorted(self.ordering) != _identity_ordering(self.rank()):
- raise ValueError(f"Invalid ordering: {self.ordering} != "
- f"permutation{_identity_ordering(self.rank())}.")
-
- def rank(self) -> int:
- """Returns the number of dimensions represented by the ordering."""
- return len(self.ordering)
+ ordering: List[int]
- at dataclasses.dataclass(frozen=True)
-class ModeFormatPack:
- """The tensor dimension format class.
+ def __post_init__(self) -> None:
+ """Verifies the value in ordering.
- We support the TACO API mode_format_pack class with an alias of this class.
+ Raises:
+ ValueError: If ordering is not a list of integers.
+ """
+ if not isinstance(self.ordering, list) or not _all_instance_of(
+ self.ordering, int
+ ):
+ raise ValueError("Ordering must be a list of integers: " f"{self.ordering}")
+ # Check that ordering is a permutation of the dimension numbers.
+ if sorted(self.ordering) != _identity_ordering(self.rank()):
+ raise ValueError(
+ f"Invalid ordering: {self.ordering} != "
+ f"permutation{_identity_ordering(self.rank())}."
+ )
- The storage format of a tensor contains one mode_format for each tensor
- dimension.
+ def rank(self) -> int:
+ """Returns the number of dimensions represented by the ordering."""
+ return len(self.ordering)
- Attributes:
- formats: A list of ModeFormat representing the storage format for each of
- the tensor dimension.
- """
- formats: List[ModeFormat]
- def __post_init__(self) -> None:
- """Verifies the value in formats.
-
- Raises:
- ValueError: If formats is not a list of ModeFormats.
- """
- if (not isinstance(self.formats, list) or
- not _all_instance_of(self.formats, ModeFormat)):
- raise ValueError("Formats must be a list of ModeFormat: "
- f"{self.formats}")
-
- def rank(self) -> int:
- """Returns the number of dimensions represented by the format pack."""
- return len(self.formats)
+ at dataclasses.dataclass(frozen=True)
+class ModeFormatPack:
+ """The tensor dimension format class.
+ We support the TACO API mode_format_pack class with an alias of this class.
- at dataclasses.dataclass
-class Format:
- """The tensor format class defined by the TACO API.
-
- Attributes:
- format_pack: A ModeFormatPack representing the storage format for the tensor
- dimensions.
- ordering: A ModeOrdering representing the tensor dimension ordering in the
- storage.
- """
- format_pack: ModeFormatPack
- ordering: Optional[ModeOrdering] = None
-
- def __post_init__(self) -> None:
- """Verifies and fixes up the values in format_pack and ordering.
-
- Verifies and fixes up the values in format_pack and ordering to supports the
- initializer syntax defined by the TACO API. If format_pack is a list of
- ModeFormat, replaces it with ModeFormatPack constructed from the list. If
- ordering is not provided, set ordering to the natural ordering for the rank
- corresponding to format_pack.
+ The storage format of a tensor contains one mode_format for each tensor
+ dimension.
- Raises:
- ValueError: If format_pack is not an instance of ModeFormatPack or if
- ordering is not an instance of ModeOrdering.
+ Attributes:
+ formats: A list of ModeFormat representing the storage format for each of
+ the tensor dimension.
"""
- if isinstance(self.format_pack, list):
- if not _all_instance_of(self.format_pack, ModeFormat):
- raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}")
- self.format_pack = ModeFormatPack(self.format_pack)
- if not isinstance(self.format_pack, ModeFormatPack):
- raise ValueError(f"Expected ModeFormatpack: {self.format_pack}")
-
- if self.ordering is None:
- self.ordering = ModeOrdering(list(range(self.rank())))
- if isinstance(self.ordering, list):
- if not _all_instance_of(self.ordering, int):
- raise ValueError(f"Expected a list of integer: {self.ordering}")
- self.ordering = ModeOrdering(self.ordering)
- if not isinstance(self.ordering, ModeOrdering):
- raise ValueError(f"Expected ModeOrdering: {self.ordering}")
-
- if self.format_pack.rank() != self.ordering.rank():
- raise ValueError("Inconsistent ModeFormatPack and ModeOrdering: "
- f"len({self.format_pack}) != "
- f"len({self.ordering})")
-
- def rank(self) -> int:
- """Returns the number of dimensions represented by the format."""
- return self.format_pack.rank()
-
- def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]:
- """Constructs the numpy arrays for the permutation and sparsity."""
- perm = np.array(self.ordering.ordering, dtype=np.ulonglong)
- a = [f.value for f in self.format_pack.formats]
- sparse = np.array(a, dtype=np.uint8)
- return (perm, sparse)
-
- def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]:
- """Constructs the MLIR attributes for the tensor format."""
- order = (
- range(self.rank()) if
- (self.ordering is None) else self.ordering.ordering)
- mlir_storage_format = [f.value for f in self.format_pack.formats]
- return sparse_tensor.EncodingAttr.get(mlir_storage_format,
- ir.AffineMap.get_permutation(order),
- None, _POS_WIDTH, _CRD_WIDTH)
-
-
-def _make_format(formats: List[ModeFormat],
- ordering: Optional[List[int]] = None) -> Format:
- """Constructs a format from a list of ModeFormat and an optional ordering.
-
- Args:
- formats: A list of ModeFormat, one for each dimension of a tensor.
- ordering: An optional list of integer, for the ordering of the tensor
- dimensions. When an ordering is not given, the identity ordering is used.
-
- Returns:
- A tensor format object.
-
- Raises:
- ValueError: If formats is not a list of ModeFormat or the length of formats
- is not consistent with the len of ordering.
- """
- ordering = ordering or _identity_ordering(len(formats))
- return Format(ModeFormatPack(formats), ModeOrdering(ordering))
+ formats: List[ModeFormat]
-class IndexExpr(abc.ABC):
- """The index notation base class.
+ def __post_init__(self) -> None:
+ """Verifies the value in formats.
- We support the TACO API index_expression class with an alias of this class.
- """
+ Raises:
+ ValueError: If formats is not a list of ModeFormats.
+ """
+ if not isinstance(self.formats, list) or not _all_instance_of(
+ self.formats, ModeFormat
+ ):
+ raise ValueError("Formats must be a list of ModeFormat: " f"{self.formats}")
- def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
- """Verifies the RHS operand and returns a binary expression.
+ def rank(self) -> int:
+ """Returns the number of dimensions represented by the format pack."""
+ return len(self.formats)
- Args:
- rhs: The RHS of the binary operation, which could be any Python object
- from user inputs.
- op: A _BinaryOp object representing the binary operator.
- Raises:
- ValueError: If rhs is not an IndexExpr.
- """
- if not isinstance(rhs, IndexExpr):
- raise ValueError(f"Expected IndexExpr: {rhs}")
- return _BinaryExpr(op, self, rhs)
-
- def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
- """Build a unary expression.
+ at dataclasses.dataclass
+class Format:
+ """The tensor format class defined by the TACO API.
- Args:
- op: A _UnaryOp object representing the unary operation.
+ Attributes:
+ format_pack: A ModeFormatPack representing the storage format for the tensor
+ dimensions.
+ ordering: A ModeOrdering representing the tensor dimension ordering in the
+ storage.
"""
- return _UnaryExpr(op, self)
- def __add__(self, rhs) -> "_BinaryExpr":
- """Defines the operator +.
+ format_pack: ModeFormatPack
+ ordering: Optional[ModeOrdering] = None
+
+ def __post_init__(self) -> None:
+ """Verifies and fixes up the values in format_pack and ordering.
+
+ Verifies and fixes up the values in format_pack and ordering to supports the
+ initializer syntax defined by the TACO API. If format_pack is a list of
+ ModeFormat, replaces it with ModeFormatPack constructed from the list. If
+ ordering is not provided, set ordering to the natural ordering for the rank
+ corresponding to format_pack.
+
+ Raises:
+ ValueError: If format_pack is not an instance of ModeFormatPack or if
+ ordering is not an instance of ModeOrdering.
+ """
+ if isinstance(self.format_pack, list):
+ if not _all_instance_of(self.format_pack, ModeFormat):
+ raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}")
+ self.format_pack = ModeFormatPack(self.format_pack)
+ if not isinstance(self.format_pack, ModeFormatPack):
+ raise ValueError(f"Expected ModeFormatpack: {self.format_pack}")
+
+ if self.ordering is None:
+ self.ordering = ModeOrdering(list(range(self.rank())))
+ if isinstance(self.ordering, list):
+ if not _all_instance_of(self.ordering, int):
+ raise ValueError(f"Expected a list of integer: {self.ordering}")
+ self.ordering = ModeOrdering(self.ordering)
+ if not isinstance(self.ordering, ModeOrdering):
+ raise ValueError(f"Expected ModeOrdering: {self.ordering}")
+
+ if self.format_pack.rank() != self.ordering.rank():
+ raise ValueError(
+ "Inconsistent ModeFormatPack and ModeOrdering: "
+ f"len({self.format_pack}) != "
+ f"len({self.ordering})"
+ )
+
+ def rank(self) -> int:
+ """Returns the number of dimensions represented by the format."""
+ return self.format_pack.rank()
+
+ def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]:
+ """Constructs the numpy arrays for the permutation and sparsity."""
+ perm = np.array(self.ordering.ordering, dtype=np.ulonglong)
+ a = [f.value for f in self.format_pack.formats]
+ sparse = np.array(a, dtype=np.uint8)
+ return (perm, sparse)
+
+ def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]:
+ """Constructs the MLIR attributes for the tensor format."""
+ order = (
+ range(self.rank()) if (self.ordering is None) else self.ordering.ordering
+ )
+ mlir_storage_format = [f.value for f in self.format_pack.formats]
+ return sparse_tensor.EncodingAttr.get(
+ mlir_storage_format,
+ ir.AffineMap.get_permutation(order),
+ None,
+ _POS_WIDTH,
+ _CRD_WIDTH,
+ )
+
+
+def _make_format(
+ formats: List[ModeFormat], ordering: Optional[List[int]] = None
+) -> Format:
+ """Constructs a format from a list of ModeFormat and an optional ordering.
Args:
- rhs: The value being added, which could be any Python object from user
- inputs.
+ formats: A list of ModeFormat, one for each dimension of a tensor.
+ ordering: An optional list of integer, for the ordering of the tensor
+ dimensions. When an ordering is not given, the identity ordering is used.
Returns:
- A _BinaryExpr object representing the operation.
+ A tensor format object.
Raises:
- ValueError: If rhs is not an IndexExpr.
+ ValueError: If formats is not a list of ModeFormat or the length of formats
+ is not consistent with the len of ordering.
"""
- return self._verify_operand_and_build_expr(rhs, operator.add)
+ ordering = ordering or _identity_ordering(len(formats))
+ return Format(ModeFormatPack(formats), ModeOrdering(ordering))
- def __mul__(self, rhs) -> "_BinaryExpr":
- """Defines the operator *.
-
- Args:
- rhs: The value being multiplied, which could be any Python object from
- user inputs.
- Returns:
- A _BinaryExpr object representing the operation.
+class IndexExpr(abc.ABC):
+ """The index notation base class.
- Raises:
- ValueError: If rhs is not an IndexExpr.
+ We support the TACO API index_expression class with an alias of this class.
"""
- return self._verify_operand_and_build_expr(rhs, operator.mul)
- def __abs__(self) -> "_UnaryExpr":
- """Defines the operator abs.
-
- Returns:
- A _UnaryExpr object representing the operation.
- """
- return self._build_unary_expr(operator.abs)
+ def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
+ """Verifies the RHS operand and returns a binary expression.
+
+ Args:
+ rhs: The RHS of the binary operation, which could be any Python object
+ from user inputs.
+ op: A _BinaryOp object representing the binary operator.
+
+ Raises:
+ ValueError: If rhs is not an IndexExpr.
+ """
+ if not isinstance(rhs, IndexExpr):
+ raise ValueError(f"Expected IndexExpr: {rhs}")
+ return _BinaryExpr(op, self, rhs)
+
+ def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
+ """Build a unary expression.
+
+ Args:
+ op: A _UnaryOp object representing the unary operation.
+ """
+ return _UnaryExpr(op, self)
+
+ def __add__(self, rhs) -> "_BinaryExpr":
+ """Defines the operator +.
+
+ Args:
+ rhs: The value being added, which could be any Python object from user
+ inputs.
+
+ Returns:
+ A _BinaryExpr object representing the operation.
+
+ Raises:
+ ValueError: If rhs is not an IndexExpr.
+ """
+ return self._verify_operand_and_build_expr(rhs, operator.add)
+
+ def __mul__(self, rhs) -> "_BinaryExpr":
+ """Defines the operator *.
+
+ Args:
+ rhs: The value being multiplied, which could be any Python object from
+ user inputs.
+
+ Returns:
+ A _BinaryExpr object representing the operation.
+
+ Raises:
+ ValueError: If rhs is not an IndexExpr.
+ """
+ return self._verify_operand_and_build_expr(rhs, operator.mul)
+
+ def __abs__(self) -> "_UnaryExpr":
+ """Defines the operator abs.
+
+ Returns:
+ A _UnaryExpr object representing the operation.
+ """
+ return self._build_unary_expr(operator.abs)
+
+ def __neg__(self) -> "_UnaryExpr":
+ """Defines the operator neg.
+
+ Returns:
+ A _UnaryExpr object representing the operation.
+ """
+ return self._build_unary_expr(operator.neg)
+
+ def __sub__(self, rhs) -> "_BinaryExpr":
+ """Defines the operator -.
+
+ Args:
+ rhs: The value being subtracted, which could be any Python object from
+ user inputs.
+
+ Returns:
+ A _BinaryExpr object representing the operation.
+
+ Raises:
+ ValueError: If rhs is not an IndexExpr.
+ """
+ return self._verify_operand_and_build_expr(rhs, operator.sub)
+
+ @abc.abstractmethod
+ def _visit(
+ self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+ ) -> None:
+ """A post-order visitor.
+
+ Args:
+ func: A callable applied to each node in the expression tree.
+ args: The variable-length arguments passed to the callable. These
+ arguments are grouped as an iterable and will be unpacked before passing
+ to the callable. This is to enable the keyword argument only syntax
+ after this argument.
+ leaf_checker: A callable object to identify nodes that should be treated
+ as leaf nodes to support partial tree visiting.
+ """
+ pass
+
+ @abc.abstractmethod
+ def _emit_expression(
+ self,
+ expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
+ expr_to_info: _ExprInfoDict,
+ ) -> lang.ScalarExpression:
+ """Emits MLIR for the expression tree.
+
+ Args:
+ expr_to_opnd: A dictionary for looking up structured op input operands for
+ the input nodes of the structured op.
+ expr_to_info: A dictionary for looking up code generation information for
+ expressions.
+
+ Returns:
+ A linalg dialect ScalarExpression for the expression.
+ """
+ pass
+
+ @abc.abstractmethod
+ def dtype(self) -> DType:
+ """Returns the data type for the result of the expression."""
+ pass
+
+ def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
+ """Emits a structured op in the linalg dialect for the expression tree.
+
+ We define a DefineOpcallable in the domain specific language for the linalg
+ dialect and execute the callable to generate the structured op. Self is the
+ root of the expression tree for the structured op.
+
+ Args:
+ expr_to_info: A dictionary for looking up code generation information for
+ expressions.
+ """
+ op_info = expr_to_info[self].structop_info
+ op_name = op_info.dst_name
+ op_def = lang.LinalgOpDef(name=op_name)
+ op_callable = lang.DefinedOpCallable(op_name, op_def)
+
+ # Collect the input expression nodes for the structured op.
+ expr_inputs = []
+ self._visit(
+ _gather_structured_op_input,
+ (self, expr_to_info, expr_inputs),
+ leaf_checker=_is_structured_op_leaf,
+ )
+
+ # Create a linalg structured op operand for each input expression node and
+ # build a dictionary for looking up the information.
+ expr_to_input_opnd = {
+ e: _emit_structured_op_input(e, expr_to_info, op_def) for e in expr_inputs
+ }
+
+ # Emit the expression tree, which produces the value assigned to the
+ # destination tensor.
+ value = self._emit_expression(expr_to_input_opnd, expr_to_info)
+ # Emit the structured op representation for the destination tensor.
+ dst_opnd = _emit_operand(
+ op_def,
+ op_info.dst_indices,
+ op_info.dst_name,
+ lang.OperandKind.OUTPUT_TENSOR,
+ )
+ dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+ dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
+
+ expr_info = expr_to_info[self]
+ # If the structured op reduces some indices, explicitly represent the
+ # reduction. This is done by generating a ReduceFn for the dimensions being
+ # reduced in the linalg dialect and calling the function with the value
+ # being reduced. We only support add reduction currently.
+ if expr_info.reduce_indices:
+ reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
+ value = lang.ReduceFn.add[reduce_dims](value)
+
+ # Emit the assignment as a comprehension in the linalg dialect.
+ comp = lang.Comprehension((dst_use, value))
+ op_def.comprehensions.append(comp)
+
+ # The structured op in the linalg dialect requires an explicit
+ # initialization for the destination tensor. Emit MLIR to initialize the
+ # destination tensor.
+ init = op_info.emit_tensor_init()
+
+ # Collect MLIR values for the linalg input operands, with the assumption
+ # that dictionary preserves the insertion order.
+ args = [
+ expr_to_info[expr].mlir_value for expr, opnd in expr_to_input_opnd.items()
+ ]
+ # Execute the DefineOpcallable object for the linalg dialect operation to
+ # emit MLIR for the linalg structured op.
+ expr_info.mlir_value = op_callable(*args, outs=[init])
+
+ def _identify_structured_ops(
+ self,
+ expr_to_info: _ExprInfoDict,
+ dst: "Tensor",
+ dst_indices: Tuple["IndexVar", ...],
+ ) -> List["IndexExpr"]:
+ """Returns expression nodes for the roots of the identified structured ops.
+
+ A structured op in the linalg dialect only supports reduction performed on
+ the whole expression. If the expression tree contains reduction that are
+ performed on part of the expression tree, the expression tree needs to be
+ implemented with multiple structured ops. This routine identifies all the
+ expression nodes that contain reduction as the root of structured ops in the
+ linalg dialect.
+
+ Args:
+ expr_to_info: A dictionary for looking up code generation information for
+ expressions.
+ dst: A destination Tensor that accepts the value of the expression tree.
+ dst_indices: The indices used by the destination index expression.
+
+ Returns:
+ An ordered list of IndexExpr for the root expressions of the structured
+ ops, where child expressions go before parent expressions that use their
+ results.
+ """
+ reduce_indices = tuple(set(expr_to_info[self].src_indices) - set(dst_indices))
+ for reduce_index in reduce_indices:
+ _mark_structured_op_root(self, reduce_index, expr_to_info)
+
+ self._visit(_accumulate_reduce_indices, (expr_to_info,))
+ structop_roots = []
+ self._visit(_gather_structured_op, (expr_to_info, structop_roots))
+
+ # Handle the root of the top level expression.
+ if not structop_roots or structop_roots[-1] != self:
+ # The top level expression is not a reduction. Add the top level
+ # expression as a structured op root.
+ structop_roots.append(self)
+
+ # Use user specified information for the destination tensor to build an
+ # _StructOpInfo for the top level expression.
+ expr_to_info[self].structop_info = _StructOpInfo(
+ dst_indices, tuple(dst.shape), dst.dtype, dst.name, dst.format
+ )
+
+ return structop_roots
+
+ def _validate_and_collect_expr_info(
+ self,
+ dst: "Tensor",
+ dst_indices: Tuple["IndexVar", ...],
+ ) -> _ExprInfoDict:
+ """Propagates expression information for validation.
+
+ Propagates the indices used by child expression nodes to parent expression
+ nodes. Also collects and validates the sizes for the dimensions
+ corresponding to the indices.
+
+ Args:
+ dst: A destination Tensor that accepts the value of the expression tree.
+ dst_indices: The indices used by the destination index expression.
+
+ Raises:
+ ValueError if there is any inconsistency in indices or dimensional
+ values.
+
+ Returns:
+ A dictionary of (IndexExpr, _ExprInfo).
+ """
+ expr_to_info = {}
+ # Validate the expression tree and construct expression information.
+ self._visit(_validate_and_collect_expr_info, (expr_to_info,))
+
+ # Validate the destination dimension information.
+ info = expr_to_info[self]
+ index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
+ for (
+ i,
+ d,
+ ) in zip(dst_indices, dst.shape):
+ if i not in index_to_dim_info:
+ raise ValueError(
+ "Destination IndexVar not used in the " f"source expression: {i}"
+ )
+ else:
+ if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1:
+ raise ValueError(
+ f"Inconsistent destination dimension for {i}: "
+ f"{d} vs {index_to_dim_info[i].dim}"
+ )
+
+ return expr_to_info
+
+ def _emit_assignment(
+ self,
+ module: ir.Module,
+ dst: "Tensor",
+ dst_indices: Tuple["IndexVar", ...],
+ expr_to_info: _ExprInfoDict,
+ input_accesses: List["Access"],
+ ) -> None:
+ """Emits an MLIR function for assigning the expression to a tensor."""
+ input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
+
+ # Build the kernel for the operations.
+ with ir.InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
+ def linalg_funcop(*args):
+ # Set up the mapping from the Access nodes to their MLIR values.
+ for e, mlir in zip(input_accesses, args):
+ expr_to_info[e].mlir_value = mlir
+
+ # Emit structured ops in the linalg dialect to implement the assignment.
+ for structop_root in self._identify_structured_ops(
+ expr_to_info, dst, dst_indices
+ ):
+ structop_root._emit_structured_op(expr_to_info)
+ dst._record_stats(expr_to_info[structop_root].structop_info)
+
+ # The function returns the MLIR value of the root expression.
+ return expr_to_info[self].mlir_value
+
+ linalg_funcop.func_op.attributes[
+ "llvm.emit_c_interface"
+ ] = ir.UnitAttr.get()
+
+ def get_input_accesses(self) -> List["Access"]:
+ """Compute the list of input accesses for the expression."""
+ input_accesses = []
+ self._visit(_gather_input_accesses_index_vars, (input_accesses,))
+ return input_accesses
+
+ def compile(
+ self,
+ dst: "Tensor",
+ dst_indices: Tuple["IndexVar", ...],
+ ) -> execution_engine.ExecutionEngine:
+ """Compiles the tensor assignment dst[dst_indices] = expression.
+
+ Args:
+ dst: The destination tensor.
+ dst_indices: The tuple of IndexVar used to access the destination tensor.
+
+ Returns:
+ The execution engine for the tensor assignment.
+
+ Raises:
+ ValueError: If the expression is not proper or not supported.
+ """
+ expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
+ input_accesses = self.get_input_accesses()
+
+ # Build and compile the module to produce the execution engine.
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ self._emit_assignment(
+ module, dst, dst_indices, expr_to_info, input_accesses
+ )
+ engine = utils.compile_and_build_engine(module)
+
+ return engine
- def __neg__(self) -> "_UnaryExpr":
- """Defines the operator neg.
- Returns:
- A _UnaryExpr object representing the operation.
- """
- return self._build_unary_expr(operator.neg)
+class _AtomicCounter:
+ """An atomic counter."""
- def __sub__(self, rhs) -> "_BinaryExpr":
- """Defines the operator -.
+ def __init__(self):
+ self._counter = 0
+ self._counter_lock = threading.Lock()
- Args:
- rhs: The value being subtracted, which could be any Python object from
- user inputs.
+ def increment(self) -> int:
+ """Increments the counter by one and returns the old value."""
+ old_value = self._counter
+ with self._counter_lock:
+ self._counter = self._counter + 1
+ return old_value
- Returns:
- A _BinaryExpr object representing the operation.
- Raises:
- ValueError: If rhs is not an IndexExpr.
- """
- return self._verify_operand_and_build_expr(rhs, operator.sub)
-
- @abc.abstractmethod
- def _visit(self,
- func: _ExprVisitor,
- args,
- *,
- leaf_checker: _SubtreeLeafChecker = None) -> None:
- """A post-order visitor.
-
- Args:
- func: A callable applied to each node in the expression tree.
- args: The variable-length arguments passed to the callable. These
- arguments are grouped as an iterable and will be unpacked before passing
- to the callable. This is to enable the keyword argument only syntax
- after this argument.
- leaf_checker: A callable object to identify nodes that should be treated
- as leaf nodes to support partial tree visiting.
- """
- pass
+class IndexVar(IndexExpr):
+ """The tensor index class.
- @abc.abstractmethod
- def _emit_expression(
- self,
- expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
- expr_to_info: _ExprInfoDict,
- ) -> lang.ScalarExpression:
- """Emits MLIR for the expression tree.
+ We support the TACO API index_var class with an alias of this class.
- Args:
- expr_to_opnd: A dictionary for looking up structured op input operands for
- the input nodes of the structured op.
- expr_to_info: A dictionary for looking up code generation information for
- expressions.
+ An IndexVar object represents an index variable in tensor index notation.
- Returns:
- A linalg dialect ScalarExpression for the expression.
+ Attributes:
+ name: A unique string name of the IndexVar.
"""
- pass
- @abc.abstractmethod
- def dtype(self) -> DType:
- """Returns the data type for the result of the expression."""
- pass
+ _counter = _AtomicCounter()
- def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
- """Emits a structured op in the linalg dialect for the expression tree.
+ def __init__(self):
+ id = self._counter.increment()
+ self._name = f"{_TACO_INDEX_PREFIX}{id}"
- We define a DefineOpcallable in the domain specific language for the linalg
- dialect and execute the callable to generate the structured op. Self is the
- root of the expression tree for the structured op.
+ def __repr__(self) -> str:
+ return f"IndexVar(name={repr(self._name)})"
- Args:
- expr_to_info: A dictionary for looking up code generation information for
- expressions.
- """
- op_info = expr_to_info[self].structop_info
- op_name = op_info.dst_name
- op_def = lang.LinalgOpDef(name=op_name)
- op_callable = lang.DefinedOpCallable(op_name, op_def)
-
- # Collect the input expression nodes for the structured op.
- expr_inputs = []
- self._visit(
- _gather_structured_op_input,
- (self, expr_to_info, expr_inputs),
- leaf_checker=_is_structured_op_leaf,
- )
+ @property
+ def name(self) -> str:
+ """Returns the name of the IndexVar."""
+ return self._name
- # Create a linalg structured op operand for each input expression node and
- # build a dictionary for looking up the information.
- expr_to_input_opnd = {
- e: _emit_structured_op_input(e, expr_to_info, op_def)
- for e in expr_inputs
- }
+ def _visit(
+ self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+ ) -> None:
+ """A post-order visitor."""
+ if leaf_checker:
+ assert leaf_checker(self, *args)
+ func(self, *args)
- # Emit the expression tree, which produces the value assigned to the
- # destination tensor.
- value = self._emit_expression(expr_to_input_opnd, expr_to_info)
- # Emit the structured op representation for the destination tensor.
- dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
- lang.OperandKind.OUTPUT_TENSOR)
- dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
- dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
-
- expr_info = expr_to_info[self]
- # If the structured op reduces some indices, explicitly represent the
- # reduction. This is done by generating a ReduceFn for the dimensions being
- # reduced in the linalg dialect and calling the function with the value
- # being reduced. We only support add reduction currently.
- if expr_info.reduce_indices:
- reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
- value = lang.ReduceFn.add[reduce_dims](value)
-
- # Emit the assignment as a comprehension in the linalg dialect.
- comp = lang.Comprehension((dst_use, value))
- op_def.comprehensions.append(comp)
-
- # The structured op in the linalg dialect requires an explicit
- # initialization for the destination tensor. Emit MLIR to initialize the
- # destination tensor.
- init = op_info.emit_tensor_init()
-
- # Collect MLIR values for the linalg input operands, with the assumption
- # that dictionary preserves the insertion order.
- args = [
- expr_to_info[expr].mlir_value
- for expr, opnd in expr_to_input_opnd.items()
- ]
- # Execute the DefineOpcallable object for the linalg dialect operation to
- # emit MLIR for the linalg structured op.
- expr_info.mlir_value = op_callable(*args, outs=[init])
-
- def _identify_structured_ops(
- self,
- expr_to_info: _ExprInfoDict,
- dst: "Tensor",
- dst_indices: Tuple["IndexVar", ...],
- ) -> List["IndexExpr"]:
- """Returns expression nodes for the roots of the identified structured ops.
-
- A structured op in the linalg dialect only supports reduction performed on
- the whole expression. If the expression tree contains reduction that are
- performed on part of the expression tree, the expression tree needs to be
- implemented with multiple structured ops. This routine identifies all the
- expression nodes that contain reduction as the root of structured ops in the
- linalg dialect.
+ def _emit_expression(
+ self,
+ expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+ expr_to_info: _ExprInfoDict,
+ ) -> lang.ScalarExpression:
+ """Emits a index value casted to the data type of the tensor expression."""
+ dim = getattr(lang.D, self.name)
+ index = lang.index(dim)
+ int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index)
+ return lang.TypeFn.cast_unsigned(lang.T, int_value)
- Args:
- expr_to_info: A dictionary for looking up code generation information for
- expressions.
- dst: A destination Tensor that accepts the value of the expression tree.
- dst_indices: The indices used by the destination index expression.
+ def dtype(self) -> DType:
+ """Returns the data type for the index value.
- Returns:
- An ordered list of IndexExpr for the root expressions of the structured
- ops, where child expressions go before parent expressions that use their
- results.
- """
- reduce_indices = tuple(
- set(expr_to_info[self].src_indices) - set(dst_indices))
- for reduce_index in reduce_indices:
- _mark_structured_op_root(self, reduce_index, expr_to_info)
-
- self._visit(_accumulate_reduce_indices, (expr_to_info,))
- structop_roots = []
- self._visit(_gather_structured_op, (expr_to_info, structop_roots))
-
- # Handle the root of the top level expression.
- if not structop_roots or structop_roots[-1] != self:
- # The top level expression is not a reduction. Add the top level
- # expression as a structured op root.
- structop_roots.append(self)
-
- # Use user specified information for the destination tensor to build an
- # _StructOpInfo for the top level expression.
- expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
- tuple(dst.shape),
- dst.dtype, dst.name,
- dst.format)
-
- return structop_roots
-
- def _validate_and_collect_expr_info(
- self,
- dst: "Tensor",
- dst_indices: Tuple["IndexVar", ...],
- ) -> _ExprInfoDict:
- """Propagates expression information for validation.
-
- Propagates the indices used by child expression nodes to parent expression
- nodes. Also collects and validates the sizes for the dimensions
- corresponding to the indices.
+ This is unreachable for IndexVar.
+ """
+ assert 0
- Args:
- dst: A destination Tensor that accepts the value of the expression tree.
- dst_indices: The indices used by the destination index expression.
- Raises:
- ValueError if there is any inconsistency in indices or dimensional
- values.
+def get_index_vars(n: int) -> List[IndexVar]:
+ """Returns a list of n IndexVar.
- Returns:
- A dictionary of (IndexExpr, _ExprInfo).
- """
- expr_to_info = {}
- # Validate the expression tree and construct expression information.
- self._visit(_validate_and_collect_expr_info, (expr_to_info,))
-
- # Validate the destination dimension information.
- info = expr_to_info[self]
- index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
- for i, d, in zip(dst_indices, dst.shape):
- if i not in index_to_dim_info:
- raise ValueError("Destination IndexVar not used in the "
- f"source expression: {i}")
- else:
- if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1:
- raise ValueError(f"Inconsistent destination dimension for {i}: "
- f"{d} vs {index_to_dim_info[i].dim}")
-
- return expr_to_info
-
- def _emit_assignment(
- self,
- module: ir.Module,
- dst: "Tensor",
- dst_indices: Tuple["IndexVar", ...],
- expr_to_info: _ExprInfoDict,
- input_accesses: List["Access"],
- ) -> None:
- """Emits an MLIR function for assigning the expression to a tensor."""
- input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
-
- # Build the kernel for the operations.
- with ir.InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
- def linalg_funcop(*args):
- # Set up the mapping from the Access nodes to their MLIR values.
- for e, mlir in zip(input_accesses, args):
- expr_to_info[e].mlir_value = mlir
-
- # Emit structured ops in the linalg dialect to implement the assignment.
- for structop_root in self._identify_structured_ops(
- expr_to_info, dst, dst_indices):
- structop_root._emit_structured_op(expr_to_info)
- dst._record_stats(expr_to_info[structop_root].structop_info)
-
- # The function returns the MLIR value of the root expression.
- return expr_to_info[self].mlir_value
-
- linalg_funcop.func_op.attributes[
- "llvm.emit_c_interface"] = ir.UnitAttr.get()
-
- def get_input_accesses(self) -> List["Access"]:
- """Compute the list of input accesses for the expression."""
- input_accesses = []
- self._visit(_gather_input_accesses_index_vars, (input_accesses,))
- return input_accesses
-
- def compile(
- self,
- dst: "Tensor",
- dst_indices: Tuple["IndexVar", ...],
- ) -> execution_engine.ExecutionEngine:
- """Compiles the tensor assignment dst[dst_indices] = expression.
+ This routine is defined by the TACO API.
Args:
- dst: The destination tensor.
- dst_indices: The tuple of IndexVar used to access the destination tensor.
+ n: An integer representing the number of IndexVar to get.
Returns:
- The execution engine for the tensor assignment.
+ A list of IndexVar.
Raises:
- ValueError: If the expression is not proper or not supported.
+ ValueError: if n is not a positive integer.
"""
- expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
- input_accesses = self.get_input_accesses()
-
- # Build and compile the module to produce the execution engine.
- with ir.Context(), ir.Location.unknown():
- module = ir.Module.create()
- self._emit_assignment(module, dst, dst_indices, expr_to_info,
- input_accesses)
- engine = utils.compile_and_build_engine(module)
-
- return engine
-
-
-class _AtomicCounter:
- """An atomic counter."""
-
- def __init__(self):
- self._counter = 0
- self._counter_lock = threading.Lock()
-
- def increment(self) -> int:
- """Increments the counter by one and returns the old value."""
- old_value = self._counter
- with self._counter_lock:
- self._counter = self._counter + 1
- return old_value
-
-
-class IndexVar(IndexExpr):
- """The tensor index class.
-
- We support the TACO API index_var class with an alias of this class.
-
- An IndexVar object represents an index variable in tensor index notation.
-
- Attributes:
- name: A unique string name of the IndexVar.
- """
- _counter = _AtomicCounter()
-
- def __init__(self):
- id = self._counter.increment()
- self._name = f"{_TACO_INDEX_PREFIX}{id}"
-
- def __repr__(self) -> str:
- return f"IndexVar(name={repr(self._name)})"
-
- @property
- def name(self) -> str:
- """Returns the name of the IndexVar."""
- return self._name
-
- def _visit(self,
- func: _ExprVisitor,
- args,
- *,
- leaf_checker: _SubtreeLeafChecker = None) -> None:
- """A post-order visitor."""
- if leaf_checker:
- assert leaf_checker(self, *args)
- func(self, *args)
-
- def _emit_expression(
- self,
- expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
- expr_to_info: _ExprInfoDict,
- ) -> lang.ScalarExpression:
- """Emits a index value casted to the data type of the tensor expression."""
- dim = getattr(lang.D, self.name)
- index = lang.index(dim)
- int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index)
- return lang.TypeFn.cast_unsigned(lang.T, int_value)
-
- def dtype(self) -> DType:
- """Returns the data type for the index value.
-
- This is unreachable for IndexVar.
- """
- assert 0
-
-
-def get_index_vars(n: int) -> List[IndexVar]:
- """Returns a list of n IndexVar.
-
- This routine is defined by the TACO API.
-
- Args:
- n: An integer representing the number of IndexVar to get.
-
- Returns:
- A list of IndexVar.
-
- Raises:
- ValueError: if n is not a positive integer.
- """
- if not isinstance(n, int) or n <= 0:
- raise ValueError(f"Expected an integer: {n}.")
- # If lock contention ever becomes an issue, we could implement a bulk getter
- # that returns a range by only claiming the lock once.
- return [IndexVar() for i in range(n)]
+ if not isinstance(n, int) or n <= 0:
+ raise ValueError(f"Expected an integer: {n}.")
+ # If lock contention ever becomes an issue, we could implement a bulk getter
+ # that returns a range by only claiming the lock once.
+ return [IndexVar() for i in range(n)]
def _mlir_symbols_from_index_vars(
- index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.SymbolDef, ...]:
- """Returns a tuple of MLIR symbols for the given tuple of index_var."""
- return tuple(getattr(lang.S, i.name) for i in index_vars)
+ index_vars: Tuple[IndexVar, ...]
+) -> Tuple[lang.SymbolDef, ...]:
+ """Returns a tuple of MLIR symbols for the given tuple of index_var."""
+ return tuple(getattr(lang.S, i.name) for i in index_vars)
def _mlir_dimensions_from_index_vars(
- index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.DimDef, ...]:
- """Returns a tuple of MLIR dimensions for the given tuple of index_var."""
- return tuple(getattr(lang.D, i.name) for i in index_vars)
+ index_vars: Tuple[IndexVar, ...]
+) -> Tuple[lang.DimDef, ...]:
+ """Returns a tuple of MLIR dimensions for the given tuple of index_var."""
+ return tuple(getattr(lang.D, i.name) for i in index_vars)
def _mlir_tensor_type(
- dtype: DType, shape: Tuple[int, ...],
- attr: Optional[sparse_tensor.EncodingAttr]) -> ir.RankedTensorType:
- """Returns an MLIR tensor type.
-
- Args:
- dtype: An DType object for the element data type of the tensor.
- shape: A tuple of integer for the shape of the tensor.
- attr: An optional MLIR sparse tensor attribute, only provided if the tensor
- is a sparse tensor.
-
- Returns:
- An MLIR ranked tensor type.
- """
- ir_type = _mlir_type_from_taco_type(dtype)
- return ir.RankedTensorType.get(shape, ir_type, attr)
-
-
- at dataclasses.dataclass(frozen=True)
-class _StructOpInfo:
- """Information for generating a structured op in the linalg dialect.
-
- This information is associated with an expression node that serves as the
- root for an expression subtree implemented with a structured op.
-
- Attributes:
- dst_indices: A tuple of IndexVar, representing the result dimensions of the
- structured op. This is used to construct the temporary variable for the
- tensor to hold the structured op result.
- dst_dims: A tuple of int, representing the result shape of the structured
- op.
- dst_dtype: A DType representing the data type of the structured op result.
- dst_name: A string representing the name of the structured op result.
- dst_format: An optional Format object representing the destination tensor
- format. None represents a true dense tensor.
- """
- dst_indices: Tuple[IndexVar, ...]
- dst_dims: Tuple[int, ...]
- dst_dtype: DType
- dst_name: str
- dst_format: Optional[Format]
-
- def __post_init__(self) -> None:
- """Verifies the integrity of the attribute values."""
- assert len(self.dst_indices) == len(self.dst_dims)
-
- def emit_tensor_init(self) -> ir.RankedTensorType:
- """Returns an initialization for the destination tensor."""
- if self.dst_format is None or self.dst_format.rank() == 0:
- # Initialize the dense tensor.
- ir_type = _mlir_type_from_taco_type(self.dst_dtype)
- empty = tensor.EmptyOp(self.dst_dims, ir_type).result
- zero = arith.ConstantOp(ir_type, 0.0)
- return linalg.fill(zero, outs=[empty])
-
- # Initialize the sparse tensor.
- mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims,
- self.dst_format.mlir_tensor_attr())
- index_type = ir.IndexType.get()
- return bufferization.AllocTensorOp(mlir_type, [], None, None, None)
-
-
-class _Stats:
- """Information to describe how a tensor expression is implemented.
-
- Currently, we only record the temporary tensors introduced for splitting the
- original expression.
- """
-
- def __init__(self):
- self._temps = []
-
- def __repr__(self) -> str:
- return f"_Stats({repr(self._temps)})"
-
- def add_element(self, structop: _StructOpInfo):
- """Adds a temporary tensor."""
- self._temps.append(structop)
-
- def get_total(self) -> int:
- """Gets the total number of temporary tensors."""
- return len(self._temps)
-
- def _get_element(self, idx: int) -> _StructOpInfo:
- """Gets the ith temporary tensor."""
- assert idx < self.get_total()
- return self._temps[idx]
-
- def get_dimensions(self, idx: int) -> Tuple[int]:
- """Gets the dimensions for the ith temporary tensor."""
- return self._get_element(idx).dst_dims
-
- def get_formats(self, idx: int) -> Tuple[ModeFormat]:
- """Gets the ModeFormats for the ith temporary tensor."""
- return tuple(self._get_element(idx).dst_format.format_pack.formats)
-
-
-class _SparseValueInfo(enum.Enum):
- """Describes how a sparse tensor value is stored.
- _UNPACKED: The sparse tensor value is stored as (coordnates, values) in
- Python.
- _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR
- sparse tensor.
- """
- _UNPACKED = 0
- _PACKED = 1
-
-
- at dataclasses.dataclass(frozen=True)
-class _Assignment:
- """Records an assignment to a tensor T as T[indices] = expression."""
- indices: Tuple["IndexVar", ...]
- expression: "IndexExpr"
-
-
-class Tensor:
- """The tensor class.
-
- We support the TACO API tensor class with an alias of this class.
-
- This class is part of the TACO API with the following methods:
- insert: Inserts a value to the given coordinate in the tensor.
- to_array: Returns a numpy ndarray for the tensor.
-
- TACO API also defines the following arrtibutes for the class:
- dtype: A dtype object representing the data type of the tensor.
- format: A format object representing the storage format of the tensor.
- name: A string object representing the name of the tensor.
- order: An integral rank of the tensor.
- shape: A list of integers representing the shape of the tensor.
-
- We currently ignore the tensor dimension ordering for dense tensor.
- """
- _counter = _AtomicCounter()
-
- def _get_unique_name(self) -> str:
- """Returns a unique name for creating a new Tensor."""
- return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}"
-
- def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat],
- Format]) -> None:
- """Process the fmt argument for the Tensor constructor.
-
- Args:
- fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
- this argument is a ModeFormat, uses this ModeFormat for all the tensor
- dimensions. If this argument is a list of ModeFormat, the len of the
- list should equal to the rank of the tensor. If this argument is a
- format, uses it for the format of the tensor.
-
- Raises:
- ValueError: If fmt is not one of the expected type or is inconsistent
- with the rank of the tensor. This is because fmt could be an users
- input.
- """
- if isinstance(fmt, ModeFormat):
- self._format = _make_format([fmt] * self.order)
- elif isinstance(fmt, list):
- if len(fmt) == self.order and isinstance(fmt[0], ModeFormat):
- self._format = _make_format(fmt)
- else:
- raise ValueError("Inconsistent shape and format: "
- f"{self._shape}, {fmt}.")
- elif isinstance(fmt, Format):
- if fmt.rank() != self.order:
- raise ValueError("Inconsistent shape and format: "
- f"{self._shape}, {fmt}.")
- else:
- self._format = fmt
- else:
- raise ValueError(f"Invalid format argument: {fmt}.")
-
- def __init__(self,
- value_or_shape: Optional[Union[List[int], Tuple[int, ...],
- complex, float, int]] = None,
- fmt: Optional[Union[ModeFormat, List[ModeFormat],
- Format]] = None,
- dtype: Optional[DType] = None,
- name: Optional[str] = None,
- is_dense: bool = False):
- """The tensor constructor interface defined by TACO API.
-
- Args:
- value_or_shape: This argument is optional and can be int, float,
- List[int], or Tuple[int, ...]. If this argument is an int or float,
- creates a scalar tensor and initializes it with the value. If this
- argument is a list or tuple of int, uses it as the shape to create a
- tensor.
- fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
- this argument is a ModeFormat, uses this ModeFormat for all the tensor
- dimensions. If this argument is a list of ModeFormat, the len of the
- list should equal to the rank of the tensor. If this argument is a
- format, uses it for the format of the tensor.
- dtype: An object of dtype, representing the data type of the tensor.
- name: A string name of the tensor. If a name is not given, creates a
- unique name for the tensor.
- is_dense: A boolean variable to indicate whether the tensor is a dense
- tensor without any sparsity annotation.
-
- Raises:
- ValueError: If there is any inconsistency among the input arguments.
- """
- # Take care of the argument default values common to both sparse tensors
- # and dense tensors.
- dtype = dtype or DType(Type.FLOAT32)
- self._name = name or self._get_unique_name()
- self._assignment = None
- self._engine = None
- self._sparse_value_location = _SparseValueInfo._UNPACKED
- self._dense_storage = None
- self._dtype = dtype
-
- if is_dense:
- assert (fmt is None)
- assert (isinstance(value_or_shape, tuple) or isinstance(
- value_or_shape, list)) and _all_instance_of(value_or_shape, int)
- self._shape = value_or_shape
- self._format = None
- return
-
- fmt = fmt or ModeFormat.COMPRESSED
- # We currently use _coords and _values to host the sparse tensor value with
- # COO format, and _dense_storage to host the dense tensor value. We don't
- # support the conversion between the two storages.
- self._coords = []
- self._values = []
- self._stats = _Stats()
- if value_or_shape is None or isinstance(value_or_shape, int) or isinstance(
- value_or_shape, float) or isinstance(value_or_shape, complex):
- # Create a scalar tensor and ignore the fmt parameter.
- self._shape = []
- self._format = _make_format([], [])
- if value_or_shape is not None:
- self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value)
- elif (isinstance(value_or_shape, tuple) or isinstance(
- value_or_shape, list)) and _all_instance_of(value_or_shape, int):
- # Create a tensor with the specified shape and format.
- self._shape = list(value_or_shape)
- self._init_format(fmt)
- else:
- raise ValueError("Invalid first argument. "
- "Must be a tuple or list for a shape or a single value"
- f"if initializing a scalar tensor: {value_or_shape}.")
-
- def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None:
- """Records the MLIR sparse tensor pointer."""
- self._sparse_value_location = _SparseValueInfo._PACKED
- self._packed_sparse_value = pointer
-
- def is_unpacked(self) -> bool:
- """Returns true if the tensor value is not packed as MLIR sparse tensor."""
- return (self._sparse_value_location == _SparseValueInfo._UNPACKED)
-
- def unpack(self) -> None:
- """Unpacks the MLIR sparse tensor representation."""
- if self.is_dense() or self.is_unpacked():
- return
-
- # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
- # values and verify the values.
- rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
- self._packed_sparse_value, self._dtype.value)
- assert rank == self.order
- assert np.array_equal(self.shape, shape)
- assert nse == len(values)
- self._coords = indices
- self._values = values
- self._sparse_value_location = _SparseValueInfo._UNPACKED
-
- def __repr__(self) -> str:
- self._sync_value()
- self.unpack()
- value_str = (f"{repr(self._dense_storage)})" if self.is_dense() else
- f"{repr(self._coords)} {repr(self._values)})")
- return (f"Tensor(_name={repr(self._name)} "
- f"_dtype={repr(self._dtype)} : ") + value_str
-
- def insert(self, coords: List[int], val: Union[complex, float, int]) -> None:
- """Inserts a value to the given coordinate.
+ dtype: DType, shape: Tuple[int, ...], attr: Optional[sparse_tensor.EncodingAttr]
+) -> ir.RankedTensorType:
+ """Returns an MLIR tensor type.
Args:
- coords: A list of integer coordinates. The length of the list must be the
- same as the rank of the tensor.
- val: A value being inserted. It is either an integral or a floating point
- value. This value will be converted to the data type of the tensor.
-
- Raises:
- ValueError: When there is any problem in the parameters.
- """
- if self.is_dense():
- raise ValueError("Insert method is not supported for dense tensors.")
- if self._assignment != None or not self.is_unpacked():
- raise ValueError(
- "Can't use Insert method for a tensor constructed from a file.")
- if not isinstance(coords, list):
- raise ValueError(f"Non list coordinate detected: {coords}.")
- if not _all_instance_of(coords, int):
- raise ValueError(f"Non integer coordinate detected: {coords}.")
- if (len(coords) != self.order or
- any([c < 0 or c >= self._shape[i] for i, c in enumerate(coords)])):
- raise ValueError("Invalid coordinate for rank: "
- f"{self.order}, {coords}.")
-
- if not isinstance(val, int) and not isinstance(
- val, float) and not isinstance(val, complex):
- raise ValueError(f"Value is neither int nor float: {val}.")
-
- self._coords.append(tuple(coords))
- self._values.append(self._dtype.value(val))
-
- def is_dense(self) -> bool:
- """Returns true if the tensor doesn't have sparsity annotation."""
- return self.order == 0 or self._format is None
-
- def to_array(self) -> np.ndarray:
- """Returns the numpy array for the Tensor.
-
- This is currenly only implemented for dense Tensor.
- """
- if not self.is_dense():
- raise ValueError("Conversion from non-dense Tensor "
- "to numpy array not supported yet.")
-
- self._sync_value()
-
- return self._dense_storage
-
- @staticmethod
- def from_array(array: np.ndarray) -> "Tensor":
- """Returns a dense tensor with the value copied from the input array.
-
- We currently only support the conversion of float32 and float64 numpy arrays
- to Tensor.
-
- Args:
- array: The numpy array that provides the data type, shape and value for
- the tensor.
+ dtype: An DType object for the element data type of the tensor.
+ shape: A tuple of integer for the shape of the tensor.
+ attr: An optional MLIR sparse tensor attribute, only provided if the tensor
+ is a sparse tensor.
Returns:
- A Tensor object.
-
- Raises:
- ValueError if the data type of the numpy array is not supported.
+ An MLIR ranked tensor type.
"""
- if array.dtype != np.float32 and array.dtype != np.float64:
- raise ValueError(f"Expected floating point value type: {array.dtype}.")
- t = Tensor(
- array.shape,
- dtype=_nptype_to_taco_type(array.dtype.type),
- is_dense=True)
- t._dense_storage = np.copy(array)
- return t
-
- @staticmethod
- def from_coo(
- coordinates: List[Tuple[int, ...]],
- values: List[_AnyRuntimeType],
- fmt: Format,
- dtype: DType,
- ) -> "Tensor":
- """Converts coordinates and values to a sparse tensor representation.
+ ir_type = _mlir_type_from_taco_type(dtype)
+ return ir.RankedTensorType.get(shape, ir_type, attr)
- Args:
- coordinates: A list of coordinates with non-zero values.
- values: The non-zero values.
- fmt: The tensor storage format.
- dtype: The tensor element data type.
- Returns:
- A tensor with the given non-zero values and storage format. The shape of
- the tensor has the minimum size for each dimension to make the given
- coordinates valid.
- """
- assert (isinstance(coordinates, List) and
- _all_instance_of(coordinates, Tuple))
- assert (isinstance(values, List) and _all_instance_of(values, dtype.value))
- assert isinstance(fmt, Format)
-
- rank = fmt.rank()
- assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates)
-
- # Find the maximum coordinate value for each dimension.
- max_coordinate = list(map(max, zip(*coordinates)))
- # The size of each dimension is one more that such a maximum coordinate
- # value.
- shape = [c + 1 for c in max_coordinate]
- t = Tensor(shape, fmt, dtype=dtype)
- t._coords = coordinates
- t._values = values
-
- return tensor
-
- @staticmethod
- def from_file(
- filename: str,
- fmt: Format,
- dtype: DType,
- ) -> "Tensor":
- """Constructs a sparse tensor using the COO-flavored values from a file.
-
- Args:
- filename: A string for the name of the file that contains the sparse
- tensor data.
- fmt: The tensor storage format.
- dtype: The tensor element data type.
-
- Returns:
- A tensor with the given non-zero values and storage format. The tensor
- value is stored as an MLIR sparse tensor.
+ at dataclasses.dataclass(frozen=True)
+class _StructOpInfo:
+ """Information for generating a structured op in the linalg dialect.
+
+ This information is associated with an expression node that serves as the
+ root for an expression subtree implemented with a structured op.
+
+ Attributes:
+ dst_indices: A tuple of IndexVar, representing the result dimensions of the
+ structured op. This is used to construct the temporary variable for the
+ tensor to hold the structured op result.
+ dst_dims: A tuple of int, representing the result shape of the structured
+ op.
+ dst_dtype: A DType representing the data type of the structured op result.
+ dst_name: A string representing the name of the structured op result.
+ dst_format: An optional Format object representing the destination tensor
+ format. None represents a true dense tensor.
"""
- sparse_tensor, shape = utils.create_sparse_tensor(filename,
- fmt.format_pack.formats,
- _dtype_to_mlir_str(dtype))
- t = Tensor(shape.tolist(), fmt, dtype=dtype)
- t._set_packed_sparse_tensor(sparse_tensor)
-
- return t
- def to_file(self, filename: str) -> None:
- """Output the tensor value to a file.
+ dst_indices: Tuple[IndexVar, ...]
+ dst_dims: Tuple[int, ...]
+ dst_dtype: DType
+ dst_name: str
+ dst_format: Optional[Format]
+
+ def __post_init__(self) -> None:
+ """Verifies the integrity of the attribute values."""
+ assert len(self.dst_indices) == len(self.dst_dims)
+
+ def emit_tensor_init(self) -> ir.RankedTensorType:
+ """Returns an initialization for the destination tensor."""
+ if self.dst_format is None or self.dst_format.rank() == 0:
+ # Initialize the dense tensor.
+ ir_type = _mlir_type_from_taco_type(self.dst_dtype)
+ empty = tensor.EmptyOp(self.dst_dims, ir_type).result
+ zero = arith.ConstantOp(ir_type, 0.0)
+ return linalg.fill(zero, outs=[empty])
+
+ # Initialize the sparse tensor.
+ mlir_type = _mlir_tensor_type(
+ self.dst_dtype, self.dst_dims, self.dst_format.mlir_tensor_attr()
+ )
+ index_type = ir.IndexType.get()
+ return bufferization.AllocTensorOp(mlir_type, [], None, None, None)
- This method evaluates any pending assignment to the tensor and outputs the
- tensor value.
- Args:
- filename: A string file name.
+class _Stats:
+ """Information to describe how a tensor expression is implemented.
- Raises:
- ValueError: If the tensor is dense, or an unpacked sparse tensor.
+ Currently, we only record the temporary tensors introduced for splitting the
+ original expression.
"""
- self._sync_value()
-
- if self.is_dense():
- raise ValueError("Writing dense tensors without sparsity annotation to "
- "file is not supported.")
- if self.is_unpacked():
- raise ValueError("Writing unpacked sparse tensors to file is not "
- "supported.")
+ def __init__(self):
+ self._temps = []
- utils.output_sparse_tensor(self._packed_sparse_value, filename,
- self._format.format_pack.formats,
- _dtype_to_mlir_str(self._dtype))
+ def __repr__(self) -> str:
+ return f"_Stats({repr(self._temps)})"
- @property
- def dtype(self) -> DType:
- """Returns the data type for the Tensor."""
- return self._dtype
+ def add_element(self, structop: _StructOpInfo):
+ """Adds a temporary tensor."""
+ self._temps.append(structop)
- @property
- def format(self) -> Format:
- """Returns the storage format for the Tensor."""
- return self._format
+ def get_total(self) -> int:
+ """Gets the total number of temporary tensors."""
+ return len(self._temps)
- @property
- def name(self) -> str:
- """Returns the name for the Tensor."""
- return self._name
+ def _get_element(self, idx: int) -> _StructOpInfo:
+ """Gets the ith temporary tensor."""
+ assert idx < self.get_total()
+ return self._temps[idx]
- @property
- def order(self) -> int:
- """Returns the rank of the Tensor."""
- return len(self._shape)
+ def get_dimensions(self, idx: int) -> Tuple[int]:
+ """Gets the dimensions for the ith temporary tensor."""
+ return self._get_element(idx).dst_dims
- @property
- def shape(self) -> List[int]:
- """Returns the shape of the Tensor."""
- return self._shape
+ def get_formats(self, idx: int) -> Tuple[ModeFormat]:
+ """Gets the ModeFormats for the ith temporary tensor."""
+ return tuple(self._get_element(idx).dst_format.format_pack.formats)
- def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
- """Verifies and normalizes the indices to access the tensor.
- Args:
- indices: The index expression used to access a tensor, which could be any
- Python object from user inputs.
-
- Returns:
- A tuple of IndexVar.
-
- Raises:
- ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
- a tuple of IndexVar for other tensors.
+class _SparseValueInfo(enum.Enum):
+ """Describes how a sparse tensor value is stored.
+ _UNPACKED: The sparse tensor value is stored as (coordnates, values) in
+ Python.
+ _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR
+ sparse tensor.
"""
- if self.order == 0:
- if not isinstance(indices, int) or indices != 0:
- raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
- return ()
-
- if isinstance(indices, IndexVar):
- return (indices,)
- elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
- return indices
- raise ValueError(f"Expected IndexVars: {indices}")
+ _UNPACKED = 0
+ _PACKED = 1
- def __getitem__(self, key) -> "Access":
- """Verifies and processes a tensor access.
- In the tensor index notation, a tensor access T[i, j] is represented as
- retrieving a value with key (i, j) from the tensor object T in Python. This
- routine verifies the key for the tensor access and returns a tensor access
- object.
-
- Args:
- key: The key used to access the tensor, which could be any Python object
- from user inputs.
+ at dataclasses.dataclass(frozen=True)
+class _Assignment:
+ """Records an assignment to a tensor T as T[indices] = expression."""
- Returns:
- The corresponding tensor access object.
+ indices: Tuple["IndexVar", ...]
+ expression: "IndexExpr"
- Raises:
- ValueError: If key is not an IndexVar or a tuple of IndexVar.
- """
- indices = self._verify_and_normalize_indices(key)
- return Access(self, indices)
- def __setitem__(self, key, value) -> None:
- """Verifies and processes a tensor assignment.
+class Tensor:
+ """The tensor class.
- In the tensor index notation, a tensor assignment "T[i, j] = ..." is
- represented as setting a value for a tensor object T via key (i, j) in
- Python. This routine verifies the key, evaluates the value, and assigns the
- value to the tensor.
+ We support the TACO API tensor class with an alias of this class.
- We only support assignment of dense tensor currently.
+ This class is part of the TACO API with the following methods:
+ insert: Inserts a value to the given coordinate in the tensor.
+ to_array: Returns a numpy ndarray for the tensor.
- Args:
- key: The key used to access the tensor, which could be any Python object
- from user inputs.
- value: The value assigned to the tensor, which could be any Python object
- from user inputs.
+ TACO API also defines the following arrtibutes for the class:
+ dtype: A dtype object representing the data type of the tensor.
+ format: A format object representing the storage format of the tensor.
+ name: A string object representing the name of the tensor.
+ order: An integral rank of the tensor.
+ shape: A list of integers representing the shape of the tensor.
- Raises:
- ValueError: If tensor is not a dense tensor, or the key is not an IndexVar
- or a tuple of IndexVar, or the length of the indices is not the same as
- the rank of the tensor.
+ We currently ignore the tensor dimension ordering for dense tensor.
"""
- indices = self._verify_and_normalize_indices(key)
- if len(indices) != self.order:
- raise ValueError("Mismatch between indices and tensor rank: "
- f"len({indices}) != {self.order}.")
- self._assignment = _Assignment(indices, value)
- self._engine = None
-
- def compile(self, force_recompile: bool = False) -> None:
- """Compiles the tensor assignment to an execution engine.
-
- Calling compile the second time does not do anything unless
- force_recompile is True.
+ _counter = _AtomicCounter()
+
+ def _get_unique_name(self) -> str:
+ """Returns a unique name for creating a new Tensor."""
+ return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}"
+
+ def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat], Format]) -> None:
+ """Process the fmt argument for the Tensor constructor.
+
+ Args:
+ fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
+ this argument is a ModeFormat, uses this ModeFormat for all the tensor
+ dimensions. If this argument is a list of ModeFormat, the len of the
+ list should equal to the rank of the tensor. If this argument is a
+ format, uses it for the format of the tensor.
+
+ Raises:
+ ValueError: If fmt is not one of the expected type or is inconsistent
+ with the rank of the tensor. This is because fmt could be an users
+ input.
+ """
+ if isinstance(fmt, ModeFormat):
+ self._format = _make_format([fmt] * self.order)
+ elif isinstance(fmt, list):
+ if len(fmt) == self.order and isinstance(fmt[0], ModeFormat):
+ self._format = _make_format(fmt)
+ else:
+ raise ValueError(
+ "Inconsistent shape and format: " f"{self._shape}, {fmt}."
+ )
+ elif isinstance(fmt, Format):
+ if fmt.rank() != self.order:
+ raise ValueError(
+ "Inconsistent shape and format: " f"{self._shape}, {fmt}."
+ )
+ else:
+ self._format = fmt
+ else:
+ raise ValueError(f"Invalid format argument: {fmt}.")
+
+ def __init__(
+ self,
+ value_or_shape: Optional[
+ Union[List[int], Tuple[int, ...], complex, float, int]
+ ] = None,
+ fmt: Optional[Union[ModeFormat, List[ModeFormat], Format]] = None,
+ dtype: Optional[DType] = None,
+ name: Optional[str] = None,
+ is_dense: bool = False,
+ ):
+ """The tensor constructor interface defined by TACO API.
+
+ Args:
+ value_or_shape: This argument is optional and can be int, float,
+ List[int], or Tuple[int, ...]. If this argument is an int or float,
+ creates a scalar tensor and initializes it with the value. If this
+ argument is a list or tuple of int, uses it as the shape to create a
+ tensor.
+ fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
+ this argument is a ModeFormat, uses this ModeFormat for all the tensor
+ dimensions. If this argument is a list of ModeFormat, the len of the
+ list should equal to the rank of the tensor. If this argument is a
+ format, uses it for the format of the tensor.
+ dtype: An object of dtype, representing the data type of the tensor.
+ name: A string name of the tensor. If a name is not given, creates a
+ unique name for the tensor.
+ is_dense: A boolean variable to indicate whether the tensor is a dense
+ tensor without any sparsity annotation.
+
+ Raises:
+ ValueError: If there is any inconsistency among the input arguments.
+ """
+ # Take care of the argument default values common to both sparse tensors
+ # and dense tensors.
+ dtype = dtype or DType(Type.FLOAT32)
+ self._name = name or self._get_unique_name()
+ self._assignment = None
+ self._engine = None
+ self._sparse_value_location = _SparseValueInfo._UNPACKED
+ self._dense_storage = None
+ self._dtype = dtype
+
+ if is_dense:
+ assert fmt is None
+ assert (
+ isinstance(value_or_shape, tuple) or isinstance(value_or_shape, list)
+ ) and _all_instance_of(value_or_shape, int)
+ self._shape = value_or_shape
+ self._format = None
+ return
+
+ fmt = fmt or ModeFormat.COMPRESSED
+ # We currently use _coords and _values to host the sparse tensor value with
+ # COO format, and _dense_storage to host the dense tensor value. We don't
+ # support the conversion between the two storages.
+ self._coords = []
+ self._values = []
+ self._stats = _Stats()
+ if (
+ value_or_shape is None
+ or isinstance(value_or_shape, int)
+ or isinstance(value_or_shape, float)
+ or isinstance(value_or_shape, complex)
+ ):
+ # Create a scalar tensor and ignore the fmt parameter.
+ self._shape = []
+ self._format = _make_format([], [])
+ if value_or_shape is not None:
+ self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value)
+ elif (
+ isinstance(value_or_shape, tuple) or isinstance(value_or_shape, list)
+ ) and _all_instance_of(value_or_shape, int):
+ # Create a tensor with the specified shape and format.
+ self._shape = list(value_or_shape)
+ self._init_format(fmt)
+ else:
+ raise ValueError(
+ "Invalid first argument. "
+ "Must be a tuple or list for a shape or a single value"
+ f"if initializing a scalar tensor: {value_or_shape}."
+ )
+
+ def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None:
+ """Records the MLIR sparse tensor pointer."""
+ self._sparse_value_location = _SparseValueInfo._PACKED
+ self._packed_sparse_value = pointer
+
+ def is_unpacked(self) -> bool:
+ """Returns true if the tensor value is not packed as MLIR sparse tensor."""
+ return self._sparse_value_location == _SparseValueInfo._UNPACKED
+
+ def unpack(self) -> None:
+ """Unpacks the MLIR sparse tensor representation."""
+ if self.is_dense() or self.is_unpacked():
+ return
+
+ # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
+ # values and verify the values.
+ rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
+ self._packed_sparse_value, self._dtype.value
+ )
+ assert rank == self.order
+ assert np.array_equal(self.shape, shape)
+ assert nse == len(values)
+ self._coords = indices
+ self._values = values
+ self._sparse_value_location = _SparseValueInfo._UNPACKED
+
+ def __repr__(self) -> str:
+ self._sync_value()
+ self.unpack()
+ value_str = (
+ f"{repr(self._dense_storage)})"
+ if self.is_dense()
+ else f"{repr(self._coords)} {repr(self._values)})"
+ )
+ return (
+ f"Tensor(_name={repr(self._name)} " f"_dtype={repr(self._dtype)} : "
+ ) + value_str
+
+ def insert(self, coords: List[int], val: Union[complex, float, int]) -> None:
+ """Inserts a value to the given coordinate.
+
+ Args:
+ coords: A list of integer coordinates. The length of the list must be the
+ same as the rank of the tensor.
+ val: A value being inserted. It is either an integral or a floating point
+ value. This value will be converted to the data type of the tensor.
+
+ Raises:
+ ValueError: When there is any problem in the parameters.
+ """
+ if self.is_dense():
+ raise ValueError("Insert method is not supported for dense tensors.")
+ if self._assignment != None or not self.is_unpacked():
+ raise ValueError(
+ "Can't use Insert method for a tensor constructed from a file."
+ )
+ if not isinstance(coords, list):
+ raise ValueError(f"Non list coordinate detected: {coords}.")
+ if not _all_instance_of(coords, int):
+ raise ValueError(f"Non integer coordinate detected: {coords}.")
+ if len(coords) != self.order or any(
+ [c < 0 or c >= self._shape[i] for i, c in enumerate(coords)]
+ ):
+ raise ValueError("Invalid coordinate for rank: " f"{self.order}, {coords}.")
+
+ if (
+ not isinstance(val, int)
+ and not isinstance(val, float)
+ and not isinstance(val, complex)
+ ):
+ raise ValueError(f"Value is neither int nor float: {val}.")
+
+ self._coords.append(tuple(coords))
+ self._values.append(self._dtype.value(val))
+
+ def is_dense(self) -> bool:
+ """Returns true if the tensor doesn't have sparsity annotation."""
+ return self.order == 0 or self._format is None
+
+ def to_array(self) -> np.ndarray:
+ """Returns the numpy array for the Tensor.
+
+ This is currenly only implemented for dense Tensor.
+ """
+ if not self.is_dense():
+ raise ValueError(
+ "Conversion from non-dense Tensor " "to numpy array not supported yet."
+ )
+
+ self._sync_value()
+
+ return self._dense_storage
+
+ @staticmethod
+ def from_array(array: np.ndarray) -> "Tensor":
+ """Returns a dense tensor with the value copied from the input array.
+
+ We currently only support the conversion of float32 and float64 numpy arrays
+ to Tensor.
+
+ Args:
+ array: The numpy array that provides the data type, shape and value for
+ the tensor.
+
+ Returns:
+ A Tensor object.
+
+ Raises:
+ ValueError if the data type of the numpy array is not supported.
+ """
+ if array.dtype != np.float32 and array.dtype != np.float64:
+ raise ValueError(f"Expected floating point value type: {array.dtype}.")
+ t = Tensor(
+ array.shape, dtype=_nptype_to_taco_type(array.dtype.type), is_dense=True
+ )
+ t._dense_storage = np.copy(array)
+ return t
+
+ @staticmethod
+ def from_coo(
+ coordinates: List[Tuple[int, ...]],
+ values: List[_AnyRuntimeType],
+ fmt: Format,
+ dtype: DType,
+ ) -> "Tensor":
+ """Converts coordinates and values to a sparse tensor representation.
+
+ Args:
+ coordinates: A list of coordinates with non-zero values.
+ values: The non-zero values.
+ fmt: The tensor storage format.
+ dtype: The tensor element data type.
+
+ Returns:
+ A tensor with the given non-zero values and storage format. The shape of
+ the tensor has the minimum size for each dimension to make the given
+ coordinates valid.
+ """
+ assert isinstance(coordinates, List) and _all_instance_of(coordinates, Tuple)
+ assert isinstance(values, List) and _all_instance_of(values, dtype.value)
+ assert isinstance(fmt, Format)
+
+ rank = fmt.rank()
+ assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates)
+
+ # Find the maximum coordinate value for each dimension.
+ max_coordinate = list(map(max, zip(*coordinates)))
+ # The size of each dimension is one more that such a maximum coordinate
+ # value.
+ shape = [c + 1 for c in max_coordinate]
+ t = Tensor(shape, fmt, dtype=dtype)
+ t._coords = coordinates
+ t._values = values
+
+ return tensor
+
+ @staticmethod
+ def from_file(
+ filename: str,
+ fmt: Format,
+ dtype: DType,
+ ) -> "Tensor":
+ """Constructs a sparse tensor using the COO-flavored values from a file.
+
+ Args:
+ filename: A string for the name of the file that contains the sparse
+ tensor data.
+ fmt: The tensor storage format.
+ dtype: The tensor element data type.
+
+ Returns:
+ A tensor with the given non-zero values and storage format. The tensor
+ value is stored as an MLIR sparse tensor.
+ """
+ sparse_tensor, shape = utils.create_sparse_tensor(
+ filename, fmt.format_pack.formats, _dtype_to_mlir_str(dtype)
+ )
+ t = Tensor(shape.tolist(), fmt, dtype=dtype)
+ t._set_packed_sparse_tensor(sparse_tensor)
+
+ return t
+
+ def to_file(self, filename: str) -> None:
+ """Output the tensor value to a file.
+
+ This method evaluates any pending assignment to the tensor and outputs the
+ tensor value.
+
+ Args:
+ filename: A string file name.
+
+ Raises:
+ ValueError: If the tensor is dense, or an unpacked sparse tensor.
+ """
+ self._sync_value()
+
+ if self.is_dense():
+ raise ValueError(
+ "Writing dense tensors without sparsity annotation to "
+ "file is not supported."
+ )
+
+ if self.is_unpacked():
+ raise ValueError(
+ "Writing unpacked sparse tensors to file is not " "supported."
+ )
+
+ utils.output_sparse_tensor(
+ self._packed_sparse_value,
+ filename,
+ self._format.format_pack.formats,
+ _dtype_to_mlir_str(self._dtype),
+ )
+
+ @property
+ def dtype(self) -> DType:
+ """Returns the data type for the Tensor."""
+ return self._dtype
+
+ @property
+ def format(self) -> Format:
+ """Returns the storage format for the Tensor."""
+ return self._format
+
+ @property
+ def name(self) -> str:
+ """Returns the name for the Tensor."""
+ return self._name
+
+ @property
+ def order(self) -> int:
+ """Returns the rank of the Tensor."""
+ return len(self._shape)
+
+ @property
+ def shape(self) -> List[int]:
+ """Returns the shape of the Tensor."""
+ return self._shape
+
+ def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
+ """Verifies and normalizes the indices to access the tensor.
+
+ Args:
+ indices: The index expression used to access a tensor, which could be any
+ Python object from user inputs.
+
+ Returns:
+ A tuple of IndexVar.
+
+ Raises:
+ ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
+ a tuple of IndexVar for other tensors.
+ """
+ if self.order == 0:
+ if not isinstance(indices, int) or indices != 0:
+ raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
+ return ()
+
+ if isinstance(indices, IndexVar):
+ return (indices,)
+ elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
+ return indices
+
+ raise ValueError(f"Expected IndexVars: {indices}")
+
+ def __getitem__(self, key) -> "Access":
+ """Verifies and processes a tensor access.
+
+ In the tensor index notation, a tensor access T[i, j] is represented as
+ retrieving a value with key (i, j) from the tensor object T in Python. This
+ routine verifies the key for the tensor access and returns a tensor access
+ object.
+
+ Args:
+ key: The key used to access the tensor, which could be any Python object
+ from user inputs.
+
+ Returns:
+ The corresponding tensor access object.
+
+ Raises:
+ ValueError: If key is not an IndexVar or a tuple of IndexVar.
+ """
+ indices = self._verify_and_normalize_indices(key)
+ return Access(self, indices)
+
+ def __setitem__(self, key, value) -> None:
+ """Verifies and processes a tensor assignment.
+
+ In the tensor index notation, a tensor assignment "T[i, j] = ..." is
+ represented as setting a value for a tensor object T via key (i, j) in
+ Python. This routine verifies the key, evaluates the value, and assigns the
+ value to the tensor.
+
+ We only support assignment of dense tensor currently.
+
+ Args:
+ key: The key used to access the tensor, which could be any Python object
+ from user inputs.
+ value: The value assigned to the tensor, which could be any Python object
+ from user inputs.
+
+ Raises:
+ ValueError: If tensor is not a dense tensor, or the key is not an IndexVar
+ or a tuple of IndexVar, or the length of the indices is not the same as
+ the rank of the tensor.
+ """
+ indices = self._verify_and_normalize_indices(key)
+ if len(indices) != self.order:
+ raise ValueError(
+ "Mismatch between indices and tensor rank: "
+ f"len({indices}) != {self.order}."
+ )
+
+ self._assignment = _Assignment(indices, value)
+ self._engine = None
+
+ def compile(self, force_recompile: bool = False) -> None:
+ """Compiles the tensor assignment to an execution engine.
+
+ Calling compile the second time does not do anything unless
+ force_recompile is True.
+
+ Args:
+ force_recompile: A boolean value to enable recompilation, such as for the
+ purpose of timing.
+
+ Raises:
+ ValueError: If the assignment is not proper or not supported.
+ """
+ if self._assignment is None or (
+ self._engine is not None and not force_recompile
+ ):
+ return
+
+ self._engine = self._assignment.expression.compile(
+ self, self._assignment.indices
+ )
+
+ def compute(self) -> None:
+ """Executes the engine for the tensor assignment.
+
+ Raises:
+ ValueError: If the assignment hasn't been compiled yet.
+ """
+ if self._assignment is None:
+ return
+
+ if self._engine is None:
+ raise ValueError("Need to invoke compile() before invoking compute().")
+
+ input_accesses = self._assignment.expression.get_input_accesses()
+ # Gather the pointers for the input buffers.
+ input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
+ if self.is_dense():
+ # The pointer to receive dense output is the first argument to the
+ # execution engine.
+ arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers
+ else:
+ # The pointer to receive the sparse tensor output is the last argument
+ # to the execution engine and is a pointer to pointer of char.
+ arg_pointers = input_pointers + [
+ ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
+ ]
+
+ # Invoke the execution engine to run the module.
+ self._engine.invoke(_ENTRY_NAME, *arg_pointers)
+
+ # Retrieve the result.
+ if self.is_dense():
+ result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
+ assert isinstance(result, np.ndarray)
+ self._dense_storage = result
+ else:
+ self._set_packed_sparse_tensor(arg_pointers[-1][0])
+
+ self._assignment = None
+ self._engine = None
+
+ def evaluate(self) -> None:
+ """Evaluates the tensor assignment."""
+ self.compile()
+ self.compute()
+
+ def _sync_value(self) -> None:
+ """Updates the tensor value by evaluating the pending assignment."""
+ if self._assignment is not None:
+ self.evaluate()
+
+ def mlir_tensor_type(self) -> ir.RankedTensorType:
+ """Returns the MLIR type for the tensor."""
+ mlir_attr = (
+ None
+ if (self._format is None or self.order == 0)
+ else self._format.mlir_tensor_attr()
+ )
+ return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
+
+ def dense_dst_ctype_pointer(self) -> ctypes.pointer:
+ """Returns the ctypes pointer for the pointer to an MemRefDescriptor.
+
+ For a dense tensor output, the MLIR compiler allocates the storage for
+ the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for
+ receiving the tensor.
+ """
+ assert self.is_dense()
+ mem_ref_desc = runtime.make_nd_memref_descriptor(
+ self.order, np.ctypeslib.as_ctypes_type(self.dtype.value)
+ )()
+ return ctypes.pointer(ctypes.pointer(mem_ref_desc))
+
+ def ctype_pointer(self) -> ctypes.pointer:
+ """Returns the ctypes pointer for the pointer to the input tensor."""
+ if self.is_dense():
+ if self._dense_storage is None:
+ self._dense_storage = np.zeros(self._shape, self._dtype.value)
+ return _ctype_pointer_from_array(self._dense_storage)
+
+ if self.is_unpacked():
+ shape = np.array(self._shape, np.int64)
+ indices = np.array(self._coords, np.int64)
+ values = np.array(self._values, self._dtype.value)
+ perm, sparse = self.format.get_permutation_and_sparsity()
+ ptr = utils.coo_tensor_to_sparse_tensor(
+ shape, values, indices, perm, sparse
+ )
+ else:
+ ptr = self._packed_sparse_value
+
+ return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
+
+ def get_scalar_value(self) -> _AnyRuntimeType:
+ """Returns the value for the scalar tensor.
+
+ This method also evaluates the assignment to the tensor.
+
+ Raises:
+ ValueError: If the tensor is not a scalar.
+ """
+ if self.order != 0:
+ raise ValueError(f"Expected a scalar tensor, got: rank={self.order}")
+
+ self._sync_value()
+ return self._dense_storage
+
+ def get_coordinates_and_values(
+ self,
+ ) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]:
+ """Returns the coordinates and values for the non-zero elements.
+
+ This method also evaluates the assignment to the tensor and unpack the
+ sparse tensor.
+ """
+ self._sync_value()
+
+ if not self.is_dense():
+ self.unpack()
+ return (self._coords, self._values)
+
+ if self.order == 0:
+ return ([], self._dense_storage)
+
+ # Coordinates for non-zero elements, grouped by dimensions.
+ coords_by_dims = self._dense_storage.nonzero()
+ # Coordinates for non-zero elements, grouped by elements.
+ coords = np.transpose(coords_by_dims)
+ values = self._dense_storage[coords_by_dims]
+ return (coords, values)
+
+ def _record_stats(self, structop: "_StructOpInfo"):
+ """Collects information for temporary tensors."""
+ # Exclude user specified destination tensors.
+ if structop.dst_name == self.name:
+ return
+
+ self._stats.add_element(structop)
+
+
+def _emit_operand(
+ op_def: lang.LinalgOpDef,
+ indices: Tuple[IndexVar, ...],
+ name: str,
+ kind: lang.OperandKind,
+) -> lang.OperandDef:
+ """Emits an operand for a tensor access in the current linalg operation.
Args:
- force_recompile: A boolean value to enable recompilation, such as for the
- purpose of timing.
-
- Raises:
- ValueError: If the assignment is not proper or not supported.
- """
- if self._assignment is None or (self._engine is not None and
- not force_recompile):
- return
-
- self._engine = self._assignment.expression.compile(self,
- self._assignment.indices)
+ op_def: A LinalgOpDef representing the current linalg dialect operation.
+ indices: A tuple of IndexVar used to access the tensor.
+ name: A unique string name of the tensor.
+ kind: An OperandKind for the operand.
- def compute(self) -> None:
- """Executes the engine for the tensor assignment.
-
- Raises:
- ValueError: If the assignment hasn't been compiled yet.
- """
- if self._assignment is None:
- return
-
- if self._engine is None:
- raise ValueError("Need to invoke compile() before invoking compute().")
-
- input_accesses = self._assignment.expression.get_input_accesses()
- # Gather the pointers for the input buffers.
- input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
- if self.is_dense():
- # The pointer to receive dense output is the first argument to the
- # execution engine.
- arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers
- else:
- # The pointer to receive the sparse tensor output is the last argument
- # to the execution engine and is a pointer to pointer of char.
- arg_pointers = input_pointers + [
- ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
- ]
-
- # Invoke the execution engine to run the module.
- self._engine.invoke(_ENTRY_NAME, *arg_pointers)
-
- # Retrieve the result.
- if self.is_dense():
- result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
- assert isinstance(result, np.ndarray)
- self._dense_storage = result
- else:
- self._set_packed_sparse_tensor(arg_pointers[-1][0])
-
- self._assignment = None
- self._engine = None
-
- def evaluate(self) -> None:
- """Evaluates the tensor assignment."""
- self.compile()
- self.compute()
-
- def _sync_value(self) -> None:
- """Updates the tensor value by evaluating the pending assignment."""
- if self._assignment is not None:
- self.evaluate()
-
- def mlir_tensor_type(self) -> ir.RankedTensorType:
- """Returns the MLIR type for the tensor."""
- mlir_attr = (None if (self._format is None or self.order == 0) else
- self._format.mlir_tensor_attr())
- return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
-
- def dense_dst_ctype_pointer(self) -> ctypes.pointer:
- """Returns the ctypes pointer for the pointer to an MemRefDescriptor.
-
- For a dense tensor output, the MLIR compiler allocates the storage for
- the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for
- receiving the tensor.
- """
- assert self.is_dense()
- mem_ref_desc = runtime.make_nd_memref_descriptor(
- self.order, np.ctypeslib.as_ctypes_type(self.dtype.value))()
- return ctypes.pointer(ctypes.pointer(mem_ref_desc))
-
- def ctype_pointer(self) -> ctypes.pointer:
- """Returns the ctypes pointer for the pointer to the input tensor."""
- if self.is_dense():
- if self._dense_storage is None:
- self._dense_storage = np.zeros(self._shape, self._dtype.value)
- return _ctype_pointer_from_array(self._dense_storage)
-
- if self.is_unpacked():
- shape = np.array(self._shape, np.int64)
- indices = np.array(self._coords, np.int64)
- values = np.array(self._values, self._dtype.value)
- perm, sparse = self.format.get_permutation_and_sparsity()
- ptr = utils.coo_tensor_to_sparse_tensor(shape, values, indices, perm,
- sparse)
- else:
- ptr = self._packed_sparse_value
-
- return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
-
- def get_scalar_value(self) -> _AnyRuntimeType:
- """Returns the value for the scalar tensor.
-
- This method also evaluates the assignment to the tensor.
-
- Raises:
- ValueError: If the tensor is not a scalar.
- """
- if self.order != 0:
- raise ValueError(f"Expected a scalar tensor, got: rank={self.order}")
-
- self._sync_value()
- return self._dense_storage
-
-
- def get_coordinates_and_values(
- self) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]:
- """Returns the coordinates and values for the non-zero elements.
-
- This method also evaluates the assignment to the tensor and unpack the
- sparse tensor.
+ Returns:
+ An OperandDef representing the operand.
"""
- self._sync_value()
-
- if not self.is_dense():
- self.unpack()
- return (self._coords, self._values)
-
- if self.order == 0:
- return ([], self._dense_storage)
-
- # Coordinates for non-zero elements, grouped by dimensions.
- coords_by_dims = self._dense_storage.nonzero()
- # Coordinates for non-zero elements, grouped by elements.
- coords = np.transpose(coords_by_dims)
- values = self._dense_storage[coords_by_dims]
- return (coords, values)
-
- def _record_stats(self, structop: "_StructOpInfo"):
- """Collects information for temporary tensors."""
- # Exclude user specified destination tensors.
- if structop.dst_name == self.name:
- return
-
- self._stats.add_element(structop)
-
-
-def _emit_operand(op_def: lang.LinalgOpDef, indices: Tuple[IndexVar, ...],
- name: str, kind: lang.OperandKind) -> lang.OperandDef:
- """Emits an operand for a tensor access in the current linalg operation.
-
- Args:
- op_def: A LinalgOpDef representing the current linalg dialect operation.
- indices: A tuple of IndexVar used to access the tensor.
- name: A unique string name of the tensor.
- kind: An OperandKind for the operand.
-
- Returns:
- An OperandDef representing the operand.
- """
- dim_sym = _mlir_symbols_from_index_vars(indices)
- opnd = lang.OperandDef(kind, lang.T, dim_sym)
- op_def.add_operand(name, opnd)
- return opnd
+ dim_sym = _mlir_symbols_from_index_vars(indices)
+ opnd = lang.OperandDef(kind, lang.T, dim_sym)
+ op_def.add_operand(name, opnd)
+ return opnd
@dataclasses.dataclass(frozen=True)
class _DimInfo:
- """Information for an operand dimension.
+ """Information for an operand dimension.
- Attributes:
- dim: An integer for the size of the dimension.
- mode_format: A ModeFormat for the dimension sparsity.
- """
- dim: int
- mode_format: ModeFormat
+ Attributes:
+ dim: An integer for the size of the dimension.
+ mode_format: A ModeFormat for the dimension sparsity.
+ """
+
+ dim: int
+ mode_format: ModeFormat
def _get_dummy_dim_info() -> _DimInfo:
- """Constructs the _DimInfo for an index used in tensor expressions."""
- return _DimInfo(-1, ModeFormat.DENSE)
+ """Constructs the _DimInfo for an index used in tensor expressions."""
+ return _DimInfo(-1, ModeFormat.DENSE)
@dataclasses.dataclass()
class _ExprInfo:
- """Expression information for validation and code generation.
-
- Attributes:
- src_indices: A tuple of IndexVar for the indices used by the tensors in the
- expression tree.
- dim_infos: A tuple of _DimInfo, representing the dimension information
- corresponding to the src_indices.
- reduce_indices: A set of IndexVar for the indices reduced by the expression.
- acc_reduce_indices: An accumulated set of IndexVar for the indices reduced
- by the expression and its children.
- structop_info: Information to support the code generation for a structured
- op in the linalg dialect, if the corresponding expression node is the root
- of a subtree for a structured op.
- mlir_value: The MLIR value generated for the structured op.
- """
- src_indices: Tuple[IndexVar, ...]
- dim_infos: Tuple[_DimInfo, ...]
- reduce_indices: Optional[Set[IndexVar]] = None
- acc_reduce_indices: Optional[Set[IndexVar]] = None
- structop_info: Optional[_StructOpInfo] = None
- mlir_value: Optional[ir.Value] = None
-
- def __post_init__(self) -> None:
- """Verifies and fix up attribute values.
-
- Verifies the consistency of the attributes and modifies the default values
- to support convenient initializer syntax.
+ """Expression information for validation and code generation.
+
+ Attributes:
+ src_indices: A tuple of IndexVar for the indices used by the tensors in the
+ expression tree.
+ dim_infos: A tuple of _DimInfo, representing the dimension information
+ corresponding to the src_indices.
+ reduce_indices: A set of IndexVar for the indices reduced by the expression.
+ acc_reduce_indices: An accumulated set of IndexVar for the indices reduced
+ by the expression and its children.
+ structop_info: Information to support the code generation for a structured
+ op in the linalg dialect, if the corresponding expression node is the root
+ of a subtree for a structured op.
+ mlir_value: The MLIR value generated for the structured op.
"""
- assert len(self.src_indices) == len(self.dim_infos)
- self.reduce_indices = self.reduce_indices or set()
- self.acc_reduce_indices = self.acc_reduce_indices or set()
+ src_indices: Tuple[IndexVar, ...]
+ dim_infos: Tuple[_DimInfo, ...]
+ reduce_indices: Optional[Set[IndexVar]] = None
+ acc_reduce_indices: Optional[Set[IndexVar]] = None
+ structop_info: Optional[_StructOpInfo] = None
+ mlir_value: Optional[ir.Value] = None
- at dataclasses.dataclass(frozen=True)
-class Access(IndexExpr):
- """The tensor access class.
+ def __post_init__(self) -> None:
+ """Verifies and fix up attribute values.
+
+ Verifies the consistency of the attributes and modifies the default values
+ to support convenient initializer syntax.
+ """
+ assert len(self.src_indices) == len(self.dim_infos)
+ self.reduce_indices = self.reduce_indices or set()
+ self.acc_reduce_indices = self.acc_reduce_indices or set()
- We support the TACO API access class with an alias of this class.
- Attributes:
- tensor: A Tensor being accessed.
- indices: A tuple of IndexVar, representing the indices used to access the
- Tensor.
- """
- tensor: Tensor
- indices: Tuple[IndexVar, ...]
+ at dataclasses.dataclass(frozen=True)
+class Access(IndexExpr):
+ """The tensor access class.
- def __post_init__(self) -> None:
- """Verifies the tensor and indices for a tensor access.
+ We support the TACO API access class with an alias of this class.
- Raises:
- ValueError: If indices is not a list of IndexVar or the len of indices
- doesn't equal to the rank of the tensor.
+ Attributes:
+ tensor: A Tensor being accessed.
+ indices: A tuple of IndexVar, representing the indices used to access the
+ Tensor.
"""
- if (not isinstance(self.indices, tuple) or
- not _all_instance_of(self.indices, IndexVar)):
- raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.")
- if self.tensor.order != len(self.indices):
- raise ValueError("Invalid indices for rank: "
- f"str{self.tensor.order} != len({str(self.indices)}).")
-
- def __repr__(self) -> str:
- # The Tensor __repr__ method evaluates the pending assignment to the tensor.
- # We want to define the __repr__ method here to avoid such evaluation of the
- # tensor assignment.
- indices_str = ", ".join(map(lambda i: i.name, self.indices))
- return (f"Tensor({self.tensor.name}) " f"Indices({indices_str})")
-
- def _emit_expression(
- self,
- expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
- expr_to_info: _ExprInfoDict,
- ) -> lang.ScalarExpression:
- """Emits a linalg dialect TensorUse expression for the tensor access."""
- assert self in expr_to_opnd
- dims = _mlir_dimensions_from_index_vars(self.indices)
- return lang.TensorUse(expr_to_opnd[self], dims)
-
- def _visit(self,
- func: _ExprVisitor,
- args,
- *,
- leaf_checker: _SubtreeLeafChecker = None) -> None:
- if leaf_checker:
- assert leaf_checker(self, *args)
- func(self, *args)
-
- def dtype(self) -> DType:
- return self.tensor.dtype
+
+ tensor: Tensor
+ indices: Tuple[IndexVar, ...]
+
+ def __post_init__(self) -> None:
+ """Verifies the tensor and indices for a tensor access.
+
+ Raises:
+ ValueError: If indices is not a list of IndexVar or the len of indices
+ doesn't equal to the rank of the tensor.
+ """
+ if not isinstance(self.indices, tuple) or not _all_instance_of(
+ self.indices, IndexVar
+ ):
+ raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.")
+ if self.tensor.order != len(self.indices):
+ raise ValueError(
+ "Invalid indices for rank: "
+ f"str{self.tensor.order} != len({str(self.indices)})."
+ )
+
+ def __repr__(self) -> str:
+ # The Tensor __repr__ method evaluates the pending assignment to the tensor.
+ # We want to define the __repr__ method here to avoid such evaluation of the
+ # tensor assignment.
+ indices_str = ", ".join(map(lambda i: i.name, self.indices))
+ return f"Tensor({self.tensor.name}) " f"Indices({indices_str})"
+
+ def _emit_expression(
+ self,
+ expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+ expr_to_info: _ExprInfoDict,
+ ) -> lang.ScalarExpression:
+ """Emits a linalg dialect TensorUse expression for the tensor access."""
+ assert self in expr_to_opnd
+ dims = _mlir_dimensions_from_index_vars(self.indices)
+ return lang.TensorUse(expr_to_opnd[self], dims)
+
+ def _visit(
+ self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+ ) -> None:
+ if leaf_checker:
+ assert leaf_checker(self, *args)
+ func(self, *args)
+
+ def dtype(self) -> DType:
+ return self.tensor.dtype
def _gather_input_accesses_index_vars(
expr: IndexExpr,
input_accesses: List[Access],
) -> None:
- """Collects Access nodes."""
- if isinstance(expr, Access) and expr not in input_accesses:
- input_accesses.append(expr)
+ """Collects Access nodes."""
+ if isinstance(expr, Access) and expr not in input_accesses:
+ input_accesses.append(expr)
def _op_ceil(__a: Any) -> Any:
- """A _UnaryOp object for operation ceil."""
- pass
+ """A _UnaryOp object for operation ceil."""
+ pass
def _op_floor(__a: Any) -> Any:
- """A _UnaryOp object for operation floor."""
- pass
+ """A _UnaryOp object for operation floor."""
+ pass
def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType:
- """Returns the linalg dialect function object for the given operation."""
- op_to_callable = {
- operator.abs: lang.UnaryFn.abs,
- operator.neg: lang.UnaryFn.negf,
- _op_ceil: lang.UnaryFn.ceil,
- _op_floor: lang.UnaryFn.floor,
- }
- return op_to_callable[op]
+ """Returns the linalg dialect function object for the given operation."""
+ op_to_callable = {
+ operator.abs: lang.UnaryFn.abs,
+ operator.neg: lang.UnaryFn.negf,
+ _op_ceil: lang.UnaryFn.ceil,
+ _op_floor: lang.UnaryFn.floor,
+ }
+ return op_to_callable[op]
@dataclasses.dataclass(frozen=True)
class _UnaryExpr(IndexExpr):
- """The representation for a Unary operation.
-
- Attributes:
- op: A _UnaryOp representing the operation.
- a: An IndexExpr representing the operand for the operation.
- """
- op: _BinaryOp
- a: IndexExpr
-
- def __post_init__(self) -> None:
- """Verifies that the operand being added is an IndexExpr."""
- assert isinstance(self.a, IndexExpr)
-
- def _emit_expression(
- self,
- expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
- expr_to_info: _ExprInfoDict,
- ) -> lang.ScalarExpression:
- """Emits the expression tree and returns the expression."""
- # The current expression node is an internal node of the structured op.
- if self not in expr_to_opnd:
- a = self.a._emit_expression(expr_to_opnd, expr_to_info)
- return _op_unary_to_callable(self.op)(a)
-
- # The current expression is a leaf node of the structured op. That is, it is
- # a temporary tensor generated by its child structured op.
- op_info = expr_to_info[self].structop_info
- assert op_info is not None
- dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
- return lang.TensorUse(expr_to_opnd[self], dims)
-
- def _visit(self,
- func: _ExprVisitor,
- args,
- *,
- leaf_checker: _SubtreeLeafChecker = None) -> None:
- """A post-order visitor."""
- if leaf_checker is None or not leaf_checker(self, *args):
- self.a._visit(func, args, leaf_checker=leaf_checker)
- func(self, *args)
-
- def dtype(self) -> DType:
- """Returns the data type of the operation."""
- return self.a.dtype()
+ """The representation for a Unary operation.
+
+ Attributes:
+ op: A _UnaryOp representing the operation.
+ a: An IndexExpr representing the operand for the operation.
+ """
+
+ op: _BinaryOp
+ a: IndexExpr
+
+ def __post_init__(self) -> None:
+ """Verifies that the operand being added is an IndexExpr."""
+ assert isinstance(self.a, IndexExpr)
+
+ def _emit_expression(
+ self,
+ expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+ expr_to_info: _ExprInfoDict,
+ ) -> lang.ScalarExpression:
+ """Emits the expression tree and returns the expression."""
+ # The current expression node is an internal node of the structured op.
+ if self not in expr_to_opnd:
+ a = self.a._emit_expression(expr_to_opnd, expr_to_info)
+ return _op_unary_to_callable(self.op)(a)
+
+ # The current expression is a leaf node of the structured op. That is, it is
+ # a temporary tensor generated by its child structured op.
+ op_info = expr_to_info[self].structop_info
+ assert op_info is not None
+ dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+ return lang.TensorUse(expr_to_opnd[self], dims)
+
+ def _visit(
+ self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+ ) -> None:
+ """A post-order visitor."""
+ if leaf_checker is None or not leaf_checker(self, *args):
+ self.a._visit(func, args, leaf_checker=leaf_checker)
+ func(self, *args)
+
+ def dtype(self) -> DType:
+ """Returns the data type of the operation."""
+ return self.a.dtype()
def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
- """Returns the linalg dialect function object for the given operation."""
- op_to_callable = {
- operator.add: lang.BinaryFn.add,
- operator.sub: lang.BinaryFn.sub,
- operator.mul: lang.BinaryFn.mul,
- }
- return op_to_callable[op]
+ """Returns the linalg dialect function object for the given operation."""
+ op_to_callable = {
+ operator.add: lang.BinaryFn.add,
+ operator.sub: lang.BinaryFn.sub,
+ operator.mul: lang.BinaryFn.mul,
+ }
+ return op_to_callable[op]
+
@dataclasses.dataclass(frozen=True)
class _BinaryExpr(IndexExpr):
- """The representation for a binary operation.
-
- Attributes:
- op: A _BinaryOp representing the binary operation.
- a: An IndexExpr representing the first operand of the operation.
- b: An IndexExpr representing the second operand of the operation.
- """
- op: _BinaryOp
- a: IndexExpr
- b: IndexExpr
-
- def __post_init__(self) -> None:
- """Verifies that the operands being added are IndexExpr."""
- assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr)
-
- def _emit_expression(
- self,
- expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
- expr_to_info: _ExprInfoDict,
- ) -> lang.ScalarExpression:
- """Emits the expression tree and returns the expression."""
- # The current expression node is an internal node of the structured op.
- if self not in expr_to_opnd:
- a = self.a._emit_expression(expr_to_opnd, expr_to_info)
- b = self.b._emit_expression(expr_to_opnd, expr_to_info)
- return _op_to_callable(self.op)(a, b)
-
- # The current expression is a leaf node of the structured op. That is, it is
- # a temporary tensor generated by its child structured op.
- op_info = expr_to_info[self].structop_info
- assert op_info is not None
- dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
- return lang.TensorUse(expr_to_opnd[self], dims)
-
- def _visit(self,
- func: _ExprVisitor,
- args,
- *,
- leaf_checker: _SubtreeLeafChecker = None) -> None:
- """A post-order visitor."""
- if leaf_checker is None or not leaf_checker(self, *args):
- self.a._visit(func, args, leaf_checker=leaf_checker)
- self.b._visit(func, args, leaf_checker=leaf_checker)
- func(self, *args)
-
- def dtype(self) -> DType:
- """Returns the data type of the binary operation."""
- return self.a.dtype()
+ """The representation for a binary operation.
+
+ Attributes:
+ op: A _BinaryOp representing the binary operation.
+ a: An IndexExpr representing the first operand of the operation.
+ b: An IndexExpr representing the second operand of the operation.
+ """
+
+ op: _BinaryOp
+ a: IndexExpr
+ b: IndexExpr
+
+ def __post_init__(self) -> None:
+ """Verifies that the operands being added are IndexExpr."""
+ assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr)
+
+ def _emit_expression(
+ self,
+ expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+ expr_to_info: _ExprInfoDict,
+ ) -> lang.ScalarExpression:
+ """Emits the expression tree and returns the expression."""
+ # The current expression node is an internal node of the structured op.
+ if self not in expr_to_opnd:
+ a = self.a._emit_expression(expr_to_opnd, expr_to_info)
+ b = self.b._emit_expression(expr_to_opnd, expr_to_info)
+ return _op_to_callable(self.op)(a, b)
+
+ # The current expression is a leaf node of the structured op. That is, it is
+ # a temporary tensor generated by its child structured op.
+ op_info = expr_to_info[self].structop_info
+ assert op_info is not None
+ dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+ return lang.TensorUse(expr_to_opnd[self], dims)
+
+ def _visit(
+ self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+ ) -> None:
+ """A post-order visitor."""
+ if leaf_checker is None or not leaf_checker(self, *args):
+ self.a._visit(func, args, leaf_checker=leaf_checker)
+ self.b._visit(func, args, leaf_checker=leaf_checker)
+ func(self, *args)
+
+ def dtype(self) -> DType:
+ """Returns the data type of the binary operation."""
+ return self.a.dtype()
def _validate_and_collect_dim_info(
@@ -1822,105 +1901,104 @@ def _validate_and_collect_dim_info(
dim_infos: Tuple[_DimInfo, ...],
expr: _BinaryExpr,
) -> None:
- """Validates and collects the dimension information for an index notation.
-
- Validates (indices, dim_infos) against the information collected from other
- source operands and is represented by index_to_dim_info. In particular, we
- ensure that each IndexVar corresponds to only one dimension size. We also
- aggregate the new information represented in (indices, dim_infos) to
- index_to_dim_info.
-
- Args:
- index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the
- previous operands.
- indices: The IndexVars to be validated.
- dim_infos: The dimension information for the IndexVars to be validated.
- expr: The binary expression where (indices, dim_infos) is used.
-
- Raises:
- ValueError if there is any problem in the IndexVars or dimensional values.
- """
- assert len(indices) == len(dim_infos)
- for i, d in zip(indices, dim_infos):
- if i not in index_to_dim_info:
- index_to_dim_info[i] = d
- else:
- dim = index_to_dim_info[i].dim
- if dim == -1 or d.dim == -1:
- dim = dim if dim != -1 else d.dim
- elif dim != d.dim:
- raise ValueError(f"Inconsistent source dimension for {i}: "
- f"{d.dim} vs {dim}")
- mode_format = _mode_format_estimator(expr.op)(
- index_to_dim_info[i].mode_format, d.mode_format)
- index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
+ """Validates and collects the dimension information for an index notation.
+
+ Validates (indices, dim_infos) against the information collected from other
+ source operands and is represented by index_to_dim_info. In particular, we
+ ensure that each IndexVar corresponds to only one dimension size. We also
+ aggregate the new information represented in (indices, dim_infos) to
+ index_to_dim_info.
+
+ Args:
+ index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the
+ previous operands.
+ indices: The IndexVars to be validated.
+ dim_infos: The dimension information for the IndexVars to be validated.
+ expr: The binary expression where (indices, dim_infos) is used.
+
+ Raises:
+ ValueError if there is any problem in the IndexVars or dimensional values.
+ """
+ assert len(indices) == len(dim_infos)
+ for i, d in zip(indices, dim_infos):
+ if i not in index_to_dim_info:
+ index_to_dim_info[i] = d
+ else:
+ dim = index_to_dim_info[i].dim
+ if dim == -1 or d.dim == -1:
+ dim = dim if dim != -1 else d.dim
+ elif dim != d.dim:
+ raise ValueError(
+ f"Inconsistent source dimension for {i}: " f"{d.dim} vs {dim}"
+ )
+ mode_format = _mode_format_estimator(expr.op)(
+ index_to_dim_info[i].mode_format, d.mode_format
+ )
+ index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
def _validate_and_collect_expr_info(
expr: IndexExpr,
expr_to_info: _ExprInfoDict,
) -> None:
- """Validates dimension information and constructs _ExprInfo.
-
- Validates that dimensional values for the same IndexVar are the same. Collects
- a list of IndexVar used by the expression and their corresponding dimensional
- values. Constructs an _ExprInfo object to record the information for the
- IndexExpr.
-
- This routine is passed to the post-order visitor as an _ExprVisitor object.
-
- Args:
- expr: The IndexExpr being validated.
- expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
- expression information.
-
- Raises:
- ValueError if there is any problem in the IndexVars or dimensional values.
- """
- # Objects of class Access can be shared by
diff erent expressions. Avoid
- # processing Access objects multiple times by skipping the processing if expr
- # is already in the dictionary.
- if expr in expr_to_info:
- return
-
- if isinstance(expr, IndexVar):
- src_indices = expr, # A tuple with one element.
- dim_infos = _get_dummy_dim_info(), # A tuple with one element.
- elif isinstance(expr, Access):
- src_indices = expr.indices
- src_dims = tuple(expr.tensor.shape)
- if expr.tensor.format is None:
- # Treat each dimension of a dense tensor as DENSE for the purpose of
- # calculating temporary tensor storage format.
- mode_formats = tuple([ModeFormat.DENSE] * len(src_dims))
+ """Validates dimension information and constructs _ExprInfo.
+
+ Validates that dimensional values for the same IndexVar are the same. Collects
+ a list of IndexVar used by the expression and their corresponding dimensional
+ values. Constructs an _ExprInfo object to record the information for the
+ IndexExpr.
+
+ This routine is passed to the post-order visitor as an _ExprVisitor object.
+
+ Args:
+ expr: The IndexExpr being validated.
+ expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
+ expression information.
+
+ Raises:
+ ValueError if there is any problem in the IndexVars or dimensional values.
+ """
+ # Objects of class Access can be shared by
diff erent expressions. Avoid
+ # processing Access objects multiple times by skipping the processing if expr
+ # is already in the dictionary.
+ if expr in expr_to_info:
+ return
+
+ if isinstance(expr, IndexVar):
+ src_indices = (expr,) # A tuple with one element.
+ dim_infos = (_get_dummy_dim_info(),) # A tuple with one element.
+ elif isinstance(expr, Access):
+ src_indices = expr.indices
+ src_dims = tuple(expr.tensor.shape)
+ if expr.tensor.format is None:
+ # Treat each dimension of a dense tensor as DENSE for the purpose of
+ # calculating temporary tensor storage format.
+ mode_formats = tuple([ModeFormat.DENSE] * len(src_dims))
+ else:
+ mode_formats = tuple(expr.tensor.format.format_pack.formats)
+ assert len(src_dims) == len(mode_formats)
+ dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
+ elif isinstance(expr, _UnaryExpr):
+ a_info = expr_to_info[expr.a]
+ index_to_dim_info = {i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)}
+ # Here we rely on the fact that dictionaries keep the insertion order for
+ # keys and values.
+ src_indices = tuple(index_to_dim_info.keys())
+ dim_infos = tuple(index_to_dim_info.values())
else:
- mode_formats = tuple(expr.tensor.format.format_pack.formats)
- assert len(src_dims) == len(mode_formats)
- dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
- elif isinstance(expr, _UnaryExpr):
- a_info = expr_to_info[expr.a]
- index_to_dim_info = {
- i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
- }
- # Here we rely on the fact that dictionaries keep the insertion order for
- # keys and values.
- src_indices = tuple(index_to_dim_info.keys())
- dim_infos = tuple(index_to_dim_info.values())
- else:
- assert isinstance(expr, _BinaryExpr)
- a_info = expr_to_info[expr.a]
- index_to_dim_info = {
- i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
- }
- b_info = expr_to_info[expr.b]
- _validate_and_collect_dim_info(index_to_dim_info, b_info.src_indices,
- b_info.dim_infos, expr)
- # Here we rely on the fact that dictionaries keep the insertion order for
- # keys and values.
- src_indices = tuple(index_to_dim_info.keys())
- dim_infos = tuple(index_to_dim_info.values())
+ assert isinstance(expr, _BinaryExpr)
+ a_info = expr_to_info[expr.a]
+ index_to_dim_info = {i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)}
+ b_info = expr_to_info[expr.b]
+ _validate_and_collect_dim_info(
+ index_to_dim_info, b_info.src_indices, b_info.dim_infos, expr
+ )
+ # Here we rely on the fact that dictionaries keep the insertion order for
+ # keys and values.
+ src_indices = tuple(index_to_dim_info.keys())
+ dim_infos = tuple(index_to_dim_info.values())
- expr_to_info[expr] = _ExprInfo(src_indices, dim_infos)
+ expr_to_info[expr] = _ExprInfo(src_indices, dim_infos)
def _mark_structured_op_root(
@@ -1928,90 +2006,92 @@ def _mark_structured_op_root(
reduce_index: IndexVar,
expr_to_info: _ExprInfoDict,
) -> None:
- """Identifies the root expression for a structured op in the linalg dialect.
-
- An linalg structured op can only perform reduction on the whole expression.
- For a TACO tensor algebra expression, the reduction on an IndexVar is done at
- the smallest expression that contains all the uses of the IndexVar. If such an
- expression is only part of the whole expression, we need to split this
- sub-expression tree out from its parent and implement the sub-expression as a
- structured op.
-
- This routine identifies the root expression node for performing a reduction on
- the given IndexVar. If the reduction of the given IndexVar should be performed
- on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices
-
- Args:
- expr: The root IndexExpr for the tensor algebra expression.
- reduce_index: The IndexVar which we want to find out the proper expression
- to perform a reduction.
- expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
-
- Raises:
- ValueError: If the expression is not proper or not supported.
- """
- expr_info = expr_to_info[expr]
- if isinstance(expr, Access):
- # Handle simple reduction expression in the format of A[i] = B[i, j].
- if reduce_index in expr_info.src_indices:
- expr_info.reduce_indices.add(reduce_index)
- return
- elif isinstance(expr, IndexVar):
- # A[i] = B[i] + j is not allowed.
- raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.")
-
- assert (isinstance(expr, _BinaryExpr))
- a_info = expr_to_info[expr.a]
- b_info = expr_to_info[expr.b]
-
- if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
- expr_info.reduce_indices.add(reduce_index)
- return
-
- if reduce_index in a_info.src_indices:
- _mark_structured_op_root(expr.a, reduce_index, expr_to_info)
- elif reduce_index in b_info.src_indices:
- _mark_structured_op_root(expr.b, reduce_index, expr_to_info)
- else:
- assert False, "Unreachable path"
+ """Identifies the root expression for a structured op in the linalg dialect.
+ An linalg structured op can only perform reduction on the whole expression.
+ For a TACO tensor algebra expression, the reduction on an IndexVar is done at
+ the smallest expression that contains all the uses of the IndexVar. If such an
+ expression is only part of the whole expression, we need to split this
+ sub-expression tree out from its parent and implement the sub-expression as a
+ structured op.
-def _accumulate_reduce_indices(
- expr: IndexExpr,
- expr_to_info: _ExprInfoDict,
-) -> None:
- """Propagates reduction indices from child expressions to parent expressions.
+ This routine identifies the root expression node for performing a reduction on
+ the given IndexVar. If the reduction of the given IndexVar should be performed
+ on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices
- This routine is passed to the post-order visitor as an _ExprVisitor object.
+ Args:
+ expr: The root IndexExpr for the tensor algebra expression.
+ reduce_index: The IndexVar which we want to find out the proper expression
+ to perform a reduction.
+ expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
- Args:
- expr: The IndexExpr being visited.
- expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
- expression information.
- """
- assert expr in expr_to_info
- expr_info = expr_to_info[expr]
+ Raises:
+ ValueError: If the expression is not proper or not supported.
+ """
+ expr_info = expr_to_info[expr]
+ if isinstance(expr, Access):
+ # Handle simple reduction expression in the format of A[i] = B[i, j].
+ if reduce_index in expr_info.src_indices:
+ expr_info.reduce_indices.add(reduce_index)
+ return
+ elif isinstance(expr, IndexVar):
+ # A[i] = B[i] + j is not allowed.
+ raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.")
- if isinstance(expr, _BinaryExpr):
+ assert isinstance(expr, _BinaryExpr)
a_info = expr_to_info[expr.a]
b_info = expr_to_info[expr.b]
- expr_info.acc_reduce_indices = (
- a_info.acc_reduce_indices | b_info.acc_reduce_indices
- | expr_info.reduce_indices)
- elif isinstance(expr, _UnaryExpr):
- a_info = expr_to_info[expr.a]
- expr_info.acc_reduce_indices = (
- a_info.acc_reduce_indices | expr_info.reduce_indices)
- elif isinstance(expr, IndexVar):
- # If an IndexVar is reducing itself, it means the IndexVar is outside the
- # iteration domain. This usage is now allowed and we should emit an error
- # before reaching here.
- assert not expr_info.reduce_indices
- else:
- assert isinstance(expr, Access)
- # Handle simple reduction expression in the format of A[i] = B[i, j].
- expr_info.acc_reduce_indices = expr_info.reduce_indices
+ if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
+ expr_info.reduce_indices.add(reduce_index)
+ return
+
+ if reduce_index in a_info.src_indices:
+ _mark_structured_op_root(expr.a, reduce_index, expr_to_info)
+ elif reduce_index in b_info.src_indices:
+ _mark_structured_op_root(expr.b, reduce_index, expr_to_info)
+ else:
+ assert False, "Unreachable path"
+
+
+def _accumulate_reduce_indices(
+ expr: IndexExpr,
+ expr_to_info: _ExprInfoDict,
+) -> None:
+ """Propagates reduction indices from child expressions to parent expressions.
+
+ This routine is passed to the post-order visitor as an _ExprVisitor object.
+
+ Args:
+ expr: The IndexExpr being visited.
+ expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
+ expression information.
+ """
+ assert expr in expr_to_info
+ expr_info = expr_to_info[expr]
+
+ if isinstance(expr, _BinaryExpr):
+ a_info = expr_to_info[expr.a]
+ b_info = expr_to_info[expr.b]
+ expr_info.acc_reduce_indices = (
+ a_info.acc_reduce_indices
+ | b_info.acc_reduce_indices
+ | expr_info.reduce_indices
+ )
+ elif isinstance(expr, _UnaryExpr):
+ a_info = expr_to_info[expr.a]
+ expr_info.acc_reduce_indices = (
+ a_info.acc_reduce_indices | expr_info.reduce_indices
+ )
+ elif isinstance(expr, IndexVar):
+ # If an IndexVar is reducing itself, it means the IndexVar is outside the
+ # iteration domain. This usage is now allowed and we should emit an error
+ # before reaching here.
+ assert not expr_info.reduce_indices
+ else:
+ assert isinstance(expr, Access)
+ # Handle simple reduction expression in the format of A[i] = B[i, j].
+ expr_info.acc_reduce_indices = expr_info.reduce_indices
def _gather_structured_op(
@@ -2019,42 +2099,42 @@ def _gather_structured_op(
expr_to_info: _ExprInfoDict,
structop_roots: List[IndexExpr],
) -> None:
- """Adds structured op root expression information to structop_roots.
-
- This routine is passed to the post-order visitor as an _ExprVisitor object.
-
- Args:
- expr: The IndexExpr being visited.
- expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
- structop_roots: The resulting list of IndexExpr that are the roots for
- linalg structured ops.
- """
- if not expr_to_info[expr].reduce_indices:
- return
-
- # If the expression is the root for reducing some indices, collect the indices
- # and dimensions for the reduction result.
- dst_indices = []
- dst_dims = []
- mode_fmts = []
- for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos):
- if i not in expr_to_info[expr].acc_reduce_indices:
- dst_indices.append(i)
- dst_dims.append(d.dim)
- mode_fmts.append(d.mode_format)
-
- # Add the information to the dictionary.
- op_info = _StructOpInfo(
- tuple(dst_indices),
- tuple(dst_dims),
- expr.dtype(),
- f"temp{len(structop_roots)}",
- _make_format(mode_fmts),
- )
- expr_to_info[expr].structop_info = op_info
-
- # Add the expression to the list of structured op roots.
- structop_roots.append(expr)
+ """Adds structured op root expression information to structop_roots.
+
+ This routine is passed to the post-order visitor as an _ExprVisitor object.
+
+ Args:
+ expr: The IndexExpr being visited.
+ expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+ structop_roots: The resulting list of IndexExpr that are the roots for
+ linalg structured ops.
+ """
+ if not expr_to_info[expr].reduce_indices:
+ return
+
+ # If the expression is the root for reducing some indices, collect the indices
+ # and dimensions for the reduction result.
+ dst_indices = []
+ dst_dims = []
+ mode_fmts = []
+ for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos):
+ if i not in expr_to_info[expr].acc_reduce_indices:
+ dst_indices.append(i)
+ dst_dims.append(d.dim)
+ mode_fmts.append(d.mode_format)
+
+ # Add the information to the dictionary.
+ op_info = _StructOpInfo(
+ tuple(dst_indices),
+ tuple(dst_dims),
+ expr.dtype(),
+ f"temp{len(structop_roots)}",
+ _make_format(mode_fmts),
+ )
+ expr_to_info[expr].structop_info = op_info
+
+ # Add the expression to the list of structured op roots.
+ structop_roots.append(expr)
def _is_structured_op_leaf(
@@ -2063,29 +2143,31 @@ def _is_structured_op_leaf(
expr_to_info: _ExprInfoDict,
*unused_args,
) -> bool:
- """Returns true iff the expression is a leaf node for a structured op.
+ """Returns true iff the expression is a leaf node for a structured op.
- The root of a structured op is a leaf of its parent structured op that uses
- its result. An expression node is a leaf node for the current structured op if
- it is an Access node or the root for a structured op that is not the current
- structured op.
+ The root of a structured op is a leaf of its parent structured op that uses
+ its result. An expression node is a leaf node for the current structured op if
+ it is an Access node or the root for a structured op that is not the current
+ structured op.
- This routine is passed to the post-order visitor as a _SubtreeLeafChecker
- object. Because the post-order visitor pass the same parameters to both
- _SubtreeLeafChecker and _ExprVisitor, this routine may received unused
- parameters.
+ This routine is passed to the post-order visitor as a _SubtreeLeafChecker
+ object. Because the post-order visitor pass the same parameters to both
+ _SubtreeLeafChecker and _ExprVisitor, this routine may received unused
+ parameters.
- Args:
- expr: The IndexExpr being visited.
- root: The root of the current structured op.
- expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+ Args:
+ expr: The IndexExpr being visited.
+ root: The root of the current structured op.
+ expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
- Returns:
- True if the current IndexExpr is a leaf for the current structured op.
- """
- return (expr != root and
- expr_to_info[expr].structop_info is not None) or isinstance(
- expr, Access) or isinstance(expr, IndexVar)
+ Returns:
+ True if the current IndexExpr is a leaf for the current structured op.
+ """
+ return (
+ (expr != root and expr_to_info[expr].structop_info is not None)
+ or isinstance(expr, Access)
+ or isinstance(expr, IndexVar)
+ )
def _gather_structured_op_input(
@@ -2094,26 +2176,28 @@ def _gather_structured_op_input(
expr_to_info: _ExprInfoDict,
structop_inputs: List[IndexExpr],
) -> None:
- """Adds the IndexExpr to structop_inputs if it is an input.
+ """Adds the IndexExpr to structop_inputs if it is an input.
- If the current IndexExpr is an input for the current structured op, adds it to
- structop_inputs. The current IndexExpr is an input if it is an Access node or
- if it is the root for a structured op that is not the current structured op.
+ If the current IndexExpr is an input for the current structured op, adds it to
+ structop_inputs. The current IndexExpr is an input if it is an Access node or
+ if it is the root for a structured op that is not the current structured op.
- This routine is passed to the post-order visitor as an _ExprVisitor object.
+ This routine is passed to the post-order visitor as an _ExprVisitor object.
- Args:
- expr: The IndexExpr being visited.
- root: The root of the current structured op.
- expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
- structop_inputs: The resulting list of IndexExpr that provide input to the
- current structured op.
- """
- if ((expr != root or isinstance(expr, Access)) and
- expr not in structop_inputs) and (isinstance(expr, Access) or
- (expr in expr_to_info and
- expr_to_info[expr].structop_info)):
- structop_inputs.append(expr)
+ Args:
+ expr: The IndexExpr being visited.
+ root: The root of the current structured op.
+ expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+ structop_inputs: The resulting list of IndexExpr that provide input to the
+ current structured op.
+ """
+ if (
+ (expr != root or isinstance(expr, Access)) and expr not in structop_inputs
+ ) and (
+ isinstance(expr, Access)
+ or (expr in expr_to_info and expr_to_info[expr].structop_info)
+ ):
+ structop_inputs.append(expr)
def _emit_structured_op_input(
@@ -2121,35 +2205,35 @@ def _emit_structured_op_input(
expr_to_info: _ExprInfoDict,
op_def: lang.LinalgOpDef,
) -> lang.OperandDef:
- """Emits OperandDef in the linalg dialect for the input IndexExpr.
-
- Args:
- expr: The input IndexExpr for the current structured op.
- expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
- op_def: The linalg operation for the current structured op.
-
- Returns:
- An OperandDef in the linalg dialect for the input IndexExpr.
- """
- op_info = expr_to_info[expr].structop_info
- if op_info and not isinstance(expr, Access):
- # The input is a temporary tensor produced by another structured op.
- indices = op_info.dst_indices
- name = op_info.dst_name
- else:
- # The input is a user provided tensor.
- assert isinstance(expr, Access)
- indices = expr.indices
- name = expr.tensor.name
-
- dim_sym = _mlir_symbols_from_index_vars(indices)
- opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
- op_def.add_operand(name, opnd)
- return opnd
+ """Emits OperandDef in the linalg dialect for the input IndexExpr.
+
+ Args:
+ expr: The input IndexExpr for the current structured op.
+ expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+ op_def: The linalg operation for the current structured op.
+
+ Returns:
+ An OperandDef in the linalg dialect for the input IndexExpr.
+ """
+ op_info = expr_to_info[expr].structop_info
+ if op_info and not isinstance(expr, Access):
+ # The input is a temporary tensor produced by another structured op.
+ indices = op_info.dst_indices
+ name = op_info.dst_name
+ else:
+ # The input is a user provided tensor.
+ assert isinstance(expr, Access)
+ indices = expr.indices
+ name = expr.tensor.name
+
+ dim_sym = _mlir_symbols_from_index_vars(indices)
+ opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
+ op_def.add_operand(name, opnd)
+ return opnd
def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
- """Build a unary operation ceil.
+ """Build a unary operation ceil.
Args:
a: The operand, which could be any Python object from user inputs.
@@ -2161,13 +2245,13 @@ def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
Raises:
ValueError: If a is not an IndexExpr.
"""
- if not isinstance(a, Access):
- raise ValueError(f"Expected an Access Operand: {a}")
- return a._build_unary_expr(op)
+ if not isinstance(a, Access):
+ raise ValueError(f"Expected an Access Operand: {a}")
+ return a._build_unary_expr(op)
def ceil(a: Access) -> "_UnaryExpr":
- """Defines the operation ceil.
+ """Defines the operation ceil.
Args:
a: The operand, which could be any Python object from user inputs.
@@ -2178,11 +2262,11 @@ def ceil(a: Access) -> "_UnaryExpr":
Raises:
ValueError: If a is not an IndexExpr.
"""
- return _check_and_build_unary(a, _op_ceil)
+ return _check_and_build_unary(a, _op_ceil)
def floor(a: Access) -> "_UnaryExpr":
- """Defines the operation floor.
+ """Defines the operation floor.
Args:
a: The operand, which could be any Python object from user inputs.
@@ -2193,4 +2277,4 @@ def floor(a: Access) -> "_UnaryExpr":
Raises:
ValueError: If a is not an IndexExpr.
"""
- return _check_and_build_unary(a, _op_floor)
+ return _check_and_build_unary(a, _op_floor)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
index e6a7d8e1b4b85..785401c25dc87 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
@@ -31,50 +31,52 @@
_TNS_FILENAME_SUFFIX = ".tns"
-def read(filename: str, fmt: Format,
- dtype: DType = DType(Type.FLOAT32)) -> Tensor:
- """Inputs a tensor from a given file.
-
- The name suffix of the file specifies the format of the input tensor. We
- currently only support .mtx format for support sparse tensors.
-
- Args:
- filename: A string input filename.
- fmt: The storage format of the tensor.
- dtype: The data type, default to float32.
-
- Raises:
- ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
- instance of Format or fmt is not a sparse tensor.
- """
- if (not isinstance(filename, str) or
- (not filename.endswith(_MTX_FILENAME_SUFFIX) and
- not filename.endswith(_TNS_FILENAME_SUFFIX))):
- raise ValueError("Expected string filename ends with "
- f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
- f"{filename}.")
-
- return Tensor.from_file(filename, fmt, dtype)
+def read(filename: str, fmt: Format, dtype: DType = DType(Type.FLOAT32)) -> Tensor:
+ """Inputs a tensor from a given file.
+
+ The name suffix of the file specifies the format of the input tensor. We
+ currently only support .mtx format for support sparse tensors.
+
+ Args:
+ filename: A string input filename.
+ fmt: The storage format of the tensor.
+ dtype: The data type, default to float32.
+
+ Raises:
+ ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
+ instance of Format or fmt is not a sparse tensor.
+ """
+ if not isinstance(filename, str) or (
+ not filename.endswith(_MTX_FILENAME_SUFFIX)
+ and not filename.endswith(_TNS_FILENAME_SUFFIX)
+ ):
+ raise ValueError(
+ "Expected string filename ends with "
+ f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
+ f"{filename}."
+ )
+
+ return Tensor.from_file(filename, fmt, dtype)
def write(filename: str, tensor: Tensor) -> None:
- """Outputs a tensor to a given file.
-
- The name suffix of the file specifies the format of the output. We currently
- only support .tns format.
-
- Args:
- filename: A string output filename.
- tensor: The tensor to output.
-
- Raises:
- ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
- """
- if (not isinstance(filename, str) or
- not filename.endswith(_TNS_FILENAME_SUFFIX)):
- raise ValueError("Expected string filename ends with"
- f" {_TNS_FILENAME_SUFFIX}: {filename}.")
- if not isinstance(tensor, Tensor):
- raise ValueError(f"Expected a Tensor object: {tensor}.")
-
- tensor.to_file(filename)
+ """Outputs a tensor to a given file.
+
+ The name suffix of the file specifies the format of the output. We currently
+ only support .tns format.
+
+ Args:
+ filename: A string output filename.
+ tensor: The tensor to output.
+
+ Raises:
+ ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
+ """
+ if not isinstance(filename, str) or not filename.endswith(_TNS_FILENAME_SUFFIX):
+ raise ValueError(
+ "Expected string filename ends with" f" {_TNS_FILENAME_SUFFIX}: {filename}."
+ )
+ if not isinstance(tensor, Tensor):
+ raise ValueError(f"Expected a Tensor object: {tensor}.")
+
+ tensor.to_file(filename)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
index 988c57b3b33f2..1e1061b8b858d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
@@ -36,190 +36,234 @@
@functools.lru_cache()
def _get_support_lib_name() -> str:
- """Gets the string name for the supporting C shared library."""
- return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
+ """Gets the string name for the supporting C shared library."""
+ return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
@functools.lru_cache()
def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler:
- """Gets the MLIR sparse compiler with default setting."""
- return mlir_sparse_compiler.SparseCompiler(
- options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
+ """Gets the MLIR sparse compiler with default setting."""
+ return mlir_sparse_compiler.SparseCompiler(
+ options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()]
+ )
def _record_support_funcs(
- ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
- ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
- """Records the two supporting functions for a given data type."""
- to_func.restype = ctypes.c_void_p
- from_func.restype = ctypes.c_void_p
- ty_to_funcs[ty] = (to_func, from_func)
+ ty: np.dtype,
+ to_func: _SupportFunc,
+ from_func: _SupportFunc,
+ ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]],
+) -> None:
+ """Records the two supporting functions for a given data type."""
+ to_func.restype = ctypes.c_void_p
+ from_func.restype = ctypes.c_void_p
+ ty_to_funcs[ty] = (to_func, from_func)
@functools.lru_cache()
def _get_support_func_locator() -> _SupportFuncLocator:
- """Constructs a function to locate the supporting functions for a data type.
-
- Loads the supporting C shared library with the needed routines. Constructs a
- dictionary from the supported data types to the routines for the data types,
- and then a function to look up the dictionary for a given data type.
-
- The name of the supporting C shared library is either provided by an
- an environment variable or a default value.
-
- Returns:
- The function to look up the supporting functions for a given data type.
-
- Raises:
- OSError: If there is any problem in loading the shared library.
- ValueError: If the shared library doesn't contain the needed routines.
- """
- # This raises OSError exception if there is any problem in loading the shared
- # library.
- c_lib = ctypes.CDLL(_get_support_lib_name())
-
- type_to_funcs = {}
- try:
- support_types = [(np.int8, c_lib.convertToMLIRSparseTensorI8,
- c_lib.convertFromMLIRSparseTensorI8),
- (np.int16, c_lib.convertToMLIRSparseTensorI16,
- c_lib.convertFromMLIRSparseTensorI16),
- (np.int32, c_lib.convertToMLIRSparseTensorI32,
- c_lib.convertFromMLIRSparseTensorI32),
- (np.int64, c_lib.convertToMLIRSparseTensorI64,
- c_lib.convertFromMLIRSparseTensorI64),
- (np.float16, c_lib.convertToMLIRSparseTensorF16,
- c_lib.convertFromMLIRSparseTensorF16),
- (np.float32, c_lib.convertToMLIRSparseTensorF32,
- c_lib.convertFromMLIRSparseTensorF32),
- (np.float64, c_lib.convertToMLIRSparseTensorF64,
- c_lib.convertFromMLIRSparseTensorF64),
- (np.complex64, c_lib.convertToMLIRSparseTensorC32,
- c_lib.convertFromMLIRSparseTensorC32),
- (np.complex128, c_lib.convertToMLIRSparseTensorC64,
- c_lib.convertFromMLIRSparseTensorC64)]
- except Exception as e:
- raise ValueError(f"Missing supporting function: {e}") from e
- for i, info in enumerate(support_types):
- _record_support_funcs(info[0], info[1], info[2], type_to_funcs)
-
- def get_support_funcs(ty: np.dtype):
- funcs = type_to_funcs[ty]
- assert funcs is not None
- return funcs
-
- return get_support_funcs
+ """Constructs a function to locate the supporting functions for a data type.
+
+ Loads the supporting C shared library with the needed routines. Constructs a
+ dictionary from the supported data types to the routines for the data types,
+ and then a function to look up the dictionary for a given data type.
+
+ The name of the supporting C shared library is either provided by an
+ an environment variable or a default value.
+
+ Returns:
+ The function to look up the supporting functions for a given data type.
+
+ Raises:
+ OSError: If there is any problem in loading the shared library.
+ ValueError: If the shared library doesn't contain the needed routines.
+ """
+ # This raises OSError exception if there is any problem in loading the shared
+ # library.
+ c_lib = ctypes.CDLL(_get_support_lib_name())
+
+ type_to_funcs = {}
+ try:
+ support_types = [
+ (
+ np.int8,
+ c_lib.convertToMLIRSparseTensorI8,
+ c_lib.convertFromMLIRSparseTensorI8,
+ ),
+ (
+ np.int16,
+ c_lib.convertToMLIRSparseTensorI16,
+ c_lib.convertFromMLIRSparseTensorI16,
+ ),
+ (
+ np.int32,
+ c_lib.convertToMLIRSparseTensorI32,
+ c_lib.convertFromMLIRSparseTensorI32,
+ ),
+ (
+ np.int64,
+ c_lib.convertToMLIRSparseTensorI64,
+ c_lib.convertFromMLIRSparseTensorI64,
+ ),
+ (
+ np.float16,
+ c_lib.convertToMLIRSparseTensorF16,
+ c_lib.convertFromMLIRSparseTensorF16,
+ ),
+ (
+ np.float32,
+ c_lib.convertToMLIRSparseTensorF32,
+ c_lib.convertFromMLIRSparseTensorF32,
+ ),
+ (
+ np.float64,
+ c_lib.convertToMLIRSparseTensorF64,
+ c_lib.convertFromMLIRSparseTensorF64,
+ ),
+ (
+ np.complex64,
+ c_lib.convertToMLIRSparseTensorC32,
+ c_lib.convertFromMLIRSparseTensorC32,
+ ),
+ (
+ np.complex128,
+ c_lib.convertToMLIRSparseTensorC64,
+ c_lib.convertFromMLIRSparseTensorC64,
+ ),
+ ]
+ except Exception as e:
+ raise ValueError(f"Missing supporting function: {e}") from e
+ for i, info in enumerate(support_types):
+ _record_support_funcs(info[0], info[1], info[2], type_to_funcs)
+
+ def get_support_funcs(ty: np.dtype):
+ funcs = type_to_funcs[ty]
+ assert funcs is not None
+ return funcs
+
+ return get_support_funcs
def sparse_tensor_to_coo_tensor(
sparse_tensor: ctypes.c_void_p,
dtype: np.dtype,
) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
- """Converts an MLIR sparse tensor to a COO-flavored format tensor.
-
- Args:
- sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
- dtype: The numpy data type for the tensor elements.
-
- Returns:
- A tuple that contains the following values for the COO-flavored format
- tensor:
- rank: An integer for the rank of the tensor.
- nse: An integer for the number of non-zero values in the tensor.
- shape: A 1D numpy array of integers, for the shape of the tensor.
- values: A 1D numpy array, for the non-zero values in the tensor.
- indices: A 2D numpy array of integers, representing the indices for the
- non-zero values in the tensor.
-
- Raises:
- OSError: If there is any problem in loading the shared library.
- ValueError: If the shared library doesn't contain the needed routines.
- """
- convert_from = _get_support_func_locator()(dtype)[1]
- rank = ctypes.c_ulonglong(0)
- nse = ctypes.c_ulonglong(0)
- shape = ctypes.POINTER(ctypes.c_ulonglong)()
-
- values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
- indices = ctypes.POINTER(ctypes.c_ulonglong)()
- convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse),
- ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices))
-
- # Convert the returned values to the corresponding numpy types.
- shape = np.ctypeslib.as_array(shape, shape=[rank.value])
- values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
- indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
- return rank.value, nse.value, shape, values, indices
-
-
-def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
- np_indices: np.ndarray, np_perm: np.ndarray,
- np_sparse: np.ndarray) -> int:
- """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
-
- Args:
- np_shape: A 1D numpy array of integers, for the shape of the tensor.
- np_values: A 1D numpy array, for the non-zero values in the tensor.
- np_indices: A 2D numpy array of integers, representing the indices for the
- non-zero values in the tensor.
- np_perm: A 1D numpy array of integers, representing the storage ordering
- for the dimensions.
- np_sparse: A 1D numpy array of uint8, representing the sparsity values
- for the dimensions.
-
- Returns:
- An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
- descriptor.
-
- Raises:
- OSError: If there is any problem in loading the shared library.
- ValueError: If the shared library doesn't contain the needed routines.
- """
-
- r = len(np_shape)
- rank = ctypes.c_ulonglong(r)
- nse = ctypes.c_ulonglong(len(np_values))
- shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
- values = np_values.ctypes.data_as(
- ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype))))
- indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
-
- perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
- sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
-
- convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
- ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
- assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
- return ptr
-
-
-def compile_and_build_engine(
- module: ir.Module) -> execution_engine.ExecutionEngine:
- """Compiles an MLIR module and builds a JIT execution engine.
-
- Args:
- module: The MLIR module.
-
- Returns:
- A JIT execution engine for the MLIR module.
-
- """
- return _get_sparse_compiler().compile_and_jit(module)
+ """Converts an MLIR sparse tensor to a COO-flavored format tensor.
+
+ Args:
+ sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
+ dtype: The numpy data type for the tensor elements.
+
+ Returns:
+ A tuple that contains the following values for the COO-flavored format
+ tensor:
+ rank: An integer for the rank of the tensor.
+ nse: An integer for the number of non-zero values in the tensor.
+ shape: A 1D numpy array of integers, for the shape of the tensor.
+ values: A 1D numpy array, for the non-zero values in the tensor.
+ indices: A 2D numpy array of integers, representing the indices for the
+ non-zero values in the tensor.
+
+ Raises:
+ OSError: If there is any problem in loading the shared library.
+ ValueError: If the shared library doesn't contain the needed routines.
+ """
+ convert_from = _get_support_func_locator()(dtype)[1]
+ rank = ctypes.c_ulonglong(0)
+ nse = ctypes.c_ulonglong(0)
+ shape = ctypes.POINTER(ctypes.c_ulonglong)()
+
+ values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
+ indices = ctypes.POINTER(ctypes.c_ulonglong)()
+ convert_from(
+ sparse_tensor,
+ ctypes.byref(rank),
+ ctypes.byref(nse),
+ ctypes.byref(shape),
+ ctypes.byref(values),
+ ctypes.byref(indices),
+ )
+
+ # Convert the returned values to the corresponding numpy types.
+ shape = np.ctypeslib.as_array(shape, shape=[rank.value])
+ values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
+ indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
+ return rank.value, nse.value, shape, values, indices
+
+
+def coo_tensor_to_sparse_tensor(
+ np_shape: np.ndarray,
+ np_values: np.ndarray,
+ np_indices: np.ndarray,
+ np_perm: np.ndarray,
+ np_sparse: np.ndarray,
+) -> int:
+ """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
+
+ Args:
+ np_shape: A 1D numpy array of integers, for the shape of the tensor.
+ np_values: A 1D numpy array, for the non-zero values in the tensor.
+ np_indices: A 2D numpy array of integers, representing the indices for the
+ non-zero values in the tensor.
+ np_perm: A 1D numpy array of integers, representing the storage ordering
+ for the dimensions.
+ np_sparse: A 1D numpy array of uint8, representing the sparsity values
+ for the dimensions.
+
+ Returns:
+ An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
+ descriptor.
+
+ Raises:
+ OSError: If there is any problem in loading the shared library.
+ ValueError: If the shared library doesn't contain the needed routines.
+ """
+
+ r = len(np_shape)
+ rank = ctypes.c_ulonglong(r)
+ nse = ctypes.c_ulonglong(len(np_values))
+ shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+ values = np_values.ctypes.data_as(
+ ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype)))
+ )
+ indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+
+ perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+ sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
+
+ convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
+ ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
+ assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
+ return ptr
+
+
+def compile_and_build_engine(module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Compiles an MLIR module and builds a JIT execution engine.
+
+ Args:
+ module: The MLIR module.
+
+ Returns:
+ A JIT execution engine for the MLIR module.
+
+ """
+ return _get_sparse_compiler().compile_and_jit(module)
class _SparseTensorDescriptor(ctypes.Structure):
- """A C structure for an MLIR sparse tensor."""
- _fields_ = [
- # A pointer for the MLIR sparse tensor storage.
- ("storage", ctypes.POINTER(ctypes.c_ulonglong)),
- # An MLIR MemRef descriptor for the shape of the sparse tensor.
- ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
- ]
+ """A C structure for an MLIR sparse tensor."""
+
+ _fields_ = [
+ # A pointer for the MLIR sparse tensor storage.
+ ("storage", ctypes.POINTER(ctypes.c_ulonglong)),
+ # An MLIR MemRef descriptor for the shape of the sparse tensor.
+ ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
+ ]
def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
- """Produces the MLIR text code to output the size for the given dimension."""
- return f"""
+ """Produces the MLIR text code to output the size for the given dimension."""
+ return f"""
%c{dim} = arith.constant {dim} : index
%d{dim} = tensor.dim %t, %c{dim} : tensor<{shape}x{type}, #enc>
memref.store %d{dim}, %b[%c{dim}] : memref<{rank}xindex>
@@ -233,26 +277,29 @@ def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
# (2) Use scf.for instead of an unrolled loop to write out the dimension sizes
# when tensor.dim supports non-constant dimension value.
def _get_create_sparse_tensor_kernel(
- sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str) -> str:
- """Creates an MLIR text kernel to contruct a sparse tensor from a file.
+ sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
+) -> str:
+ """Creates an MLIR text kernel to contruct a sparse tensor from a file.
- The kernel returns a _SparseTensorDescriptor structure.
- """
- rank = len(sparsity_codes)
+ The kernel returns a _SparseTensorDescriptor structure.
+ """
+ rank = len(sparsity_codes)
- # Use ? to represent a dimension in the dynamic shape string representation.
- shape = "x".join(map(lambda d: "?", range(rank)))
+ # Use ? to represent a dimension in the dynamic shape string representation.
+ shape = "x".join(map(lambda d: "?", range(rank)))
- # Convert the encoded sparsity values to a string representation.
- sparsity = ", ".join(
- map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
+ # Convert the encoded sparsity values to a string representation.
+ sparsity = ", ".join(
+ map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
+ )
- # Get the MLIR text code to write the dimension sizes to the output buffer.
- output_dims = "\n".join(
- map(lambda d: _output_one_dim(d, rank, shape, type), range(rank)))
+ # Get the MLIR text code to write the dimension sizes to the output buffer.
+ output_dims = "\n".join(
+ map(lambda d: _output_one_dim(d, rank, shape, type), range(rank))
+ )
- # Return the MLIR text kernel.
- return f"""
+ # Return the MLIR text kernel.
+ return f"""
!Ptr = !llvm.ptr<i8>
#enc = #sparse_tensor.encoding<{{
lvlTypes = [ {sparsity} ]
@@ -266,69 +313,69 @@ def _get_create_sparse_tensor_kernel(
}}"""
-def create_sparse_tensor(filename: str,
- sparsity: Sequence[sparse_tensor.DimLevelType],
- type: str) -> Tuple[ctypes.c_void_p, np.ndarray]:
- """Creates an MLIR sparse tensor from the input file.
+def create_sparse_tensor(
+ filename: str, sparsity: Sequence[sparse_tensor.DimLevelType], type: str
+) -> Tuple[ctypes.c_void_p, np.ndarray]:
+ """Creates an MLIR sparse tensor from the input file.
- Args:
- filename: A string for the name of the file that contains the tensor data in
- a COO-flavored format.
- sparsity: A sequence of DimLevelType values, one for each dimension of the
- tensor.
+ Args:
+ filename: A string for the name of the file that contains the tensor data in
+ a COO-flavored format.
+ sparsity: A sequence of DimLevelType values, one for each dimension of the
+ tensor.
- Returns:
- A Tuple containing the following values:
- storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
- shape: A 1D numpy array of integers, for the shape of the tensor.
+ Returns:
+ A Tuple containing the following values:
+ storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
+ shape: A 1D numpy array of integers, for the shape of the tensor.
- Raises:
- OSError: If there is any problem in loading the supporting C shared library.
- ValueError: If the shared library doesn't contain the needed routine.
- """
- with ir.Context() as ctx, ir.Location.unknown():
- module = _get_create_sparse_tensor_kernel(sparsity, type)
- module = ir.Module.parse(module)
- engine = compile_and_build_engine(module)
+ Raises:
+ OSError: If there is any problem in loading the supporting C shared library.
+ ValueError: If the shared library doesn't contain the needed routine.
+ """
+ with ir.Context() as ctx, ir.Location.unknown():
+ module = _get_create_sparse_tensor_kernel(sparsity, type)
+ module = ir.Module.parse(module)
+ engine = compile_and_build_engine(module)
- # A sparse tensor descriptor to receive the kernel result.
- c_tensor_desc = _SparseTensorDescriptor()
- # Convert the filename to a byte stream.
- c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
+ # A sparse tensor descriptor to receive the kernel result.
+ c_tensor_desc = _SparseTensorDescriptor()
+ # Convert the filename to a byte stream.
+ c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
- arg_pointers = [
- ctypes.byref(ctypes.pointer(c_tensor_desc)),
- ctypes.byref(c_filename)
- ]
+ arg_pointers = [
+ ctypes.byref(ctypes.pointer(c_tensor_desc)),
+ ctypes.byref(c_filename),
+ ]
- # Invoke the execution engine to run the module and return the result.
- engine.invoke(_ENTRY_NAME, *arg_pointers)
- shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
- return c_tensor_desc.storage, shape
+ # Invoke the execution engine to run the module and return the result.
+ engine.invoke(_ENTRY_NAME, *arg_pointers)
+ shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
+ return c_tensor_desc.storage, shape
# TODO: With better support from MLIR, we may improve the current implementation
# by using Python code to generate the kernel instead of doing MLIR text code
# stitching.
def _get_output_sparse_tensor_kernel(
- sparsity_codes: Sequence[sparse_tensor.DimLevelType],
- type: str) -> str:
- """Creates an MLIR text kernel to output a sparse tensor to a file.
+ sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
+) -> str:
+ """Creates an MLIR text kernel to output a sparse tensor to a file.
- The kernel returns void.
- """
- rank = len(sparsity_codes)
+ The kernel returns void.
+ """
+ rank = len(sparsity_codes)
- # Use ? to represent a dimension in the dynamic shape string representation.
- shape = "x".join(map(lambda d: "?", range(rank)))
+ # Use ? to represent a dimension in the dynamic shape string representation.
+ shape = "x".join(map(lambda d: "?", range(rank)))
- # Convert the encoded sparsity values to a string representation.
- sparsity = ", ".join(
- map(lambda s: '"compressed"'
- if s.value else '"dense"', sparsity_codes))
+ # Convert the encoded sparsity values to a string representation.
+ sparsity = ", ".join(
+ map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
+ )
- # Return the MLIR text kernel.
- return f"""
+ # Return the MLIR text kernel.
+ return f"""
!Ptr = !llvm.ptr<i8>
#enc = #sparse_tensor.encoding<{{
lvlTypes = [ {sparsity} ]
@@ -340,35 +387,38 @@ def _get_output_sparse_tensor_kernel(
}}"""
-def output_sparse_tensor(tensor: ctypes.c_void_p, filename: str,
- sparsity: Sequence[sparse_tensor.DimLevelType],
- type: str) -> None:
- """Outputs an MLIR sparse tensor to the given file.
-
- Args:
- tensor: A C pointer to the MLIR sparse tensor.
- filename: A string for the name of the file that contains the tensor data in
- a COO-flavored format.
- sparsity: A sequence of DimLevelType values, one for each dimension of the
- tensor.
- type: The MLIR string for the data type.
-
- Raises:
- OSError: If there is any problem in loading the supporting C shared library.
- ValueError: If the shared library doesn't contain the needed routine.
- """
- with ir.Context() as ctx, ir.Location.unknown():
- module = _get_output_sparse_tensor_kernel(sparsity, type)
- module = ir.Module.parse(module)
- engine = compile_and_build_engine(module)
-
- # Convert the filename to a byte stream.
- c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
-
- arg_pointers = [
- ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
- ctypes.byref(c_filename)
- ]
-
- # Invoke the execution engine to run the module and return the result.
- engine.invoke(_ENTRY_NAME, *arg_pointers)
+def output_sparse_tensor(
+ tensor: ctypes.c_void_p,
+ filename: str,
+ sparsity: Sequence[sparse_tensor.DimLevelType],
+ type: str,
+) -> None:
+ """Outputs an MLIR sparse tensor to the given file.
+
+ Args:
+ tensor: A C pointer to the MLIR sparse tensor.
+ filename: A string for the name of the file that contains the tensor data in
+ a COO-flavored format.
+ sparsity: A sequence of DimLevelType values, one for each dimension of the
+ tensor.
+ type: The MLIR string for the data type.
+
+ Raises:
+ OSError: If there is any problem in loading the supporting C shared library.
+ ValueError: If the shared library doesn't contain the needed routine.
+ """
+ with ir.Context() as ctx, ir.Location.unknown():
+ module = _get_output_sparse_tensor_kernel(sparsity, type)
+ module = ir.Module.parse(module)
+ engine = compile_and_build_engine(module)
+
+ # Convert the filename to a byte stream.
+ c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
+
+ arg_pointers = [
+ ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
+ ctypes.byref(c_filename),
+ ]
+
+ # Invoke the execution engine to run the module and return the result.
+ engine.invoke(_ENTRY_NAME, *arg_pointers)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
index 69db28d4bccd5..8f193b81bb07c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
@@ -13,29 +13,29 @@
class SparseCompiler:
- """Sparse compiler class for compiling and building MLIR modules."""
-
- def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
- pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
- self.pipeline = pipeline
- self.opt_level = opt_level
- self.shared_libs = shared_libs
-
- def __call__(self, module: ir.Module):
- """Convenience application method."""
- self.compile(module)
-
- def compile(self, module: ir.Module):
- """Compiles the module by invoking the sparse copmiler pipeline."""
- passmanager.PassManager.parse(self.pipeline).run(module.operation)
-
- def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
- """Wraps the module in a JIT execution engine."""
- return execution_engine.ExecutionEngine(
- module, opt_level=self.opt_level, shared_libs=self.shared_libs)
-
- def compile_and_jit(self,
- module: ir.Module) -> execution_engine.ExecutionEngine:
- """Compiles and jits the module."""
- self.compile(module)
- return self.jit(module)
+ """Sparse compiler class for compiling and building MLIR modules."""
+
+ def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+ pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
+ self.pipeline = pipeline
+ self.opt_level = opt_level
+ self.shared_libs = shared_libs
+
+ def __call__(self, module: ir.Module):
+ """Convenience application method."""
+ self.compile(module)
+
+ def compile(self, module: ir.Module):
+ """Compiles the module by invoking the sparse copmiler pipeline."""
+ passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+ def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Wraps the module in a JIT execution engine."""
+ return execution_engine.ExecutionEngine(
+ module, opt_level=self.opt_level, shared_libs=self.shared_libs
+ )
+
+ def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Compiles and jits the module."""
+ self.compile(module)
+ return self.jit(module)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
index 466c9df042984..1be88fa8bd709 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
@@ -8,38 +8,40 @@
def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool:
- """Compares sparse tensor actual output file with expected output file.
+ """Compares sparse tensor actual output file with expected output file.
- This routine assumes the input files are in FROSTT format. See
- http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
+ This routine assumes the input files are in FROSTT format. See
+ http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
- It also assumes the first line in the output file is a comment line.
+ It also assumes the first line in the output file is a comment line.
- """
- with open(actual, "r") as actual_f:
- with open(expected, "r") as expected_f:
- # Skip the first comment line.
- _ = actual_f.readline()
- _ = expected_f.readline()
+ """
+ with open(actual, "r") as actual_f:
+ with open(expected, "r") as expected_f:
+ # Skip the first comment line.
+ _ = actual_f.readline()
+ _ = expected_f.readline()
- # Compare the two lines of meta data
- if (actual_f.readline() != expected_f.readline() or
- actual_f.readline() != expected_f.readline()):
- return FALSE
+ # Compare the two lines of meta data
+ if (
+ actual_f.readline() != expected_f.readline()
+ or actual_f.readline() != expected_f.readline()
+ ):
+ return FALSE
- actual_data = np.loadtxt(actual, np.float64, skiprows=3)
- expected_data = np.loadtxt(expected, np.float64, skiprows=3)
- return np.allclose(actual_data, expected_data, rtol=rtol)
+ actual_data = np.loadtxt(actual, np.float64, skiprows=3)
+ expected_data = np.loadtxt(expected, np.float64, skiprows=3)
+ return np.allclose(actual_data, expected_data, rtol=rtol)
def file_as_string(file: str) -> str:
- """Returns contents of file as string."""
- with open(file, "r") as f:
- return f.read()
+ """Returns contents of file as string."""
+ with open(file, "r") as f:
+ return f.read()
def run_test(f):
- """Prints the test name and runs the test."""
- print(f.__name__)
- f()
- return f
+ """Prints the test name and runs the test."""
+ print(f.__name__)
+ f()
+ return f
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
index 5b7e648b97957..45ce446478dee 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
@@ -18,509 +18,630 @@
def _init_3d(T, I, J, K):
- for i in range(I):
- for j in range(J):
- for k in range(K):
- T.insert([i, j, k], i + j + k + 1)
+ for i in range(I):
+ for j in range(J):
+ for k in range(K):
+ T.insert([i, j, k], i + j + k + 1)
def _init_2d(T, I, J):
- for i in range(I):
- for j in range(J):
- T.insert([i, j], i + j + 1)
+ for i in range(I):
+ for j in range(J):
+ T.insert([i, j], i + j + 1)
def _init_1d_with_value(T, I, v):
- for i in range(I):
- T.insert([i], v)
+ for i in range(I):
+ T.insert([i], v)
def test_expect_error(name, code, error):
- """Executes the code then verifies the expected error message."""
- try:
- exec(code)
- except ValueError as e:
- passed = "passed" if (str(e).startswith(error)) else "failed"
- print(f"test_{name}: {passed}")
+ """Executes the code then verifies the expected error message."""
+ try:
+ exec(code)
+ except ValueError as e:
+ passed = "passed" if (str(e).startswith(error)) else "failed"
+ print(f"test_{name}: {passed}")
# CHECK-LABEL: test_tensor_dtype
@testing_utils.run_test
def test_tensor_dtype():
- passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
- passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
- passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
- passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
- passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
- # CHECK: Number of passed: 5
- print("Number of passed:", passed)
+ passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
+ passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
+ passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
+ passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
+ passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
+ # CHECK: Number of passed: 5
+ print("Number of passed:", passed)
# CHECK: test_mode_ordering_not_int: passed
-test_expect_error("mode_ordering_not_int",
- "m = mlir_pytaco.ModeOrdering(['x'])",
- "Ordering must be a list of integers")
+test_expect_error(
+ "mode_ordering_not_int",
+ "m = mlir_pytaco.ModeOrdering(['x'])",
+ "Ordering must be a list of integers",
+)
# CHECK: test_mode_ordering_not_permutation: passed
-test_expect_error("mode_ordering_not_permutation",
- "m = mlir_pytaco.ModeOrdering([2, 1])", "Invalid ordering")
+test_expect_error(
+ "mode_ordering_not_permutation",
+ "m = mlir_pytaco.ModeOrdering([2, 1])",
+ "Invalid ordering",
+)
# CHECK: test_mode_format_invalid: passed
-test_expect_error("mode_format_invalid",
- "m = mlir_pytaco.ModeFormatPack(['y'])",
- "Formats must be a list of ModeFormat")
+test_expect_error(
+ "mode_format_invalid",
+ "m = mlir_pytaco.ModeFormatPack(['y'])",
+ "Formats must be a list of ModeFormat",
+)
# CHECK: test_expect_mode_format_pack: passed
-test_expect_error("expect_mode_format_pack", ("""
+test_expect_error(
+ "expect_mode_format_pack",
+ (
+ """
mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
f = mlir_pytaco.Format(["x"], mode_ordering)
- """), "Expected a list of ModeFormat")
+ """
+ ),
+ "Expected a list of ModeFormat",
+)
# CHECK: test_expect_mode_ordering: passed
-test_expect_error("expect_mode_ordering", ("""
+test_expect_error(
+ "expect_mode_ordering",
+ (
+ """
mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
f = mlir_pytaco.Format(mode_format_pack, "x")
- """), "Expected ModeOrdering")
+ """
+ ),
+ "Expected ModeOrdering",
+)
# CHECK: test_inconsistent_mode_format_pack_and_mode_ordering: passed
-test_expect_error("inconsistent_mode_format_pack_and_mode_ordering", ("""
+test_expect_error(
+ "inconsistent_mode_format_pack_and_mode_ordering",
+ (
+ """
mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
- """), "Inconsistent ModeFormatPack and ModeOrdering")
+ """
+ ),
+ "Inconsistent ModeFormatPack and ModeOrdering",
+)
# CHECK-LABEL: test_format_default_ordering
@testing_utils.run_test
def test_format_default_ordering():
- f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
- passed = 0
- passed += np.array_equal(f.ordering.ordering, [0, 1])
- # CHECK: Number of passed: 1
- print("Number of passed:", passed)
+ f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
+ passed = 0
+ passed += np.array_equal(f.ordering.ordering, [0, 1])
+ # CHECK: Number of passed: 1
+ print("Number of passed:", passed)
# CHECK-LABEL: test_format_explicit_ordering
@testing_utils.run_test
def test_format_explicit_ordering():
- f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
- passed = 0
- passed += np.array_equal(f.ordering.ordering, [1, 0])
- # CHECK: Number of passed: 1
- print("Number of passed:", passed)
+ f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
+ passed = 0
+ passed += np.array_equal(f.ordering.ordering, [1, 0])
+ # CHECK: Number of passed: 1
+ print("Number of passed:", passed)
# CHECK-LABEL: test_index_var
@testing_utils.run_test
def test_index_var():
- i = mlir_pytaco.IndexVar()
- j = mlir_pytaco.IndexVar()
- passed = (i.name != j.name)
+ i = mlir_pytaco.IndexVar()
+ j = mlir_pytaco.IndexVar()
+ passed = i.name != j.name
- vars = mlir_pytaco.get_index_vars(10)
- passed += (len(vars) == 10)
- passed += (all([isinstance(e, mlir_pytaco.IndexVar) for e in vars]))
+ vars = mlir_pytaco.get_index_vars(10)
+ passed += len(vars) == 10
+ passed += all([isinstance(e, mlir_pytaco.IndexVar) for e in vars])
- # CHECK: Number of passed: 3
- print("Number of passed:", passed)
+ # CHECK: Number of passed: 3
+ print("Number of passed:", passed)
# CHECK: test_tensor_invalid_first_argument: passed
-test_expect_error("tensor_invalid_first_argument",
- "t = mlir_pytaco.Tensor('f')", "Invalid first argument")
+test_expect_error(
+ "tensor_invalid_first_argument",
+ "t = mlir_pytaco.Tensor('f')",
+ "Invalid first argument",
+)
# CHECK: test_tensor_inconsistent_shape_and_format: passed
-test_expect_error("tensor_inconsistent_shape_and_format", ("""
+test_expect_error(
+ "tensor_inconsistent_shape_and_format",
+ (
+ """
mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
mode_ordering = mlir_pytaco.ModeOrdering([0, 1])
f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
t = mlir_pytaco.Tensor([3], f)
- """), "Inconsistent shape and format")
+ """
+ ),
+ "Inconsistent shape and format",
+)
# CHECK: test_tensor_invalid_format: passed
-test_expect_error("tensor_invalid_format", "t = mlir_pytaco.Tensor([3], 'f')",
- "Invalid format argument")
+test_expect_error(
+ "tensor_invalid_format",
+ "t = mlir_pytaco.Tensor([3], 'f')",
+ "Invalid format argument",
+)
# CHECK: test_tensor_insert_nonlist_coordinate: passed
-test_expect_error("tensor_insert_nonlist_coordinate", ("""
+test_expect_error(
+ "tensor_insert_nonlist_coordinate",
+ (
+ """
t = mlir_pytaco.Tensor([3])
t.insert(1, 0)
- """), "Non list coordinate detected")
+ """
+ ),
+ "Non list coordinate detected",
+)
# CHECK: test_tensor_insert_too_much_coordinate: passed
-test_expect_error("tensor_insert_too_much_coordinate", ("""
+test_expect_error(
+ "tensor_insert_too_much_coordinate",
+ (
+ """
t = mlir_pytaco.Tensor([3])
t.insert([0, 0], 0)
- """), "Invalid coordinate")
+ """
+ ),
+ "Invalid coordinate",
+)
# CHECK: test_tensor_insert_coordinate_outof_range: passed
-test_expect_error("tensor_insert_coordinate_outof_range", ("""
+test_expect_error(
+ "tensor_insert_coordinate_outof_range",
+ (
+ """
t = mlir_pytaco.Tensor([1, 1])
t.insert([1, 0], 0)
- """), "Invalid coordinate")
+ """
+ ),
+ "Invalid coordinate",
+)
# CHECK: test_tensor_insert_coordinate_nonint: passed
-test_expect_error("tensor_insert_coordinate_nonint", ("""
+test_expect_error(
+ "tensor_insert_coordinate_nonint",
+ (
+ """
t = mlir_pytaco.Tensor([1, 1])
t.insert([0, "xy"], 0)
- """), "Non integer coordinate detected")
+ """
+ ),
+ "Non integer coordinate detected",
+)
# CHECK: test_tensor_insert_invalid_value: passed
-test_expect_error("tensor_insert_invalid_value", ("""
+test_expect_error(
+ "tensor_insert_invalid_value",
+ (
+ """
t = mlir_pytaco.Tensor([1, 1])
t.insert([0, 0], "x")
- """), "Value is neither int nor float")
+ """
+ ),
+ "Value is neither int nor float",
+)
# CHECK: test_access_non_index_var_index: passed
-test_expect_error("access_non_index_var_index", ("""
+test_expect_error(
+ "access_non_index_var_index",
+ (
+ """
t = mlir_pytaco.Tensor([5, 6])
i = mlir_pytaco.IndexVar()
a = mlir_pytaco.Access(t, (i, "j"))
- """), "Indices contain non IndexVar")
+ """
+ ),
+ "Indices contain non IndexVar",
+)
# CHECK: test_access_inconsistent_rank_indices: passed
-test_expect_error("access_inconsistent_rank_indices", ("""
+test_expect_error(
+ "access_inconsistent_rank_indices",
+ (
+ """
t = mlir_pytaco.Tensor([5, 6])
i = mlir_pytaco.IndexVar()
a = mlir_pytaco.Access(t, (i,))
- """), "Invalid indices for rank")
+ """
+ ),
+ "Invalid indices for rank",
+)
# CHECK: test_access_invalid_indices_for_rank: passed
-test_expect_error("access_invalid_indices_for_rank", ("""
+test_expect_error(
+ "access_invalid_indices_for_rank",
+ (
+ """
t = mlir_pytaco.Tensor([5, 6])
i, j, k = mlir_pytaco.get_index_vars(3)
a = mlir_pytaco.Access(t, (i,j, k))
- """), "Invalid indices for rank")
+ """
+ ),
+ "Invalid indices for rank",
+)
# CHECK: test_invalid_indices: passed
-test_expect_error("invalid_indices", ("""
+test_expect_error(
+ "invalid_indices",
+ (
+ """
i, j = mlir_pytaco.get_index_vars(2)
A = mlir_pytaco.Tensor([2, 3])
B = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2, 3], _DENSE)
C[i, j] = A[1, j] + B[i, j]
- """), "Expected IndexVars")
+ """
+ ),
+ "Expected IndexVars",
+)
# CHECK: test_inconsistent_rank_indices: passed
-test_expect_error("inconsistent_rank_indices", ("""
+test_expect_error(
+ "inconsistent_rank_indices",
+ (
+ """
i, j = mlir_pytaco.get_index_vars(2)
A = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2, 3], _DENSE)
C[i, j] = A[i]
- """), "Invalid indices for rank")
+ """
+ ),
+ "Invalid indices for rank",
+)
# CHECK: test_destination_index_not_used_in_source: passed
-test_expect_error("destination_index_not_used_in_source", ("""
+test_expect_error(
+ "destination_index_not_used_in_source",
+ (
+ """
i, j = mlir_pytaco.get_index_vars(2)
A = mlir_pytaco.Tensor([3])
C = mlir_pytaco.Tensor([3], _DENSE)
C[j] = A[i]
C.evaluate()
- """), "Destination IndexVar not used in the source expression")
+ """
+ ),
+ "Destination IndexVar not used in the source expression",
+)
# CHECK: test_destination_dim_not_consistent_with_source: passed
-test_expect_error("destination_dim_not_consistent_with_source", ("""
+test_expect_error(
+ "destination_dim_not_consistent_with_source",
+ (
+ """
i = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([3])
C = mlir_pytaco.Tensor([5], _DENSE)
C[i] = A[i]
C.evaluate()
- """), "Inconsistent destination dimension for IndexVar")
+ """
+ ),
+ "Inconsistent destination dimension for IndexVar",
+)
# CHECK: test_inconsistent_source_dim: passed
-test_expect_error("inconsistent_source_dim", ("""
+test_expect_error(
+ "inconsistent_source_dim",
+ (
+ """
i = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([3])
B = mlir_pytaco.Tensor([5])
C = mlir_pytaco.Tensor([3], _DENSE)
C[i] = A[i] + B[i]
C.evaluate()
- """), "Inconsistent source dimension for IndexVar")
+ """
+ ),
+ "Inconsistent source dimension for IndexVar",
+)
# CHECK: test_index_var_outside_domain: passed
-test_expect_error("index_var_outside_domain", ("""
+test_expect_error(
+ "index_var_outside_domain",
+ (
+ """
i, j = mlir_pytaco.get_index_vars(2)
A = mlir_pytaco.Tensor([3])
B = mlir_pytaco.Tensor([3])
B[i] = A[i] + j
B.evaluate()
- """), "IndexVar is not part of the iteration domain")
+ """
+ ),
+ "IndexVar is not part of the iteration domain",
+)
# CHECK-LABEL: test_tensor_all_dense_sparse
@testing_utils.run_test
def test_tensor_all_dense_sparse():
- a = mlir_pytaco.Tensor([4], [_DENSE])
- passed = (not a.is_dense())
- passed += (a.order == 1)
- passed += (a.shape[0] == 4)
- # CHECK: Number of passed: 3
- print("Number of passed:", passed)
+ a = mlir_pytaco.Tensor([4], [_DENSE])
+ passed = not a.is_dense()
+ passed += a.order == 1
+ passed += a.shape[0] == 4
+ # CHECK: Number of passed: 3
+ print("Number of passed:", passed)
# CHECK-LABEL: test_tensor_true_dense
@testing_utils.run_test
def test_tensor_true_dense():
- a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
- passed = a.is_dense()
- passed += (a.order == 1)
- passed += (a.shape[0] == 5)
- # CHECK: Number of passed: 3
- print("Number of passed:", passed)
+ a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
+ passed = a.is_dense()
+ passed += a.order == 1
+ passed += a.shape[0] == 5
+ # CHECK: Number of passed: 3
+ print("Number of passed:", passed)
# CHECK-LABEL: test_tensor_copy
@testing_utils.run_test
def test_tensor_copy():
- i, j = mlir_pytaco.get_index_vars(2)
- I = 2
- J = 3
- A = mlir_pytaco.Tensor([I, J])
- A.insert([0, 1], 5.0)
- A.insert([1, 2], 6.0)
- B = mlir_pytaco.Tensor([I, J])
- B[i, j] = A[i, j]
- passed = (B._assignment is not None)
- passed += (B._engine is None)
- try:
+ i, j = mlir_pytaco.get_index_vars(2)
+ I = 2
+ J = 3
+ A = mlir_pytaco.Tensor([I, J])
+ A.insert([0, 1], 5.0)
+ A.insert([1, 2], 6.0)
+ B = mlir_pytaco.Tensor([I, J])
+ B[i, j] = A[i, j]
+ passed = B._assignment is not None
+ passed += B._engine is None
+ try:
+ B.compute()
+ except ValueError as e:
+ passed += str(e).startswith("Need to invoke compile")
+ B.compile()
+ passed += B._engine is not None
B.compute()
- except ValueError as e:
- passed += (str(e).startswith("Need to invoke compile"))
- B.compile()
- passed += (B._engine is not None)
- B.compute()
- passed += (B._assignment is None)
- passed += (B._engine is None)
- indices, values = B.get_coordinates_and_values()
- passed += np.array_equal(indices, [[0, 1], [1, 2]])
- passed += np.allclose(values, [5.0, 6.0])
- # No temporary tensor is used.
- passed += (B._stats.get_total() == 0)
- # CHECK: Number of passed: 9
- print("Number of passed:", passed)
+ passed += B._assignment is None
+ passed += B._engine is None
+ indices, values = B.get_coordinates_and_values()
+ passed += np.array_equal(indices, [[0, 1], [1, 2]])
+ passed += np.allclose(values, [5.0, 6.0])
+ # No temporary tensor is used.
+ passed += B._stats.get_total() == 0
+ # CHECK: Number of passed: 9
+ print("Number of passed:", passed)
# CHECK-LABEL: test_tensor_trivial_reduction
@testing_utils.run_test
def test_tensor_trivial_reduction():
- i, j = mlir_pytaco.get_index_vars(2)
- I = 2
- J = 3
- A = mlir_pytaco.Tensor([I, J])
- A.insert([0, 1], 5.0)
- A.insert([0, 2], 3.0)
- A.insert([1, 2], 6.0)
- B = mlir_pytaco.Tensor([I])
- B[i] = A[i, j]
- indices, values = B.get_coordinates_and_values()
- passed = np.array_equal(indices, [[0], [1]])
- passed += np.allclose(values, [8.0, 6.0])
- # No temporary tensor is used.
- passed += (B._stats.get_total() == 0)
-
- # CHECK: Number of passed: 3
- print("Number of passed:", passed)
+ i, j = mlir_pytaco.get_index_vars(2)
+ I = 2
+ J = 3
+ A = mlir_pytaco.Tensor([I, J])
+ A.insert([0, 1], 5.0)
+ A.insert([0, 2], 3.0)
+ A.insert([1, 2], 6.0)
+ B = mlir_pytaco.Tensor([I])
+ B[i] = A[i, j]
+ indices, values = B.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[0], [1]])
+ passed += np.allclose(values, [8.0, 6.0])
+ # No temporary tensor is used.
+ passed += B._stats.get_total() == 0
+
+ # CHECK: Number of passed: 3
+ print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add
@testing_utils.run_test
def test_binary_add():
- i = mlir_pytaco.IndexVar()
- A = mlir_pytaco.Tensor([4])
- B = mlir_pytaco.Tensor([4])
- C = mlir_pytaco.Tensor([4])
- A.insert([1], 10)
- A.insert([2], 1)
- B.insert([3], 20)
- B.insert([2], 2)
- C[i] = A[i] + B[i]
- indices, values = C.get_coordinates_and_values()
- passed = np.array_equal(indices, [[1], [2], [3]])
- passed += np.array_equal(values, [10., 3., 20.])
- # No temporary tensor is used.
- passed += (C._stats.get_total() == 0)
- # CHECK: Number of passed: 3
- print("Number of passed:", passed)
+ i = mlir_pytaco.IndexVar()
+ A = mlir_pytaco.Tensor([4])
+ B = mlir_pytaco.Tensor([4])
+ C = mlir_pytaco.Tensor([4])
+ A.insert([1], 10)
+ A.insert([2], 1)
+ B.insert([3], 20)
+ B.insert([2], 2)
+ C[i] = A[i] + B[i]
+ indices, values = C.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[1], [2], [3]])
+ passed += np.array_equal(values, [10.0, 3.0, 20.0])
+ # No temporary tensor is used.
+ passed += C._stats.get_total() == 0
+ # CHECK: Number of passed: 3
+ print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_sub
@testing_utils.run_test
def test_binary_add_sub():
- i = mlir_pytaco.IndexVar()
- j = mlir_pytaco.IndexVar()
- A = mlir_pytaco.Tensor([2, 3])
- B = mlir_pytaco.Tensor([2, 3])
- C = mlir_pytaco.Tensor([2, 3])
- D = mlir_pytaco.Tensor([2, 3])
- A.insert([0, 1], 10)
- A.insert([1, 2], 40)
- B.insert([0, 0], 20)
- B.insert([1, 2], 30)
- C.insert([0, 1], 5)
- C.insert([1, 2], 7)
- D[i, j] = A[i, j] + B[i, j] - C[i, j]
- indices, values = D.get_coordinates_and_values()
- passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
- passed += np.array_equal(values, [20., 5., 63.])
- # No temporary tensor is used.
- passed += (D._stats.get_total() == 0)
- # CHECK: Number of passed: 3
- print("Number of passed:", passed)
+ i = mlir_pytaco.IndexVar()
+ j = mlir_pytaco.IndexVar()
+ A = mlir_pytaco.Tensor([2, 3])
+ B = mlir_pytaco.Tensor([2, 3])
+ C = mlir_pytaco.Tensor([2, 3])
+ D = mlir_pytaco.Tensor([2, 3])
+ A.insert([0, 1], 10)
+ A.insert([1, 2], 40)
+ B.insert([0, 0], 20)
+ B.insert([1, 2], 30)
+ C.insert([0, 1], 5)
+ C.insert([1, 2], 7)
+ D[i, j] = A[i, j] + B[i, j] - C[i, j]
+ indices, values = D.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+ passed += np.array_equal(values, [20.0, 5.0, 63.0])
+ # No temporary tensor is used.
+ passed += D._stats.get_total() == 0
+ # CHECK: Number of passed: 3
+ print("Number of passed:", passed)
# CHECK-LABEL: test_binary_mul_add
@testing_utils.run_test
def test_binary_mul_add():
- i = mlir_pytaco.IndexVar()
- j = mlir_pytaco.IndexVar()
- A = mlir_pytaco.Tensor([2, 3])
- B = mlir_pytaco.Tensor([2, 3])
- C = mlir_pytaco.Tensor([2, 3])
- D = mlir_pytaco.Tensor([2, 3])
- A.insert([0, 1], 10)
- A.insert([1, 2], 40)
- B.insert([0, 0], 20)
- B.insert([1, 2], 30)
- C.insert([0, 1], 5)
- C.insert([1, 2], 7)
- D[i, j] = A[i, j] * C[i, j] + B[i, j]
- indices, values = D.get_coordinates_and_values()
- passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
- passed += np.array_equal(values, [20., 50., 310.])
- # No temporary tensor is used.
- passed += (D._stats.get_total() == 0)
- # CHECK: Number of passed: 3
- print("Number of passed:", passed)
+ i = mlir_pytaco.IndexVar()
+ j = mlir_pytaco.IndexVar()
+ A = mlir_pytaco.Tensor([2, 3])
+ B = mlir_pytaco.Tensor([2, 3])
+ C = mlir_pytaco.Tensor([2, 3])
+ D = mlir_pytaco.Tensor([2, 3])
+ A.insert([0, 1], 10)
+ A.insert([1, 2], 40)
+ B.insert([0, 0], 20)
+ B.insert([1, 2], 30)
+ C.insert([0, 1], 5)
+ C.insert([1, 2], 7)
+ D[i, j] = A[i, j] * C[i, j] + B[i, j]
+ indices, values = D.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+ passed += np.array_equal(values, [20.0, 50.0, 310.0])
+ # No temporary tensor is used.
+ passed += D._stats.get_total() == 0
+ # CHECK: Number of passed: 3
+ print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_reduce_at_root
@testing_utils.run_test
def test_binary_add_reduce_at_root():
- i = mlir_pytaco.IndexVar()
- j = mlir_pytaco.IndexVar()
- A = mlir_pytaco.Tensor([2, 3])
- B = mlir_pytaco.Tensor([2, 3])
- C = mlir_pytaco.Tensor([2], _DENSE)
- A.insert([0, 1], 10)
- A.insert([1, 2], 40)
- B.insert([0, 0], 20)
- B.insert([1, 2], 30)
- C[i] = A[i, j] + B[i, j]
- indices, values = C.get_coordinates_and_values()
- passed = np.array_equal(indices, [[0], [1]])
- passed += np.array_equal(values, [30., 70.])
- # No temporary tensor is used.
- passed += (C._stats.get_total() == 0)
- # CHECK: Number of passed: 3
- print("Number of passed:", passed)
+ i = mlir_pytaco.IndexVar()
+ j = mlir_pytaco.IndexVar()
+ A = mlir_pytaco.Tensor([2, 3])
+ B = mlir_pytaco.Tensor([2, 3])
+ C = mlir_pytaco.Tensor([2], _DENSE)
+ A.insert([0, 1], 10)
+ A.insert([1, 2], 40)
+ B.insert([0, 0], 20)
+ B.insert([1, 2], 30)
+ C[i] = A[i, j] + B[i, j]
+ indices, values = C.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[0], [1]])
+ passed += np.array_equal(values, [30.0, 70.0])
+ # No temporary tensor is used.
+ passed += C._stats.get_total() == 0
+ # CHECK: Number of passed: 3
+ print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_reduce_at_child
@testing_utils.run_test
def test_binary_add_reduce_at_child():
- i = mlir_pytaco.IndexVar()
- j = mlir_pytaco.IndexVar()
- I = 2
- J = 3
- A = mlir_pytaco.Tensor([I, J])
- B = mlir_pytaco.Tensor([J])
- C = mlir_pytaco.Tensor([I])
- D = mlir_pytaco.Tensor([I], _DENSE)
-
- _init_2d(A, I, J)
- _init_1d_with_value(C, I, 2)
- _init_1d_with_value(B, J, 1)
-
- D[i] = A[i, j] * B[j] + C[i]
- indices, values = D.get_coordinates_and_values()
- passed = np.array_equal(indices, [[0], [1]])
- passed += np.array_equal(values, [8., 11.])
-
- # The expression is implemented as:
- # temp0[i] = A[i, j] * B[i]
- # D[i] = temp0[i] + C[i]
- # Check the temporary tensor introduced by the implementation.
- stats = D._stats
- passed += (stats.get_total() == 1)
- passed += (stats.get_formats(0) == (_COMPRESSED,))
- passed += (stats.get_dimensions(0) == (I,))
- # CHECK: Number of passed: 5
- print("Number of passed:", passed)
+ i = mlir_pytaco.IndexVar()
+ j = mlir_pytaco.IndexVar()
+ I = 2
+ J = 3
+ A = mlir_pytaco.Tensor([I, J])
+ B = mlir_pytaco.Tensor([J])
+ C = mlir_pytaco.Tensor([I])
+ D = mlir_pytaco.Tensor([I], _DENSE)
+
+ _init_2d(A, I, J)
+ _init_1d_with_value(C, I, 2)
+ _init_1d_with_value(B, J, 1)
+
+ D[i] = A[i, j] * B[j] + C[i]
+ indices, values = D.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[0], [1]])
+ passed += np.array_equal(values, [8.0, 11.0])
+
+ # The expression is implemented as:
+ # temp0[i] = A[i, j] * B[i]
+ # D[i] = temp0[i] + C[i]
+ # Check the temporary tensor introduced by the implementation.
+ stats = D._stats
+ passed += stats.get_total() == 1
+ passed += stats.get_formats(0) == (_COMPRESSED,)
+ passed += stats.get_dimensions(0) == (I,)
+ # CHECK: Number of passed: 5
+ print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_reduce_3d_1
@testing_utils.run_test
def test_binary_add_reduce_3d_1():
- i, j, k, l = mlir_pytaco.get_index_vars(4)
- I = 2
- J = 3
- K = 4
- L = 5
- A = mlir_pytaco.Tensor([I, J, K])
- B = mlir_pytaco.Tensor([I, J, L])
- C = mlir_pytaco.Tensor([K])
- D = mlir_pytaco.Tensor([L])
- E = mlir_pytaco.Tensor([I], _DENSE)
-
- _init_3d(A, I, J, K)
- _init_3d(B, I, J, L)
- _init_1d_with_value(C, K, 1)
- _init_1d_with_value(D, L, 2)
-
- E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
- indices, values = E.get_coordinates_and_values()
- passed = np.array_equal(indices, [[0], [1]])
- passed += np.array_equal(values, [162., 204.])
-
- # The expression is implemented as:
- # temp0[i, j] = A[i, j, k] * C[k]
- # temp1[i, j] = B[i, j, l] * D[l]
- # E[i] = temp0[i, j] + temp1[i, j]
- # Check the two temporary tensors introduced by the implementation.
- stats = E._stats
- passed += (stats.get_total() == 2)
- passed += (stats.get_formats(0) == (_COMPRESSED, _COMPRESSED))
- passed += (stats.get_dimensions(0) == (I, J))
- passed += (stats.get_formats(1) == (_COMPRESSED, _COMPRESSED))
- passed += (stats.get_dimensions(1) == (I, J))
- # CHECK: Number of passed: 7
- print("Number of passed:", passed)
+ i, j, k, l = mlir_pytaco.get_index_vars(4)
+ I = 2
+ J = 3
+ K = 4
+ L = 5
+ A = mlir_pytaco.Tensor([I, J, K])
+ B = mlir_pytaco.Tensor([I, J, L])
+ C = mlir_pytaco.Tensor([K])
+ D = mlir_pytaco.Tensor([L])
+ E = mlir_pytaco.Tensor([I], _DENSE)
+
+ _init_3d(A, I, J, K)
+ _init_3d(B, I, J, L)
+ _init_1d_with_value(C, K, 1)
+ _init_1d_with_value(D, L, 2)
+
+ E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
+ indices, values = E.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[0], [1]])
+ passed += np.array_equal(values, [162.0, 204.0])
+
+ # The expression is implemented as:
+ # temp0[i, j] = A[i, j, k] * C[k]
+ # temp1[i, j] = B[i, j, l] * D[l]
+ # E[i] = temp0[i, j] + temp1[i, j]
+ # Check the two temporary tensors introduced by the implementation.
+ stats = E._stats
+ passed += stats.get_total() == 2
+ passed += stats.get_formats(0) == (_COMPRESSED, _COMPRESSED)
+ passed += stats.get_dimensions(0) == (I, J)
+ passed += stats.get_formats(1) == (_COMPRESSED, _COMPRESSED)
+ passed += stats.get_dimensions(1) == (I, J)
+ # CHECK: Number of passed: 7
+ print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_reduce_3d_2
@testing_utils.run_test
def test_binary_add_reduce_3d_2():
- i, j, k, l = mlir_pytaco.get_index_vars(4)
- I = 2
- J = 3
- K = 4
- L = 5
- A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
- B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
- C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
- D = mlir_pytaco.Tensor([L])
- E = mlir_pytaco.Tensor([I], _DENSE)
-
- _init_3d(A, I, J, K)
- _init_3d(B, I, L, K)
- _init_2d(C, J, K)
- _init_1d_with_value(D, L, 2)
-
- E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
- indices, values = E.get_coordinates_and_values()
- passed = np.array_equal(indices, [[0], [1]])
- passed += np.array_equal(values, [264., 316.])
-
- # The expression is implemented as:
- # temp0[i, k] = A[i, j, k] + C[j, k]
- # temp1[i, k] = B[i, l, k] * D[l]
- # E[i] = temp0[i, k] + temp1[i, k]
- # Check the two temporary tensors introduced by the implementation.
- stats = E._stats
- passed += (stats.get_total() == 2)
- passed += (stats.get_formats(0) == (_COMPRESSED, _DENSE))
- passed += (stats.get_dimensions(0) == (I, K))
- passed += (stats.get_formats(1) == (_DENSE, _COMPRESSED))
- passed += (stats.get_dimensions(1) == (I, K))
- # CHECK: Number of passed: 7
- print("Number of passed:", passed)
+ i, j, k, l = mlir_pytaco.get_index_vars(4)
+ I = 2
+ J = 3
+ K = 4
+ L = 5
+ A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
+ B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
+ C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
+ D = mlir_pytaco.Tensor([L])
+ E = mlir_pytaco.Tensor([I], _DENSE)
+
+ _init_3d(A, I, J, K)
+ _init_3d(B, I, L, K)
+ _init_2d(C, J, K)
+ _init_1d_with_value(D, L, 2)
+
+ E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
+ indices, values = E.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[0], [1]])
+ passed += np.array_equal(values, [264.0, 316.0])
+
+ # The expression is implemented as:
+ # temp0[i, k] = A[i, j, k] + C[j, k]
+ # temp1[i, k] = B[i, l, k] * D[l]
+ # E[i] = temp0[i, k] + temp1[i, k]
+ # Check the two temporary tensors introduced by the implementation.
+ stats = E._stats
+ passed += stats.get_total() == 2
+ passed += stats.get_formats(0) == (_COMPRESSED, _DENSE)
+ passed += stats.get_dimensions(0) == (I, K)
+ passed += stats.get_formats(1) == (_DENSE, _COMPRESSED)
+ passed += stats.get_dimensions(1) == (I, K)
+ # CHECK: Number of passed: 7
+ print("Number of passed:", passed)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
index cce97d6ef1e25..1d5274759b6a9 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
@@ -32,21 +32,21 @@
# CHECK-LABEL: test_read_mtx_matrix_general
@testing_utils.run_test
def test_read_mtx_matrix_general():
- with tempfile.TemporaryDirectory() as test_dir:
- file_name = os.path.join(test_dir, "data.mtx")
- with open(file_name, "w") as file:
- file.write(_MTX_DATA)
- a = mlir_pytaco_io.read(file_name, _FORMAT)
- passed = 0
- # The value of a is stored as an MLIR sparse tensor.
- passed += (not a.is_unpacked())
- a.unpack()
- passed += (a.is_unpacked())
- coords, values = a.get_coordinates_and_values()
- passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
- passed += np.allclose(values, [2.0, 3.0, 4.0])
- # CHECK: 4
- print(passed)
+ with tempfile.TemporaryDirectory() as test_dir:
+ file_name = os.path.join(test_dir, "data.mtx")
+ with open(file_name, "w") as file:
+ file.write(_MTX_DATA)
+ a = mlir_pytaco_io.read(file_name, _FORMAT)
+ passed = 0
+ # The value of a is stored as an MLIR sparse tensor.
+ passed += not a.is_unpacked()
+ a.unpack()
+ passed += a.is_unpacked()
+ coords, values = a.get_coordinates_and_values()
+ passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
+ passed += np.allclose(values, [2.0, 3.0, 4.0])
+ # CHECK: 4
+ print(passed)
_TNS_DATA = """2 3
@@ -60,57 +60,57 @@ def test_read_mtx_matrix_general():
# CHECK-LABEL: test_read_tns
@testing_utils.run_test
def test_read_tns():
- with tempfile.TemporaryDirectory() as test_dir:
- file_name = os.path.join(test_dir, "data.tns")
- with open(file_name, "w") as file:
- file.write(_TNS_DATA)
- a = mlir_pytaco_io.read(file_name, _FORMAT)
- passed = 0
- # The value of a is stored as an MLIR sparse tensor.
- passed += (not a.is_unpacked())
- a.unpack()
- passed += (a.is_unpacked())
- coords, values = a.get_coordinates_and_values()
- passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
- passed += np.allclose(values, [2.0, 3.0, 4.0])
- # CHECK: 4
- print(passed)
+ with tempfile.TemporaryDirectory() as test_dir:
+ file_name = os.path.join(test_dir, "data.tns")
+ with open(file_name, "w") as file:
+ file.write(_TNS_DATA)
+ a = mlir_pytaco_io.read(file_name, _FORMAT)
+ passed = 0
+ # The value of a is stored as an MLIR sparse tensor.
+ passed += not a.is_unpacked()
+ a.unpack()
+ passed += a.is_unpacked()
+ coords, values = a.get_coordinates_and_values()
+ passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
+ passed += np.allclose(values, [2.0, 3.0, 4.0])
+ # CHECK: 4
+ print(passed)
# CHECK-LABEL: test_write_unpacked_tns
@testing_utils.run_test
def test_write_unpacked_tns():
- a = mlir_pytaco.Tensor([2, 3])
- a.insert([0, 1], 10)
- a.insert([1, 2], 40)
- a.insert([0, 0], 20)
- with tempfile.TemporaryDirectory() as test_dir:
- file_name = os.path.join(test_dir, "data.tns")
- try:
- mlir_pytaco_io.write(file_name, a)
- except ValueError as e:
- # CHECK: Writing unpacked sparse tensors to file is not supported
- print(e)
+ a = mlir_pytaco.Tensor([2, 3])
+ a.insert([0, 1], 10)
+ a.insert([1, 2], 40)
+ a.insert([0, 0], 20)
+ with tempfile.TemporaryDirectory() as test_dir:
+ file_name = os.path.join(test_dir, "data.tns")
+ try:
+ mlir_pytaco_io.write(file_name, a)
+ except ValueError as e:
+ # CHECK: Writing unpacked sparse tensors to file is not supported
+ print(e)
# CHECK-LABEL: test_write_packed_tns
@testing_utils.run_test
def test_write_packed_tns():
- a = mlir_pytaco.Tensor([2, 3])
- a.insert([0, 1], 10)
- a.insert([1, 2], 40)
- a.insert([0, 0], 20)
- b = mlir_pytaco.Tensor([2, 3])
- i, j = mlir_pytaco.get_index_vars(2)
- b[i, j] = a[i, j] + a[i, j]
- with tempfile.TemporaryDirectory() as test_dir:
- file_name = os.path.join(test_dir, "data.tns")
- mlir_pytaco_io.write(file_name, b)
- with open(file_name, "r") as file:
- lines = file.readlines()
- passed = 0
- # Skip the comment line in the output.
- if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
- passed = 1
- # CHECK: 1
- print(passed)
+ a = mlir_pytaco.Tensor([2, 3])
+ a.insert([0, 1], 10)
+ a.insert([1, 2], 40)
+ a.insert([0, 0], 20)
+ b = mlir_pytaco.Tensor([2, 3])
+ i, j = mlir_pytaco.get_index_vars(2)
+ b[i, j] = a[i, j] + a[i, j]
+ with tempfile.TemporaryDirectory() as test_dir:
+ file_name = os.path.join(test_dir, "data.tns")
+ mlir_pytaco_io.write(file_name, b)
+ with open(file_name, "r") as file:
+ lines = file.readlines()
+ passed = 0
+ # Skip the comment line in the output.
+ if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
+ passed = 1
+ # CHECK: 1
+ print(passed)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py
index 13259698b1b12..1344f4aa741ab 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py
@@ -20,79 +20,93 @@
def _to_string(s: Sequence[int]) -> str:
- """Converts a sequence of integer to a space separated value string."""
- return " ".join(map(lambda e: str(e), s))
+ """Converts a sequence of integer to a space separated value string."""
+ return " ".join(map(lambda e: str(e), s))
def _add_one(s: Sequence[int]) -> Sequence[int]:
- """Adds one to each element in the sequence of integer."""
- return [i + 1 for i in s]
+ """Adds one to each element in the sequence of integer."""
+ return [i + 1 for i in s]
@dataclasses.dataclass(frozen=True)
class _SparseTensorCOO:
- """Values for a COO-flavored format sparse tensor.
-
- Attributes:
- rank: An integer rank for the tensor.
- nse: An integer for the number of non-zero values.
- shape: A sequence of integer for the dimension size.
- values: A sequence of float for the non-zero values of the tensor.
- indices: A sequence of coordinate, each coordinate is a sequence of integer.
- """
- rank: int
- nse: int
- shape: Sequence[int]
- values: Sequence[float]
- indices: Sequence[Sequence[int]]
+ """Values for a COO-flavored format sparse tensor.
+
+ Attributes:
+ rank: An integer rank for the tensor.
+ nse: An integer for the number of non-zero values.
+ shape: A sequence of integer for the dimension size.
+ values: A sequence of float for the non-zero values of the tensor.
+ indices: A sequence of coordinate, each coordinate is a sequence of integer.
+ """
+
+ rank: int
+ nse: int
+ shape: Sequence[int]
+ values: Sequence[float]
+ indices: Sequence[Sequence[int]]
def _coo_values_to_tns_format(t: _SparseTensorCOO) -> str:
- """Converts a sparse tensor COO-flavored values to TNS text format."""
- # The coo_value_str contains one line for each (coordinate value) pair.
- # Indices are 1-based in TNS text format but 0-based in MLIR.
- coo_value_str = "\n".join(
- map(lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
- range(t.nse)))
-
- # Returns the TNS text format representation for the tensor.
- return f"""{t.rank} {t.nse}
+ """Converts a sparse tensor COO-flavored values to TNS text format."""
+ # The coo_value_str contains one line for each (coordinate value) pair.
+ # Indices are 1-based in TNS text format but 0-based in MLIR.
+ coo_value_str = "\n".join(
+ map(
+ lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
+ range(t.nse),
+ )
+ )
+
+ # Returns the TNS text format representation for the tensor.
+ return f"""{t.rank} {t.nse}
{_to_string(t.shape)}
{coo_value_str}
"""
def _implement_read_tns_test(
- t: _SparseTensorCOO,
- sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> int:
- tns_data = _coo_values_to_tns_format(t)
-
- # Write sparse tensor data to a file.
- with tempfile.TemporaryDirectory() as test_dir:
- file_name = os.path.join(test_dir, "data.tns")
- with open(file_name, "w") as file:
- file.write(tns_data)
-
- # Read the data from the file and construct an MLIR sparse tensor.
- sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
- file_name, sparsity_codes, "f64")
-
- passed = 0
-
- # Verify the output shape for the tensor.
- if np.array_equal(o_shape, t.shape):
- passed += 1
-
- # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
- # values and verify the values.
- o_rank, o_nse, o_shape, o_values, o_indices = (
- pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64))
- if o_rank == t.rank and o_nse == t.nse and np.array_equal(
- o_shape, t.shape) and np.allclose(o_values, t.values) and np.array_equal(
- o_indices, t.indices):
- passed += 1
-
- return passed
+ t: _SparseTensorCOO, sparsity_codes: Sequence[sparse_tensor.DimLevelType]
+) -> int:
+ tns_data = _coo_values_to_tns_format(t)
+
+ # Write sparse tensor data to a file.
+ with tempfile.TemporaryDirectory() as test_dir:
+ file_name = os.path.join(test_dir, "data.tns")
+ with open(file_name, "w") as file:
+ file.write(tns_data)
+
+ # Read the data from the file and construct an MLIR sparse tensor.
+ sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
+ file_name, sparsity_codes, "f64"
+ )
+
+ passed = 0
+
+ # Verify the output shape for the tensor.
+ if np.array_equal(o_shape, t.shape):
+ passed += 1
+
+ # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
+ # values and verify the values.
+ (
+ o_rank,
+ o_nse,
+ o_shape,
+ o_values,
+ o_indices,
+ ) = pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64)
+ if (
+ o_rank == t.rank
+ and o_nse == t.nse
+ and np.array_equal(o_shape, t.shape)
+ and np.allclose(o_values, t.values)
+ and np.array_equal(o_indices, t.indices)
+ ):
+ passed += 1
+
+ return passed
# A 2D sparse tensor data in COO-flavored format.
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
index 12c97dbfa61ba..70b4b66f4378d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
@@ -5,11 +5,11 @@ if not config.mlir_run_amx_tests:
config.unsupported = True
# No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
config.unsupported = True
if config.intel_sde_executable:
# Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU.
- config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli'))
+ config.substitutions.append(("%lli", config.intel_sde_executable + " -spr -- lli"))
else:
- config.substitutions.append(('%lli', 'lli'))
+ config.substitutions.append(("%lli", "lli"))
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg
index 0423fc03da5f9..296b4419438e8 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg
@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sme_tests:
config.unsupported = True
# No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg
index 8a0d884509c80..37d3a74874ce4 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg
@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sve_tests:
config.unsupported = True
# No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg
index 0e22874db1d18..bde815616b2db 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg
@@ -5,11 +5,11 @@ if not config.mlir_run_x86vector_tests:
config.unsupported = True
# No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
config.unsupported = True
if config.intel_sde_executable:
# Run test in emulator (Intel SDE).
- config.substitutions.append(('%lli', config.intel_sde_executable + ' -tgl -- lli'))
+ config.substitutions.append(("%lli", config.intel_sde_executable + " -tgl -- lli"))
else:
- config.substitutions.append(('%lli', 'lli'))
+ config.substitutions.append(("%lli", "lli"))
diff --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg
index 0bdebfedeee36..acb8dd43f50b4 100644
--- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg
@@ -1,2 +1,2 @@
if not config.enable_cuda_runner:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg
index 451b9fcbed3df..3bd70245df7df 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg
@@ -2,4 +2,4 @@ import sys
# TensorCore tests must be enabled via build flag.
if not config.mlir_run_cuda_tensor_core_tests:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/lit.local.cfg
index 0bdebfedeee36..acb8dd43f50b4 100644
--- a/mlir/test/Integration/GPU/CUDA/lit.local.cfg
+++ b/mlir/test/Integration/GPU/CUDA/lit.local.cfg
@@ -1,2 +1,2 @@
if not config.enable_cuda_runner:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/Integration/GPU/ROCM/lit.local.cfg b/mlir/test/Integration/GPU/ROCM/lit.local.cfg
index b0d086f9d4d51..e1f864857c5c1 100644
--- a/mlir/test/Integration/GPU/ROCM/lit.local.cfg
+++ b/mlir/test/Integration/GPU/ROCM/lit.local.cfg
@@ -1,4 +1,4 @@
if not config.enable_rocm_runner or not config.rocm_test_chipset:
- config.unsupported = True
+ config.unsupported = True
-config.substitutions.append(('%chip', config.rocm_test_chipset))
+config.substitutions.append(("%chip", config.rocm_test_chipset))
diff --git a/mlir/test/Integration/lit.local.cfg b/mlir/test/Integration/lit.local.cfg
index 80a862a1ce664..1b4a323871d75 100644
--- a/mlir/test/Integration/lit.local.cfg
+++ b/mlir/test/Integration/lit.local.cfg
@@ -3,8 +3,9 @@ from lit.llvm import llvm_config
if not config.mlir_include_integration_tests:
config.unsupported = True
+
def configure_aarch64_lli_cmd():
- lli_cmd = 'lli'
+ lli_cmd = "lli"
# NOTE: If the SVE tests are disabled and the SME tests are enabled to run
# under emulation, the SVE specific RUN lines in the SparseTensor tests
@@ -12,8 +13,12 @@ def configure_aarch64_lli_cmd():
if not (config.mlir_run_arm_sve_tests or config.mlir_run_arm_sme_tests):
return lli_cmd
- config.substitutions.append(('%mlir_native_utils_lib_dir',
- config.arm_emulator_utils_lib_dir or config.mlir_lib_dir))
+ config.substitutions.append(
+ (
+ "%mlir_native_utils_lib_dir",
+ config.arm_emulator_utils_lib_dir or config.mlir_lib_dir,
+ )
+ )
if config.arm_emulator_executable:
if config.arm_emulator_lli_executable:
@@ -23,16 +28,22 @@ def configure_aarch64_lli_cmd():
# when running under an emulator. If the user didn't specify an lli
# executable, use absolute path %llvm_tools_dir/lli.
lli_cmd = llvm_config.use_llvm_tool(
- 'lli', search_env='LLI', required=True,
- search_paths=[config.llvm_tools_dir], use_installed=False
+ "lli",
+ search_env="LLI",
+ required=True,
+ search_paths=[config.llvm_tools_dir],
+ use_installed=False,
)
# Run test in emulator (qemu or armie)
- emulation_cmd = f'{config.arm_emulator_executable} {config.arm_emulator_options}'
- lli_cmd = f'{emulation_cmd} {lli_cmd}'
+ emulation_cmd = (
+ f"{config.arm_emulator_executable} {config.arm_emulator_options}"
+ )
+ lli_cmd = f"{emulation_cmd} {lli_cmd}"
return lli_cmd
+
aarch64_lli_cmd = configure_aarch64_lli_cmd()
# Configure the following AArch64 substitutions:
@@ -52,5 +63,5 @@ aarch64_lli_cmd = configure_aarch64_lli_cmd()
# could be used in the SparseTensor tests where necessary, but the meaning
# conveyed by the substitution name would be a misnomer if the host target
# is not AArch64 and MLIR_RUN_ARM_SVE_TESTS=OFF.
-config.substitutions.append(('%lli_aarch64_cmd', aarch64_lli_cmd))
-config.substitutions.append(('%lli_host_or_aarch64_cmd', aarch64_lli_cmd))
+config.substitutions.append(("%lli_aarch64_cmd", aarch64_lli_cmd))
+config.substitutions.append(("%lli_host_or_aarch64_cmd", aarch64_lli_cmd))
diff --git a/mlir/test/Unit/lit.cfg.py b/mlir/test/Unit/lit.cfg.py
index 5b66517b1788e..1898b72adf002 100644
--- a/mlir/test/Unit/lit.cfg.py
+++ b/mlir/test/Unit/lit.cfg.py
@@ -8,43 +8,43 @@
import lit.formats
# name: The name of this test suite.
-config.name = 'MLIR-Unit'
+config.name = "MLIR-Unit"
# suffixes: A list of file extensions to treat as test files.
config.suffixes = []
# test_source_root: The root path where tests are located.
# test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.mlir_obj_root, 'unittests')
+config.test_exec_root = os.path.join(config.mlir_obj_root, "unittests")
config.test_source_root = config.test_exec_root
# testFormat: The test format to use to interpret tests.
-config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, 'Tests')
+config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, "Tests")
# Propagate the temp directory. Windows requires this because it uses \Windows\
# if none of these are present.
-if 'TMP' in os.environ:
- config.environment['TMP'] = os.environ['TMP']
-if 'TEMP' in os.environ:
- config.environment['TEMP'] = os.environ['TEMP']
+if "TMP" in os.environ:
+ config.environment["TMP"] = os.environ["TMP"]
+if "TEMP" in os.environ:
+ config.environment["TEMP"] = os.environ["TEMP"]
# Propagate HOME as it can be used to override incorrect homedir in passwd
# that causes the tests to fail.
-if 'HOME' in os.environ:
- config.environment['HOME'] = os.environ['HOME']
+if "HOME" in os.environ:
+ config.environment["HOME"] = os.environ["HOME"]
# Propagate sanitizer options.
for var in [
- 'ASAN_SYMBOLIZER_PATH',
- 'HWASAN_SYMBOLIZER_PATH',
- 'MSAN_SYMBOLIZER_PATH',
- 'TSAN_SYMBOLIZER_PATH',
- 'UBSAN_SYMBOLIZER_PATH',
- 'ASAN_OPTIONS',
- 'HWASAN_OPTIONS',
- 'MSAN_OPTIONS',
- 'TSAN_OPTIONS',
- 'UBSAN_OPTIONS',
+ "ASAN_SYMBOLIZER_PATH",
+ "HWASAN_SYMBOLIZER_PATH",
+ "MSAN_SYMBOLIZER_PATH",
+ "TSAN_SYMBOLIZER_PATH",
+ "UBSAN_SYMBOLIZER_PATH",
+ "ASAN_OPTIONS",
+ "HWASAN_OPTIONS",
+ "MSAN_OPTIONS",
+ "TSAN_OPTIONS",
+ "UBSAN_OPTIONS",
]:
if var in os.environ:
config.environment[var] = os.environ[var]
diff --git a/mlir/test/lib/Dialect/Test/lit.local.cfg b/mlir/test/lib/Dialect/Test/lit.local.cfg
index edb5b44b2e2fe..65a7f202dc82a 100644
--- a/mlir/test/lib/Dialect/Test/lit.local.cfg
+++ b/mlir/test/lib/Dialect/Test/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.remove('.td')
\ No newline at end of file
+config.suffixes.remove(".td")
diff --git a/mlir/test/lib/Dialect/Transform/lit.local.cfg b/mlir/test/lib/Dialect/Transform/lit.local.cfg
index edb5b44b2e2fe..65a7f202dc82a 100644
--- a/mlir/test/lib/Dialect/Transform/lit.local.cfg
+++ b/mlir/test/lib/Dialect/Transform/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.remove('.td')
\ No newline at end of file
+config.suffixes.remove(".td")
diff --git a/mlir/test/lib/Tools/PDLL/lit.local.cfg b/mlir/test/lib/Tools/PDLL/lit.local.cfg
index 8cfe5cd834f06..8ffccee1d6d79 100644
--- a/mlir/test/lib/Tools/PDLL/lit.local.cfg
+++ b/mlir/test/lib/Tools/PDLL/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.remove('.pdll')
+config.suffixes.remove(".pdll")
diff --git a/mlir/test/lib/Transforms/lit.local.cfg b/mlir/test/lib/Transforms/lit.local.cfg
index 8cfe5cd834f06..8ffccee1d6d79 100644
--- a/mlir/test/lib/Transforms/lit.local.cfg
+++ b/mlir/test/lib/Transforms/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.remove('.pdll')
+config.suffixes.remove(".pdll")
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 1fc2e319d19fd..ad0b0d5567779 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -16,21 +16,32 @@
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
-config.name = 'MLIR'
+config.name = "MLIR"
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test', '.pdll', '.c']
+config.suffixes = [
+ ".td",
+ ".mlir",
+ ".toy",
+ ".ll",
+ ".tc",
+ ".py",
+ ".yaml",
+ ".test",
+ ".pdll",
+ ".c",
+]
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.mlir_obj_root, 'test')
+config.test_exec_root = os.path.join(config.mlir_obj_root, "test")
-config.substitutions.append(('%PATH%', config.environment['PATH']))
-config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
+config.substitutions.append(("%PATH%", config.environment["PATH"]))
+config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
config.substitutions.append(("%mlir_src_root", config.mlir_src_root))
config.substitutions.append(("%host_cxx", config.host_cxx))
config.substitutions.append(("%host_cc", config.host_cc))
@@ -40,94 +51,109 @@
# substitution of the same name and the found path.
# Correctly handles the platforms shared library directory and naming conventions.
def add_runtime(name):
- path = ''
- for prefix in ['', 'lib']:
- path = os.path.join(config.llvm_shlib_dir, f'{prefix}{name}{config.llvm_shlib_ext}')
+ path = ""
+ for prefix in ["", "lib"]:
+ path = os.path.join(
+ config.llvm_shlib_dir, f"{prefix}{name}{config.llvm_shlib_ext}"
+ )
if os.path.isfile(path):
break
- return ToolSubst(f'%{name}', path)
+ return ToolSubst(f"%{name}", path)
-llvm_config.with_system_environment(
- ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
+llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
llvm_config.use_default_substitutions()
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
-config.excludes = ['Inputs', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
- 'lit.cfg.py', 'lit.site.cfg.py']
+config.excludes = [
+ "Inputs",
+ "CMakeLists.txt",
+ "README.txt",
+ "LICENSE.txt",
+ "lit.cfg.py",
+ "lit.site.cfg.py",
+]
# Tweak the PATH to include the tools dir.
-llvm_config.with_environment('PATH', config.mlir_tools_dir, append_path=True)
-llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
+llvm_config.with_environment("PATH", config.mlir_tools_dir, append_path=True)
+llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
tools = [
- 'mlir-tblgen',
- 'mlir-translate',
- 'mlir-lsp-server',
- 'mlir-capi-execution-engine-test',
- 'mlir-capi-ir-test',
- 'mlir-capi-llvm-test',
- 'mlir-capi-pass-test',
- 'mlir-capi-pdl-test',
- 'mlir-capi-quant-test',
- 'mlir-capi-sparse-tensor-test',
- 'mlir-capi-transform-test',
- 'mlir-cpu-runner',
- add_runtime('mlir_runner_utils'),
- add_runtime('mlir_c_runner_utils'),
- add_runtime('mlir_async_runtime'),
- 'mlir-linalg-ods-yaml-gen',
- 'mlir-reduce',
- 'mlir-pdll',
- 'not',
+ "mlir-tblgen",
+ "mlir-translate",
+ "mlir-lsp-server",
+ "mlir-capi-execution-engine-test",
+ "mlir-capi-ir-test",
+ "mlir-capi-llvm-test",
+ "mlir-capi-pass-test",
+ "mlir-capi-pdl-test",
+ "mlir-capi-quant-test",
+ "mlir-capi-sparse-tensor-test",
+ "mlir-capi-transform-test",
+ "mlir-cpu-runner",
+ add_runtime("mlir_runner_utils"),
+ add_runtime("mlir_c_runner_utils"),
+ add_runtime("mlir_async_runtime"),
+ "mlir-linalg-ods-yaml-gen",
+ "mlir-reduce",
+ "mlir-pdll",
+ "not",
]
if config.enable_spirv_cpu_runner:
- tools.extend(['mlir-spirv-cpu-runner', add_runtime('mlir_test_spirv_cpu_runner_c_wrappers')])
+ tools.extend(
+ ["mlir-spirv-cpu-runner", add_runtime("mlir_test_spirv_cpu_runner_c_wrappers")]
+ )
if config.enable_vulkan_runner:
- tools.extend([add_runtime('vulkan-runtime-wrappers')])
+ tools.extend([add_runtime("vulkan-runtime-wrappers")])
if config.enable_rocm_runner:
- tools.extend([add_runtime('mlir_rocm_runtime')])
+ tools.extend([add_runtime("mlir_rocm_runtime")])
if config.enable_cuda_runner:
- tools.extend([add_runtime('mlir_cuda_runtime')])
+ tools.extend([add_runtime("mlir_cuda_runtime")])
# The following tools are optional
-tools.extend([
- ToolSubst('toyc-ch1', unresolved='ignore'),
- ToolSubst('toyc-ch2', unresolved='ignore'),
- ToolSubst('toyc-ch3', unresolved='ignore'),
- ToolSubst('toyc-ch4', unresolved='ignore'),
- ToolSubst('toyc-ch5', unresolved='ignore'),
- ToolSubst('toyc-ch6', unresolved='ignore'),
- ToolSubst('toyc-ch7', unresolved='ignore'),
- ToolSubst('%mlir_lib_dir', config.mlir_lib_dir, unresolved='ignore'),
- ToolSubst('%mlir_src_dir', config.mlir_src_root, unresolved='ignore'),
-])
+tools.extend(
+ [
+ ToolSubst("toyc-ch1", unresolved="ignore"),
+ ToolSubst("toyc-ch2", unresolved="ignore"),
+ ToolSubst("toyc-ch3", unresolved="ignore"),
+ ToolSubst("toyc-ch4", unresolved="ignore"),
+ ToolSubst("toyc-ch5", unresolved="ignore"),
+ ToolSubst("toyc-ch6", unresolved="ignore"),
+ ToolSubst("toyc-ch7", unresolved="ignore"),
+ ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"),
+ ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"),
+ ]
+)
python_executable = config.python_executable
# Python configuration with sanitizer requires some magic preloading. This will only work on clang/linux.
# TODO: detect Darwin/Windows situation (or mark these tests as unsupported on these platforms).
if "asan" in config.available_features and "Linux" in config.host_os:
- python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
+ python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
# On Windows the path to python could contains spaces in which case it needs to be provided in quotes.
# This is the equivalent of how %python is setup in llvm/utils/lit/lit/llvm/config.py.
elif "Windows" in config.host_os:
- python_executable = '"%s"' % (python_executable)
-tools.extend([
- ToolSubst('%PYTHON', python_executable, unresolved='ignore'),
-])
+ python_executable = '"%s"' % (python_executable)
+tools.extend(
+ [
+ ToolSubst("%PYTHON", python_executable, unresolved="ignore"),
+ ]
+)
if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ:
- tools.extend([
- ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'),
- ])
+ tools.extend(
+ [
+ ToolSubst("mlir-opt", "mlir-opt --verify-roundtrip", unresolved="fatal"),
+ ]
+ )
llvm_config.add_tool_substitutions(tools, tool_dirs)
@@ -135,40 +161,48 @@ def add_runtime(name):
# FileCheck -enable-var-scope is enabled by default in MLIR test
# This option avoids to accidentally reuse variable across -LABEL match,
# it can be explicitly opted-in by prefixing the variable name with $
-config.environment['FILECHECK_OPTS'] = "-enable-var-scope --allow-unused-prefixes=false"
+config.environment["FILECHECK_OPTS"] = "-enable-var-scope --allow-unused-prefixes=false"
# Add the python path for both the source and binary tree.
# Note that presently, the python sources come from the source tree and the
# binaries come from the build tree. This should be unified to the build tree
# by copying/linking sources to build.
if config.enable_bindings_python:
- llvm_config.with_environment('PYTHONPATH', [
- os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_core'),
- os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_test'),
- ], append_path=True)
+ llvm_config.with_environment(
+ "PYTHONPATH",
+ [
+ os.path.join(config.mlir_obj_root, "python_packages", "mlir_core"),
+ os.path.join(config.mlir_obj_root, "python_packages", "mlir_test"),
+ ],
+ append_path=True,
+ )
if config.enable_assertions:
- config.available_features.add('asserts')
+ config.available_features.add("asserts")
else:
- config.available_features.add('noasserts')
+ config.available_features.add("noasserts")
+
def have_host_jit_feature_support(feature_name):
- mlir_cpu_runner_exe = lit.util.which('mlir-cpu-runner', config.mlir_tools_dir)
+ mlir_cpu_runner_exe = lit.util.which("mlir-cpu-runner", config.mlir_tools_dir)
+
+ if not mlir_cpu_runner_exe:
+ return False
- if not mlir_cpu_runner_exe:
- return False
+ try:
+ mlir_cpu_runner_cmd = subprocess.Popen(
+ [mlir_cpu_runner_exe, "--host-supports-" + feature_name],
+ stdout=subprocess.PIPE,
+ )
+ except OSError:
+ print("could not exec mlir-cpu-runner")
+ return False
- try:
- mlir_cpu_runner_cmd = subprocess.Popen(
- [mlir_cpu_runner_exe, '--host-supports-' + feature_name], stdout=subprocess.PIPE)
- except OSError:
- print('could not exec mlir-cpu-runner')
- return False
+ mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode("ascii")
+ mlir_cpu_runner_cmd.wait()
- mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode('ascii')
- mlir_cpu_runner_cmd.wait()
+ return "true" in mlir_cpu_runner_out
- return 'true' in mlir_cpu_runner_out
-if have_host_jit_feature_support('jit'):
- config.available_features.add('host-supports-jit')
+if have_host_jit_feature_support("jit"):
+ config.available_features.add("host-supports-jit")
diff --git a/mlir/test/mlir-cpu-runner/lit.local.cfg b/mlir/test/mlir-cpu-runner/lit.local.cfg
index 3f59ff1bd9774..3c20d203b800f 100644
--- a/mlir/test/mlir-cpu-runner/lit.local.cfg
+++ b/mlir/test/mlir-cpu-runner/lit.local.cfg
@@ -1,12 +1,11 @@
import sys
# MSAN does not work with JIT.
-if 'msan' in config.available_features:
- config.unsupported = True
+if "msan" in config.available_features:
+ config.unsupported = True
# Requires native execution.
-if 'host-supports-jit' not in config.available_features:
+if "host-supports-jit" not in config.available_features:
config.unsupported = True
-config.available_features.add(
- config.root.native_target.lower() + '-native-target')
+config.available_features.add(config.root.native_target.lower() + "-native-target")
diff --git a/mlir/test/mlir-pdll-lsp-server/lit.local.cfg b/mlir/test/mlir-pdll-lsp-server/lit.local.cfg
index 25d08c7aba306..aa35dbfa8c01f 100644
--- a/mlir/test/mlir-pdll-lsp-server/lit.local.cfg
+++ b/mlir/test/mlir-pdll-lsp-server/lit.local.cfg
@@ -1 +1 @@
-config.excludes = ['include']
+config.excludes = ["include"]
diff --git a/mlir/test/mlir-pdll/lit.local.cfg b/mlir/test/mlir-pdll/lit.local.cfg
index c438027edc2c4..4cb5622aaa761 100644
--- a/mlir/test/mlir-pdll/lit.local.cfg
+++ b/mlir/test/mlir-pdll/lit.local.cfg
@@ -1,2 +1,2 @@
-config.suffixes = ['.pdll', '.mlir']
-config.excludes = ['include']
+config.suffixes = [".pdll", ".mlir"]
+config.excludes = ["include"]
diff --git a/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg b/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
index 286bea4cace67..8717dd025498e 100644
--- a/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
+++ b/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
@@ -1,4 +1,4 @@
import sys
if not config.enable_spirv_cpu_runner:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/mlir-vulkan-runner/lit.local.cfg b/mlir/test/mlir-vulkan-runner/lit.local.cfg
index f99be2aeb70e9..6da7fcdd610aa 100644
--- a/mlir/test/mlir-vulkan-runner/lit.local.cfg
+++ b/mlir/test/mlir-vulkan-runner/lit.local.cfg
@@ -1,2 +1,2 @@
if not config.enable_vulkan_runner:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/python/develoment_files.py b/mlir/test/python/develoment_files.py
index ea0a911890a07..4dc3a0b700b1f 100644
--- a/mlir/test/python/develoment_files.py
+++ b/mlir/test/python/develoment_files.py
@@ -14,5 +14,6 @@
all_libs = os.listdir(get_lib_dirs()[0])
found_lib = False
for file_name in all_libs:
- if expected_lib_name in file_name: found_lib = True
+ if expected_lib_name in file_name:
+ found_lib = True
assert found_lib, f"Did not find '{expected_lib_name}' lib in {all_libs}"
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index acae9b6474083..8e9613d052466 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,16 +4,18 @@
import mlir.dialects.func as func
import mlir.dialects.arith as arith
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
+ print("\nTEST:", f.__name__)
+ f()
+
# CHECK-LABEL: TEST: testConstantOp
@run
def testConstantOps():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- arith.ConstantOp(value=42.42, result=F32Type.get())
- # CHECK: %cst = arith.constant 4.242000e+01 : f32
- print(module)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ arith.ConstantOp(value=42.42, result=F32Type.get())
+ # CHECK: %cst = arith.constant 4.242000e+01 : f32
+ print(module)
diff --git a/mlir/test/python/dialects/async_dialect.py b/mlir/test/python/dialects/async_dialect.py
index da3103cecddf2..f6181cc76118e 100644
--- a/mlir/test/python/dialects/async_dialect.py
+++ b/mlir/test/python/dialects/async_dialect.py
@@ -5,14 +5,17 @@
import mlir.dialects.async_dialect.passes
from mlir.passmanager import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
+ print("\nTEST:", f.__name__)
+ f()
+
def testAsyncPass():
- with Context() as context:
- PassManager.parse('any(async-to-async-runtime)')
- print('SUCCESS')
+ with Context() as context:
+ PassManager.parse("any(async-to-async-runtime)")
+ print("SUCCESS")
+
# CHECK-LABEL: testAsyncPass
# CHECK: SUCCESS
diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index eab24b5c796b0..18ebba61e7fea 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -7,232 +7,242 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: testFromPyFunc
@run
def testFromPyFunc():
- with Context() as ctx, Location.unknown() as loc:
- ctx.allow_unregistered_dialects = True
- m = builtin.ModuleOp()
- f32 = F32Type.get()
- f64 = F64Type.get()
- with InsertionPoint(m.body):
- # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
- # CHECK: return %arg0 : f64
- @func.FuncOp.from_py_func(f64)
- def unary_return(a):
- return a
-
- # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
- # CHECK: return %arg0, %arg1 : f32, f64
- @func.FuncOp.from_py_func(f32, f64)
- def binary_return(a, b):
- return a, b
-
- # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
- # CHECK: return
- @func.FuncOp.from_py_func(f32, f64)
- def none_return(a, b):
- pass
-
- # CHECK-LABEL: func @call_unary
- # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
- # CHECK: return %0 : f64
- @func.FuncOp.from_py_func(f64)
- def call_unary(a):
- return unary_return(a)
-
- # CHECK-LABEL: func @call_binary
- # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
- # CHECK: return %0#0, %0#1 : f32, f64
- @func.FuncOp.from_py_func(f32, f64)
- def call_binary(a, b):
- return binary_return(a, b)
-
- # We expect coercion of a single result operation to a returned value.
- # CHECK-LABEL: func @single_result_op
- # CHECK: %0 = "custom.op1"() : () -> f32
- # CHECK: return %0 : f32
- @func.FuncOp.from_py_func()
- def single_result_op():
- return Operation.create("custom.op1", results=[f32])
-
- # CHECK-LABEL: func @call_none
- # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
- # CHECK: return
- @func.FuncOp.from_py_func(f32, f64)
- def call_none(a, b):
- return none_return(a, b)
-
- ## Variants and optional feature tests.
- # CHECK-LABEL: func @from_name_arg
- @func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
- def explicit_name(a, b):
- return b
-
- @func.FuncOp.from_py_func(f32, f64)
- def positional_func_op(a, b, func_op):
- assert isinstance(func_op, func.FuncOp)
- return b
-
- @func.FuncOp.from_py_func(f32, f64)
- def kw_func_op(a, b=None, func_op=None):
- assert isinstance(func_op, func.FuncOp)
- return b
-
- @func.FuncOp.from_py_func(f32, f64)
- def kwargs_func_op(a, b=None, **kwargs):
- assert isinstance(kwargs["func_op"], func.FuncOp)
- return b
-
- # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
- # CHECK: return %arg1 : f64
- @func.FuncOp.from_py_func(f32, f64, results=[f64])
- def explicit_results(a, b):
- func.ReturnOp([b])
-
- print(m)
+ with Context() as ctx, Location.unknown() as loc:
+ ctx.allow_unregistered_dialects = True
+ m = builtin.ModuleOp()
+ f32 = F32Type.get()
+ f64 = F64Type.get()
+ with InsertionPoint(m.body):
+ # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
+ # CHECK: return %arg0 : f64
+ @func.FuncOp.from_py_func(f64)
+ def unary_return(a):
+ return a
+
+ # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
+ # CHECK: return %arg0, %arg1 : f32, f64
+ @func.FuncOp.from_py_func(f32, f64)
+ def binary_return(a, b):
+ return a, b
+
+ # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
+ # CHECK: return
+ @func.FuncOp.from_py_func(f32, f64)
+ def none_return(a, b):
+ pass
+
+ # CHECK-LABEL: func @call_unary
+ # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
+ # CHECK: return %0 : f64
+ @func.FuncOp.from_py_func(f64)
+ def call_unary(a):
+ return unary_return(a)
+
+ # CHECK-LABEL: func @call_binary
+ # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
+ # CHECK: return %0#0, %0#1 : f32, f64
+ @func.FuncOp.from_py_func(f32, f64)
+ def call_binary(a, b):
+ return binary_return(a, b)
+
+ # We expect coercion of a single result operation to a returned value.
+ # CHECK-LABEL: func @single_result_op
+ # CHECK: %0 = "custom.op1"() : () -> f32
+ # CHECK: return %0 : f32
+ @func.FuncOp.from_py_func()
+ def single_result_op():
+ return Operation.create("custom.op1", results=[f32])
+
+ # CHECK-LABEL: func @call_none
+ # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
+ # CHECK: return
+ @func.FuncOp.from_py_func(f32, f64)
+ def call_none(a, b):
+ return none_return(a, b)
+
+ ## Variants and optional feature tests.
+ # CHECK-LABEL: func @from_name_arg
+ @func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
+ def explicit_name(a, b):
+ return b
+
+ @func.FuncOp.from_py_func(f32, f64)
+ def positional_func_op(a, b, func_op):
+ assert isinstance(func_op, func.FuncOp)
+ return b
+
+ @func.FuncOp.from_py_func(f32, f64)
+ def kw_func_op(a, b=None, func_op=None):
+ assert isinstance(func_op, func.FuncOp)
+ return b
+
+ @func.FuncOp.from_py_func(f32, f64)
+ def kwargs_func_op(a, b=None, **kwargs):
+ assert isinstance(kwargs["func_op"], func.FuncOp)
+ return b
+
+ # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
+ # CHECK: return %arg1 : f64
+ @func.FuncOp.from_py_func(f32, f64, results=[f64])
+ def explicit_results(a, b):
+ func.ReturnOp([b])
+
+ print(m)
# CHECK-LABEL: TEST: testFromPyFuncErrors
@run
def testFromPyFuncErrors():
- with Context() as ctx, Location.unknown() as loc:
- m = builtin.ModuleOp()
- f32 = F32Type.get()
- f64 = F64Type.get()
- with InsertionPoint(m.body):
- try:
+ with Context() as ctx, Location.unknown() as loc:
+ m = builtin.ModuleOp()
+ f32 = F32Type.get()
+ f64 = F64Type.get()
+ with InsertionPoint(m.body):
+ try:
- @func.FuncOp.from_py_func(f64, results=[f64])
- def unary_return(a):
- return a
- except AssertionError as e:
- # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
- print(e)
+ @func.FuncOp.from_py_func(f64, results=[f64])
+ def unary_return(a):
+ return a
+
+ except AssertionError as e:
+ # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
+ print(e)
# CHECK-LABEL: TEST: testBuildFuncOp
@run
def testBuildFuncOp():
- ctx = Context()
- with Location.unknown(ctx) as loc:
- m = builtin.ModuleOp()
-
- f32 = F32Type.get()
- tensor_type = RankedTensorType.get((2, 3, 4), f32)
- with InsertionPoint.at_block_begin(m.body):
- f = func.FuncOp(name="some_func",
- type=FunctionType.get(
- inputs=[tensor_type, tensor_type],
- results=[tensor_type]),
- visibility="nested")
- # CHECK: Name is: "some_func"
- print("Name is: ", f.name)
-
- # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
- print("Type is: ", f.type)
-
- # CHECK: Visibility is: "nested"
- print("Visibility is: ", f.visibility)
-
- try:
- entry_block = f.entry_block
- except IndexError as e:
- # CHECK: External function does not have a body
- print(e)
-
- with InsertionPoint(f.add_entry_block()):
- func.ReturnOp([f.entry_block.arguments[0]])
- pass
-
- try:
- f.add_entry_block()
- except IndexError as e:
- # CHECK: The function already has an entry block!
- print(e)
-
- # Try the callback builder and passing type as tuple.
- f = func.FuncOp(name="some_other_func",
- type=([tensor_type, tensor_type], [tensor_type]),
- visibility="nested",
- body_builder=lambda f: func.ReturnOp(
- [f.entry_block.arguments[0]]))
-
- # CHECK: module {
- # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
- # CHECK: return %arg0 : tensor<2x3x4xf32>
- # CHECK: }
- # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
- # CHECK: return %arg0 : tensor<2x3x4xf32>
- # CHECK: }
- print(m)
+ ctx = Context()
+ with Location.unknown(ctx) as loc:
+ m = builtin.ModuleOp()
+
+ f32 = F32Type.get()
+ tensor_type = RankedTensorType.get((2, 3, 4), f32)
+ with InsertionPoint.at_block_begin(m.body):
+ f = func.FuncOp(
+ name="some_func",
+ type=FunctionType.get(
+ inputs=[tensor_type, tensor_type], results=[tensor_type]
+ ),
+ visibility="nested",
+ )
+ # CHECK: Name is: "some_func"
+ print("Name is: ", f.name)
+
+ # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+ print("Type is: ", f.type)
+
+ # CHECK: Visibility is: "nested"
+ print("Visibility is: ", f.visibility)
+
+ try:
+ entry_block = f.entry_block
+ except IndexError as e:
+ # CHECK: External function does not have a body
+ print(e)
+
+ with InsertionPoint(f.add_entry_block()):
+ func.ReturnOp([f.entry_block.arguments[0]])
+ pass
+
+ try:
+ f.add_entry_block()
+ except IndexError as e:
+ # CHECK: The function already has an entry block!
+ print(e)
+
+ # Try the callback builder and passing type as tuple.
+ f = func.FuncOp(
+ name="some_other_func",
+ type=([tensor_type, tensor_type], [tensor_type]),
+ visibility="nested",
+ body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]),
+ )
+
+ # CHECK: module {
+ # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ # CHECK: return %arg0 : tensor<2x3x4xf32>
+ # CHECK: }
+ # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ # CHECK: return %arg0 : tensor<2x3x4xf32>
+ # CHECK: }
+ print(m)
# CHECK-LABEL: TEST: testFuncArgumentAccess
@run
def testFuncArgumentAccess():
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- module = Module.create()
- f32 = F32Type.get()
- f64 = F64Type.get()
- with InsertionPoint(module.body):
- f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
- with InsertionPoint(f.add_entry_block()):
- func.ReturnOp(f.arguments)
- f.arg_attrs = ArrayAttr.get([
- DictAttr.get({
- "custom_dialect.foo": StringAttr.get("bar"),
- "custom_dialect.baz": UnitAttr.get()
- }),
- DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
- ])
- f.result_attrs = ArrayAttr.get([
- DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
- DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
- ])
-
- other = func.FuncOp("other_func", ([f32, f32], []))
- with InsertionPoint(other.add_entry_block()):
- func.ReturnOp([])
- other.arg_attrs = [
- DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
- DictAttr.get()
- ]
-
- # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
- print(f.arg_attrs)
-
- # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
- print(f.result_attrs)
-
- # CHECK: func @some_func(
- # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
- # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
- # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
- # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
- # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
- #
- # CHECK: func @other_func(
- # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
- # CHECK: %{{.*}}: f32)
- print(module)
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ module = Module.create()
+ f32 = F32Type.get()
+ f64 = F64Type.get()
+ with InsertionPoint(module.body):
+ f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
+ with InsertionPoint(f.add_entry_block()):
+ func.ReturnOp(f.arguments)
+ f.arg_attrs = ArrayAttr.get(
+ [
+ DictAttr.get(
+ {
+ "custom_dialect.foo": StringAttr.get("bar"),
+ "custom_dialect.baz": UnitAttr.get(),
+ }
+ ),
+ DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}),
+ ]
+ )
+ f.result_attrs = ArrayAttr.get(
+ [
+ DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
+ DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}),
+ ]
+ )
+
+ other = func.FuncOp("other_func", ([f32, f32], []))
+ with InsertionPoint(other.add_entry_block()):
+ func.ReturnOp([])
+ other.arg_attrs = [
+ DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
+ DictAttr.get(),
+ ]
+
+ # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
+ print(f.arg_attrs)
+
+ # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
+ print(f.result_attrs)
+
+ # CHECK: func @some_func(
+ # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
+ # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
+ # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
+ # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
+ # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
+ #
+ # CHECK: func @other_func(
+ # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
+ # CHECK: %{{.*}}: f32)
+ print(module)
# CHECK-LABEL: testDenseElementsAttr
@run
def testDenseElementsAttr():
- with Context(), Location.unknown():
- values = np.arange(4, dtype=np.int32)
- i32 = IntegerType.get_signless(32)
- print(DenseElementsAttr.get(values, type=i32))
- # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
- print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
- # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
- print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
- # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
+ with Context(), Location.unknown():
+ values = np.arange(4, dtype=np.int32)
+ i32 = IntegerType.get_signless(32)
+ print(DenseElementsAttr.get(values, type=i32))
+ # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
+ print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
+ # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+ print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
+ # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
diff --git a/mlir/test/python/dialects/complex_dialect.py b/mlir/test/python/dialects/complex_dialect.py
index e724575b5bf5b..afad21757bc3c 100644
--- a/mlir/test/python/dialects/complex_dialect.py
+++ b/mlir/test/python/dialects/complex_dialect.py
@@ -9,24 +9,24 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
+ print("\nTEST:", f.__name__)
+ f()
# CHECK-LABEL: TEST: testComplexOps
@run
def testComplexOps():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
- def emit_add(arg):
- return mlir_complex.AddOp(arg, arg)
-
- # CHECK-LABEL: func @emit_add(
- # CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
- # CHECK: %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
- # CHECK: return %[[RES]] : complex<f32>
- # CHECK: }
- print(module)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
+ def emit_add(arg):
+ return mlir_complex.AddOp(arg, arg)
+
+ # CHECK-LABEL: func @emit_add(
+ # CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+ # CHECK: %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
+ # CHECK: return %[[RES]] : complex<f32>
+ # CHECK: }
+ print(module)
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 3be9cac2c1925..161a12d78776a 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -7,13 +7,13 @@
def constructAndPrintInModule(f):
- print("\nTEST:", f.__name__)
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- f()
- print(module)
- return f
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
# CHECK-LABEL: TEST: testConstantOp
@@ -21,21 +21,21 @@ def constructAndPrintInModule(f):
@constructAndPrintInModule
def testConstantOp():
- c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
- c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
- c3 = arith.ConstantOp(F32Type.get(), 3.14)
- c4 = arith.ConstantOp(F64Type.get(), 1.23)
- # CHECK: 42
- print(c1.literal_value)
+ c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
+ c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
+ c3 = arith.ConstantOp(F32Type.get(), 3.14)
+ c4 = arith.ConstantOp(F64Type.get(), 1.23)
+ # CHECK: 42
+ print(c1.literal_value)
- # CHECK: 100
- print(c2.literal_value)
+ # CHECK: 100
+ print(c2.literal_value)
- # CHECK: 3.140000104904175
- print(c3.literal_value)
+ # CHECK: 3.140000104904175
+ print(c3.literal_value)
- # CHECK: 1.23
- print(c4.literal_value)
+ # CHECK: 1.23
+ print(c4.literal_value)
# CHECK: = arith.constant 42 : i32
@@ -47,17 +47,17 @@ def testConstantOp():
# CHECK-LABEL: TEST: testVectorConstantOp
@constructAndPrintInModule
def testVectorConstantOp():
- int_type = IntegerType.get_signless(32)
- vec_type = VectorType.get([2, 2], int_type)
- c1 = arith.ConstantOp(
- vec_type,
- DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
- try:
- print(c1.literal_value)
- except ValueError as e:
- assert "only integer and float constants have literal values" in str(e)
- else:
- assert False
+ int_type = IntegerType.get_signless(32)
+ vec_type = VectorType.get([2, 2], int_type)
+ c1 = arith.ConstantOp(
+ vec_type, DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))
+ )
+ try:
+ print(c1.literal_value)
+ except ValueError as e:
+ assert "only integer and float constants have literal values" in str(e)
+ else:
+ assert False
# CHECK: = arith.constant dense<42> : vector<2x2xi32>
@@ -66,9 +66,9 @@ def testVectorConstantOp():
# CHECK-LABEL: TEST: testConstantIndexOp
@constructAndPrintInModule
def testConstantIndexOp():
- c1 = arith.ConstantOp.create_index(10)
- # CHECK: 10
- print(c1.literal_value)
+ c1 = arith.ConstantOp.create_index(10)
+ # CHECK: 10
+ print(c1.literal_value)
# CHECK: = arith.constant 10 : index
@@ -77,18 +77,18 @@ def testConstantIndexOp():
# CHECK-LABEL: TEST: testFunctionCalls
@constructAndPrintInModule
def testFunctionCalls():
- foo = func.FuncOp("foo", ([], []))
- foo.sym_visibility = StringAttr.get("private")
- bar = func.FuncOp("bar", ([], [IndexType.get()]))
- bar.sym_visibility = StringAttr.get("private")
- qux = func.FuncOp("qux", ([], [F32Type.get()]))
- qux.sym_visibility = StringAttr.get("private")
-
- with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
- func.CallOp(foo, [])
- func.CallOp([IndexType.get()], "bar", [])
- func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
- func.ReturnOp([])
+ foo = func.FuncOp("foo", ([], []))
+ foo.sym_visibility = StringAttr.get("private")
+ bar = func.FuncOp("bar", ([], [IndexType.get()]))
+ bar.sym_visibility = StringAttr.get("private")
+ qux = func.FuncOp("qux", ([], [F32Type.get()]))
+ qux.sym_visibility = StringAttr.get("private")
+
+ with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
+ func.CallOp(foo, [])
+ func.CallOp([IndexType.get()], "bar", [])
+ func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
+ func.ReturnOp([])
# CHECK: func private @foo()
diff --git a/mlir/test/python/dialects/gpu.py b/mlir/test/python/dialects/gpu.py
index 38bf038a5eeed..7eefaed711c2c 100644
--- a/mlir/test/python/dialects/gpu.py
+++ b/mlir/test/python/dialects/gpu.py
@@ -5,14 +5,17 @@
import mlir.dialects.gpu.passes
from mlir.passmanager import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
+ print("\nTEST:", f.__name__)
+ f()
+
def testGPUPass():
- with Context() as context:
- PassManager.parse('any(gpu-kernel-outlining)')
- print('SUCCESS')
+ with Context() as context:
+ PassManager.parse("any(gpu-kernel-outlining)")
+ print("SUCCESS")
+
# CHECK-LABEL: testGPUPass
# CHECK: SUCCESS
diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index d787c5f49c441..7892d020d9067 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -34,8 +34,9 @@ def matmul(
C=TensorDef(U, S.M, S.N, output=True),
bfn=BinaryFnAttrDef(default=BinaryFn.mul),
ufn=UnaryFnAttrDef(default=UnaryFn.exp),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
# CHECK: ---
@@ -47,7 +48,7 @@ def matmul(
# CHECK: type_var: T
@linalg_structured_op
def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
- O[D.m, D.n] = value
+ O[D.m, D.n] = value
# CHECK: ---
@@ -71,5 +72,6 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
def strided_copy(
I=TensorDef(T, S.IH, S.IW),
O=TensorDef(T, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 2])):
- O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 2]),
+):
+ O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index eacf43547b110..ad0a3eac5913e 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -35,8 +35,9 @@ def matmul(
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True),
mul=BinaryFnAttrDef(default=BinaryFn.mul),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
# CHECK: ---
@@ -79,12 +80,12 @@ def matmul(
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
@linalg_structured_op
def constants(
- O=TensorDef(T, S.M, S.K, output=True),
- exp=UnaryFnAttrDef(default=UnaryFn.exp)):
- pi = TypeFn.cast_signed(T, const(3.1415926535897931))
- cst42 = TypeFn.cast_signed(T, const(42))
- cst1000 = TypeFn.cast_signed(T, exp(const(1e+3)))
- O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
+ O=TensorDef(T, S.M, S.K, output=True), exp=UnaryFnAttrDef(default=UnaryFn.exp)
+):
+ pi = TypeFn.cast_signed(T, const(3.1415926535897931))
+ cst42 = TypeFn.cast_signed(T, const(42))
+ cst1000 = TypeFn.cast_signed(T, exp(const(1e3)))
+ O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
# CHECK: ---
@@ -100,7 +101,7 @@ def constants(
# CHECK: scalar_index: 0
@linalg_structured_op
def indices(O=TensorDef(T, S.M, S.K, output=True)):
- O[D.m, D.n] = index(D.n) + index(D.m)
+ O[D.m, D.n] = index(D.n) + index(D.m)
# CHECK: ---
@@ -111,4 +112,4 @@ def indices(O=TensorDef(T, S.M, S.K, output=True)):
# CHECK: scalar_arg: value
@linalg_structured_op
def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
- O[D.m, D.n] = value
+ O[D.m, D.n] = value
diff --git a/mlir/test/python/dialects/linalg/opdsl/doctests.py b/mlir/test/python/dialects/linalg/opdsl/doctests.py
index 4aae768848815..d2f9cec19d570 100644
--- a/mlir/test/python/dialects/linalg/opdsl/doctests.py
+++ b/mlir/test/python/dialects/linalg/opdsl/doctests.py
@@ -3,10 +3,11 @@
import doctest
import importlib
+
def test_module(module_name):
- print(f"--- Testing module: {module_name}")
- m = importlib.import_module(module_name)
- doctest.testmod(m, verbose=True, raise_on_error=True, report=True)
+ print(f"--- Testing module: {module_name}")
+ m = importlib.import_module(module_name)
+ doctest.testmod(m, verbose=True, raise_on_error=True, report=True)
test_module("mlir.dialects.linalg.opdsl.lang.affine")
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
index ebe2c0f33a286..d666d313767b9 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
@@ -17,43 +17,44 @@ def conv_poly(
K=TensorDef(T2, S.KH, S.KW, S.C),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])):
- domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
- D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c])
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2]),
+):
+ domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+ O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+ ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c])
with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
-
- # Convolution indexing maps.
- # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
- # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
- # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
-
- # CHECK-LABEL: @test_f32i32_conv
- # CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
- # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
- # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
- # CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
- # CHECK-NEXT: %[[FILTER_CAST:.+]] = arith.fptosi %[[FILTER:.+]] : f32 to i32
- # CHECK-NEXT: %[[PROD:.+]] = arith.muli %[[IN_CAST]], %[[FILTER_CAST]] : i32
- # CHECK-NEXT: %[[SUM:.+]] = arith.addi %[[OUT]], %[[PROD]] : i32
- # CHECK-NEXT: linalg.yield %[[SUM]] : i32
- # CHECK-NEXT: -> tensor<1x2x4x1xi32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((1, 4, 16, 1), f32),
- RankedTensorType.get((2, 2, 1), f32),
- RankedTensorType.get((1, 2, 4, 1), i32))
- def test_f32i32_conv(input, filter, init_result):
- # Use default dilations and set non-default strides.
- return conv_poly(
- input, filter, outs=[init_result], strides=[2, 4])
+ module = Module.create()
+ f32 = F32Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ # Convolution indexing maps.
+ # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
+ # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+ # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+
+ # CHECK-LABEL: @test_f32i32_conv
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
+ # CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
+ # CHECK-NEXT: %[[FILTER_CAST:.+]] = arith.fptosi %[[FILTER:.+]] : f32 to i32
+ # CHECK-NEXT: %[[PROD:.+]] = arith.muli %[[IN_CAST]], %[[FILTER_CAST]] : i32
+ # CHECK-NEXT: %[[SUM:.+]] = arith.addi %[[OUT]], %[[PROD]] : i32
+ # CHECK-NEXT: linalg.yield %[[SUM]] : i32
+ # CHECK-NEXT: -> tensor<1x2x4x1xi32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 4, 16, 1), f32),
+ RankedTensorType.get((2, 2, 1), f32),
+ RankedTensorType.get((1, 2, 4, 1), i32),
+ )
+ def test_f32i32_conv(input, filter, init_result):
+ # Use default dilations and set non-default strides.
+ return conv_poly(input, filter, outs=[init_result], strides=[2, 4])
print(module)
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
index 1f840b09b0085..ffef737755e85 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
@@ -13,47 +13,51 @@
@linalg_structured_op
def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
- O[None] = TypeFn.cast_signed(U, value)
+ O[None] = TypeFn.cast_signed(U, value)
+
@linalg_structured_op
def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)):
- O[None] = TypeFn.cast_signed(U, I[None])
+ O[None] = TypeFn.cast_signed(U, I[None])
+
with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- # Fill indexing maps.
- # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
- # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
- # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
- # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
- # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
- # CHECK-LABEL: @test_fill_0d
- # CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]
- # CHECK-SAME: iterator_types = []
- @func.FuncOp.from_py_func(f32, RankedTensorType.get([], f32))
- def test_fill_0d(value, init_result):
- return fill_poly(value, outs=[init_result])
-
- # CHECK-LABEL: @test_fill_2d
- # CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]]
- # CHECK-SAME: iterator_types = ["parallel", "parallel"]
- @func.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32))
- def test_fill_2d(value, init_result):
- return fill_poly(value, outs=[init_result])
-
- # CHECK-LABEL: @test_fill_rank_zero_3d
- # CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
- # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
- @func.FuncOp.from_py_func(
- RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32))
- def test_fill_rank_zero_3d(input, init_result):
- return fill_rank_zero_poly(input, outs=[init_result])
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ # Fill indexing maps.
+ # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
+ # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+ # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+ # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
+ # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+ # CHECK-LABEL: @test_fill_0d
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]
+ # CHECK-SAME: iterator_types = []
+ @func.FuncOp.from_py_func(f32, RankedTensorType.get([], f32))
+ def test_fill_0d(value, init_result):
+ return fill_poly(value, outs=[init_result])
+
+ # CHECK-LABEL: @test_fill_2d
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel"]
+ @func.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32))
+ def test_fill_2d(value, init_result):
+ return fill_poly(value, outs=[init_result])
+
+ # CHECK-LABEL: @test_fill_rank_zero_3d
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32)
+ )
+ def test_fill_rank_zero_3d(input, init_result):
+ return fill_rank_zero_poly(input, outs=[init_result])
+
print(module)
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
index 6dff754859eef..18c237c68081a 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
@@ -16,9 +16,10 @@
def matmul_mono(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
- C=TensorDef(T, S.M, S.N, output=True)):
- domain(D.m, D.n, D.k)
- C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
+ C=TensorDef(T, S.M, S.N, output=True),
+):
+ domain(D.m, D.n, D.k)
+ C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
@linalg_structured_op
@@ -26,146 +27,162 @@ def matmul_poly(
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- domain(D.m, D.n, D.k)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ domain(D.m, D.n, D.k)
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
with Context() as ctx, Location.unknown():
- module = Module.create()
- f16 = F16Type.get()
- f32 = F32Type.get()
- f64 = F64Type.get()
- i8 = IntegerType.get_signless(8)
- i16 = IntegerType.get_signless(16)
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
-
- # Multiplication indexing maps. We verify only the indexing maps of the
- # first multiplication and then do additional tests on casting and body
- # generation behavior.
- # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
- # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
- # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
- # CHECK-LABEL: func @test_matmul_mono
- # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32>
- # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32>
- # CHECK: %[[INITC:.+]] = tensor.empty() : tensor<4x8xf32>
- # CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]]
- # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
- # CHECK-SAME: ins(%[[A]], %[[B]]
- # CHECK-SAME: outs(%[[INITC]]
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32))
- def test_matmul_mono(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
- return matmul_mono(lhs, rhs, outs=[init_result.result])
-
- # CHECK-LABEL: @test_i8i8i32_matmul
- # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
- # CHECK-NEXT: %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
- # CHECK-NEXT: %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i8 to i32
- # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
- # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
- # CHECK-NEXT: linalg.yield %[[ADD]] : i32
- # CHECK-NEXT: -> tensor<4x8xi32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
- RankedTensorType.get((4, 8), i32))
- def test_i8i8i32_matmul(lhs, rhs, init_result):
- return matmul_poly(lhs, rhs, outs=[init_result])
-
- # CHECK-LABEL: @test_i8i8i32_matmul_unsigned
- # CHECK: = arith.extui
- # CHECK: = arith.extui
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
- RankedTensorType.get((4, 8), i32))
- def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
- return matmul_poly(
- lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
-
- # CHECK-LABEL: @test_i8i16i32_matmul
- # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
- # CHECK-NEXT: %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
- # CHECK-NEXT: %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i16 to i32
- # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
- # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
- # CHECK-NEXT: linalg.yield %[[ADD]] : i32
- # CHECK-NEXT: -> tensor<4x8xi32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i16),
- RankedTensorType.get((4, 8), i32))
- def test_i8i16i32_matmul(lhs, rhs, init_result):
- return matmul_poly(lhs, rhs, outs=[init_result])
-
- # CHECK-LABEL: @test_i32i32i16_matmul
- # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
- # CHECK-NEXT: %[[A_CAST:.+]] = arith.trunci %[[A_ARG]] : i32 to i16
- # CHECK-NEXT: %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i32 to i16
- # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i16
- # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i16
- # CHECK-NEXT: linalg.yield %[[ADD]] : i16
- # CHECK-NEXT: -> tensor<4x8xi16>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), i32), RankedTensorType.get((16, 8), i32),
- RankedTensorType.get((4, 8), i16))
- def test_i32i32i16_matmul(lhs, rhs, init_result):
- return matmul_poly(lhs, rhs, outs=[init_result])
-
- # CHECK-LABEL: @test_i8i8f32_matmul
- # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
- # CHECK-NEXT: %[[A_CAST:.+]] = arith.sitofp %[[A_ARG]] : i8 to f32
- # CHECK-NEXT: %[[B_CAST:.+]] = arith.sitofp %[[B_ARG]] : i8 to f32
- # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
- # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
- # CHECK-NEXT: linalg.yield %[[ADD]] : f32
- # CHECK-NEXT: -> tensor<4x8xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
- RankedTensorType.get((4, 8), f32))
- def test_i8i8f32_matmul(lhs, rhs, init_result):
- return matmul_poly(lhs, rhs, outs=[init_result])
-
- # CHECK-LABEL: @test_i8i8f32_matmul_unsigned
- # CHECK: = arith.uitofp
- # CHECK: = arith.uitofp
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
- RankedTensorType.get((4, 8), f32))
- def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
- return matmul_poly(
- lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
-
- # CHECK-LABEL: @test_f16f16f32_matmul
- # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
- # CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
- # CHECK-NEXT: %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
- # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
- # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
- # CHECK-NEXT: linalg.yield %[[ADD]] : f32
- # CHECK-NEXT: -> tensor<4x8xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f16), RankedTensorType.get((16, 8), f16),
- RankedTensorType.get((4, 8), f32))
- def test_f16f16f32_matmul(lhs, rhs, init_result):
- return matmul_poly(lhs, rhs, outs=[init_result])
-
- # CHECK-LABEL: @test_f64f64f32_matmul
- # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
- # CHECK-NEXT: %[[A_CAST:.+]] = arith.truncf %[[A_ARG]] : f64 to f32
- # CHECK-NEXT: %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
- # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
- # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
- # CHECK-NEXT: linalg.yield %[[ADD]] : f32
- # CHECK-NEXT: -> tensor<4x8xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f64), RankedTensorType.get((16, 8), f64),
- RankedTensorType.get((4, 8), f32))
- def test_f64f64f32_matmul(lhs, rhs, init_result):
- return matmul_poly(lhs, rhs, outs=[init_result])
+ module = Module.create()
+ f16 = F16Type.get()
+ f32 = F32Type.get()
+ f64 = F64Type.get()
+ i8 = IntegerType.get_signless(8)
+ i16 = IntegerType.get_signless(16)
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ # Multiplication indexing maps. We verify only the indexing maps of the
+ # first multiplication and then do additional tests on casting and body
+ # generation behavior.
+ # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+ # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+ # CHECK-LABEL: func @test_matmul_mono
+ # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32>
+ # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32>
+ # CHECK: %[[INITC:.+]] = tensor.empty() : tensor<4x8xf32>
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+ # CHECK-SAME: ins(%[[A]], %[[B]]
+ # CHECK-SAME: outs(%[[INITC]]
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+ )
+ def test_matmul_mono(lhs, rhs):
+ init_result = tensor.EmptyOp([4, 8], f32)
+ return matmul_mono(lhs, rhs, outs=[init_result.result])
+
+ # CHECK-LABEL: @test_i8i8i32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
+ # CHECK-NEXT: %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i8 to i32
+ # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
+ # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : i32
+ # CHECK-NEXT: -> tensor<4x8xi32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), i32),
+ )
+ def test_i8i8i32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_i8i8i32_matmul_unsigned
+ # CHECK: = arith.extui
+ # CHECK: = arith.extui
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), i32),
+ )
+ def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
+
+ # CHECK-LABEL: @test_i8i16i32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
+ # CHECK-NEXT: %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i16 to i32
+ # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
+ # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : i32
+ # CHECK-NEXT: -> tensor<4x8xi32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i16),
+ RankedTensorType.get((4, 8), i32),
+ )
+ def test_i8i16i32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_i32i32i16_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
+ # CHECK-NEXT: %[[A_CAST:.+]] = arith.trunci %[[A_ARG]] : i32 to i16
+ # CHECK-NEXT: %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i32 to i16
+ # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i16
+ # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i16
+ # CHECK-NEXT: linalg.yield %[[ADD]] : i16
+ # CHECK-NEXT: -> tensor<4x8xi16>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i32),
+ RankedTensorType.get((16, 8), i32),
+ RankedTensorType.get((4, 8), i16),
+ )
+ def test_i32i32i16_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_i8i8f32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = arith.sitofp %[[A_ARG]] : i8 to f32
+ # CHECK-NEXT: %[[B_CAST:.+]] = arith.sitofp %[[B_ARG]] : i8 to f32
+ # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+ # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : f32
+ # CHECK-NEXT: -> tensor<4x8xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), f32),
+ )
+ def test_i8i8f32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_i8i8f32_matmul_unsigned
+ # CHECK: = arith.uitofp
+ # CHECK: = arith.uitofp
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), f32),
+ )
+ def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
+
+ # CHECK-LABEL: @test_f16f16f32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+ # CHECK-NEXT: %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
+ # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+ # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : f32
+ # CHECK-NEXT: -> tensor<4x8xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f16),
+ RankedTensorType.get((16, 8), f16),
+ RankedTensorType.get((4, 8), f32),
+ )
+ def test_f16f16f32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_f64f64f32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = arith.truncf %[[A_ARG]] : f64 to f32
+ # CHECK-NEXT: %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
+ # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+ # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : f32
+ # CHECK-NEXT: -> tensor<4x8xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f64),
+ RankedTensorType.get((16, 8), f64),
+ RankedTensorType.get((4, 8), f32),
+ )
+ def test_f64f64f32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
print(module)
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index aad714998c108..f8e034fb0e48b 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -17,14 +17,16 @@
@linalg_structured_op
def test_const(O=TensorDef(F32, S.M, S.N, output=True)):
- O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned(
- F32, const(2.3283064e-10))
+ O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned(
+ F32, const(2.3283064e-10)
+ )
@linalg_structured_op
def test_index(O=TensorDef(I32, S.M, S.N, output=True)):
- O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed(
- I32, index(D.n))
+ O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed(
+ I32, index(D.n)
+ )
@linalg_structured_op
@@ -32,120 +34,129 @@ def elemwise_unary_poly(
I=TensorDef(T),
O=TensorDef(U, output=True),
fun=UnaryFnAttrDef(default=UnaryFn.exp),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- O[None] = fun(cast(U, I[None]))
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ O[None] = fun(cast(U, I[None]))
@linalg_structured_op(op_name="custom_op_name")
def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
- O[D.n] = I[D.n]
+ O[D.n] = I[D.n]
with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- c32 = ComplexType.get(f32)
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
-
- # CHECK-LABEL: @test_f32_const
- # CHECK-DAG: %[[CST0:.+]] = arith.constant 42 : i64
- # CHECK-DAG: %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32
- # CHECK-DAG: %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64
- # CHECK-DAG: %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32
- # CHECK-DAG: %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32
- # CHECK-NEXT: linalg.yield %[[SUM]] : f32
- @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32))
- def test_f32_const(init_result):
- return test_const(outs=[init_result])
-
- # CHECK-LABEL: @test_i32_index
- # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
- # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
- # CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32
- # CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32
- # CHECK-DAG: %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32
- # CHECK-NEXT: linalg.yield %[[SUM]] : i32
- @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32))
- def test_i32_index(init_result):
- return test_index(outs=[init_result])
-
- # CHECK-LABEL: @test_f32_elemwise_exp
- # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
- # CHECK-NEXT: -> tensor<4x16xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
- def test_f32_elemwise_exp(input, init_result):
- return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
-
- # CHECK-LABEL: @test_f32_elemwise_log
- # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[LOG]] : f32
- # CHECK-NEXT: -> tensor<4x16xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
- def test_f32_elemwise_log(input, init_result):
- return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
-
- # CHECK-LABEL: @test_f32_elemwise_abs
- # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.absf %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
- # CHECK-NEXT: -> tensor<4x16xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
- def test_f32_elemwise_abs(input, init_result):
- return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
-
- # CHECK-LABEL: @test_f32_elemwise_ceil
- # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
- # CHECK-NEXT: -> tensor<4x16xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
- def test_f32_elemwise_ceil(input, init_result):
- return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
-
- # CHECK-LABEL: @test_f32_elemwise_floor
- # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
- # CHECK-NEXT: -> tensor<4x16xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
- def test_f32_elemwise_floor(input, init_result):
- return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
-
- # CHECK-LABEL: @test_f32_elemwise_neg
- # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
- # CHECK-NEXT: -> tensor<4x16xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
- def test_f32_elemwise_neg(input, init_result):
- return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
-
- # CHECK-LABEL: @test_c32_elemwise_neg
- # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
- # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
- # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32>
- # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32))
- def test_c32_elemwise_neg(input, init_result):
- return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
-
- # Just check that we don't assert out on name mismatch.
- # CHECK-LABEL: @test_non_default_op_name
- @func.FuncOp.from_py_func(
- RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32))
- def test_non_default_op_name(input, init_result):
- return non_default_op_name(input, outs=[init_result])
+ module = Module.create()
+ f32 = F32Type.get()
+ c32 = ComplexType.get(f32)
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ # CHECK-LABEL: @test_f32_const
+ # CHECK-DAG: %[[CST0:.+]] = arith.constant 42 : i64
+ # CHECK-DAG: %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32
+ # CHECK-DAG: %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64
+ # CHECK-DAG: %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32
+ # CHECK-DAG: %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32
+ # CHECK-NEXT: linalg.yield %[[SUM]] : f32
+ @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32))
+ def test_f32_const(init_result):
+ return test_const(outs=[init_result])
+
+ # CHECK-LABEL: @test_i32_index
+ # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
+ # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
+ # CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32
+ # CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32
+ # CHECK-DAG: %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32
+ # CHECK-NEXT: linalg.yield %[[SUM]] : i32
+ @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32))
+ def test_i32_index(init_result):
+ return test_index(outs=[init_result])
+
+ # CHECK-LABEL: @test_f32_elemwise_exp
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+ )
+ def test_f32_elemwise_exp(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+
+ # CHECK-LABEL: @test_f32_elemwise_log
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[LOG]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+ )
+ def test_f32_elemwise_log(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+
+ # CHECK-LABEL: @test_f32_elemwise_abs
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[EXP:.+]] = math.absf %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+ )
+ def test_f32_elemwise_abs(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
+ # CHECK-LABEL: @test_f32_elemwise_ceil
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+ )
+ def test_f32_elemwise_ceil(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
+
+ # CHECK-LABEL: @test_f32_elemwise_floor
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+ )
+ def test_f32_elemwise_floor(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
+
+ # CHECK-LABEL: @test_f32_elemwise_neg
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+ )
+ def test_f32_elemwise_neg(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
+ # CHECK-LABEL: @test_c32_elemwise_neg
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+ # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32>
+ # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+ )
+ def test_c32_elemwise_neg(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
+ # Just check that we don't assert out on name mismatch.
+ # CHECK-LABEL: @test_non_default_op_name
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32)
+ )
+ def test_non_default_op_name(input, init_result):
+ return non_default_op_name(input, outs=[init_result])
print(module)
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
index 2fd63382c4ec3..ab049d3dfae57 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
@@ -19,121 +19,134 @@ def pooling_poly(
reduce=BinaryFnAttrDef(default=BinaryFn.max_signed),
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
- cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
- D.c]))
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+ O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
+ cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+ )
with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
-
- # Pooling indexing maps.
- # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
- # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
- # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
-
- # CHECK-LABEL: @test_f32i32_max_pooling
- # CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
- # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
- # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
- # CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
- # CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
- # CHECK-NEXT: linalg.yield %[[MAX]] : i32
- # CHECK-NEXT: -> tensor<1x2x4x1xi32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((1, 4, 16, 1), f32),
- RankedTensorType.get((2, 2), f32),
- RankedTensorType.get((1, 2, 4, 1), i32))
- def test_f32i32_max_pooling(input, shape, init_result):
- return pooling_poly(
- input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
-
- # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
- # CHECK: = arith.fptoui
- # CHECK: = arith.maxui
- @func.FuncOp.from_py_func(
- RankedTensorType.get((1, 4, 16, 1), f32),
- RankedTensorType.get((2, 2), f32),
- RankedTensorType.get((1, 2, 4, 1), i32))
- def test_f32i32_max_unsigned_pooling(input, shape, init_result):
- return pooling_poly(
- input,
- shape,
- outs=[init_result],
- reduce=BinaryFn.max_unsigned,
- cast=TypeFn.cast_unsigned,
- strides=[2, 4],
- dilations=[1, 2])
-
- # CHECK-LABEL: @test_f32f32_max_pooling
- # CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
- # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
- # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32
- # CHECK-NEXT: linalg.yield %[[MAX]] : f32
- # CHECK-NEXT: -> tensor<1x2x4x1xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((1, 4, 16, 1), f32),
- RankedTensorType.get((2, 2), f32),
- RankedTensorType.get((1, 2, 4, 1), f32))
- def test_f32f32_max_pooling(input, shape, init_result):
- return pooling_poly(
- input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
-
- # CHECK-LABEL: @test_f32i32_min_pooling
- # CHECK: = arith.fptosi
- # CHECK: = arith.minsi
- @func.FuncOp.from_py_func(
- RankedTensorType.get((1, 4, 16, 1), f32),
- RankedTensorType.get((2, 2), f32),
- RankedTensorType.get((1, 2, 4, 1), i32))
- def test_f32i32_min_pooling(input, shape, init_result):
- return pooling_poly(
- input,
- shape,
- outs=[init_result],
- reduce=BinaryFn.min_signed,
- strides=[2, 4],
- dilations=[1, 2])
-
- # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
- # CHECK: = arith.fptoui
- # CHECK: = arith.minui
- @func.FuncOp.from_py_func(
- RankedTensorType.get((1, 4, 16, 1), f32),
- RankedTensorType.get((2, 2), f32),
- RankedTensorType.get((1, 2, 4, 1), i32))
- def test_f32i32_min_unsigned_pooling(input, shape, init_result):
- return pooling_poly(
- input,
- shape,
- outs=[init_result],
- reduce=BinaryFn.min_unsigned,
- cast=TypeFn.cast_unsigned,
- strides=[2, 4],
- dilations=[1, 2])
-
- # CHECK-LABEL: @test_f32f32_min_pooling
- # CHECK: = arith.minf
- @func.FuncOp.from_py_func(
- RankedTensorType.get((1, 4, 16, 1), f32),
- RankedTensorType.get((2, 2), f32),
- RankedTensorType.get((1, 2, 4, 1), f32))
- def test_f32f32_min_pooling(input, shape, init_result):
- return pooling_poly(
- input,
- shape,
- outs=[init_result],
- reduce=BinaryFn.min_signed,
- strides=[2, 4],
- dilations=[1, 2])
+ module = Module.create()
+ f32 = F32Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ # Pooling indexing maps.
+ # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
+ # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+ # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+
+ # CHECK-LABEL: @test_f32i32_max_pooling
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
+ # CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
+ # CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
+ # CHECK-NEXT: linalg.yield %[[MAX]] : i32
+ # CHECK-NEXT: -> tensor<1x2x4x1xi32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 4, 16, 1), f32),
+ RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((1, 2, 4, 1), i32),
+ )
+ def test_f32i32_max_pooling(input, shape, init_result):
+ return pooling_poly(
+ input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]
+ )
+
+ # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
+ # CHECK: = arith.fptoui
+ # CHECK: = arith.maxui
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 4, 16, 1), f32),
+ RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((1, 2, 4, 1), i32),
+ )
+ def test_f32i32_max_unsigned_pooling(input, shape, init_result):
+ return pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.max_unsigned,
+ cast=TypeFn.cast_unsigned,
+ strides=[2, 4],
+ dilations=[1, 2],
+ )
+
+ # CHECK-LABEL: @test_f32f32_max_pooling
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32
+ # CHECK-NEXT: linalg.yield %[[MAX]] : f32
+ # CHECK-NEXT: -> tensor<1x2x4x1xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 4, 16, 1), f32),
+ RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((1, 2, 4, 1), f32),
+ )
+ def test_f32f32_max_pooling(input, shape, init_result):
+ return pooling_poly(
+ input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]
+ )
+
+ # CHECK-LABEL: @test_f32i32_min_pooling
+ # CHECK: = arith.fptosi
+ # CHECK: = arith.minsi
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 4, 16, 1), f32),
+ RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((1, 2, 4, 1), i32),
+ )
+ def test_f32i32_min_pooling(input, shape, init_result):
+ return pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.min_signed,
+ strides=[2, 4],
+ dilations=[1, 2],
+ )
+
+ # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
+ # CHECK: = arith.fptoui
+ # CHECK: = arith.minui
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 4, 16, 1), f32),
+ RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((1, 2, 4, 1), i32),
+ )
+ def test_f32i32_min_unsigned_pooling(input, shape, init_result):
+ return pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.min_unsigned,
+ cast=TypeFn.cast_unsigned,
+ strides=[2, 4],
+ dilations=[1, 2],
+ )
+
+ # CHECK-LABEL: @test_f32f32_min_pooling
+ # CHECK: = arith.minf
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 4, 16, 1), f32),
+ RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((1, 2, 4, 1), f32),
+ )
+ def test_f32f32_min_pooling(input, shape, init_result):
+ return pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.min_signed,
+ strides=[2, 4],
+ dilations=[1, 2],
+ )
print(module)
diff --git a/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg b/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg
index cead85f016b96..18d2d458ed790 100644
--- a/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg
+++ b/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg
@@ -4,6 +4,6 @@
# Since both lit and the python bindings use the same python interpreter,
# we can just check whether yaml can be imported here and exclude if not.
try:
- import yaml
+ import yaml
except ModuleNotFoundError:
- config.unsupported = True
+ config.unsupported = True
diff --git a/mlir/test/python/dialects/linalg/opdsl/metadata.py b/mlir/test/python/dialects/linalg/opdsl/metadata.py
index a7502e9eb1aae..9c940e1060cab 100644
--- a/mlir/test/python/dialects/linalg/opdsl/metadata.py
+++ b/mlir/test/python/dialects/linalg/opdsl/metadata.py
@@ -13,8 +13,10 @@
def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True)):
- implements(ContractionOpInterface)
- defines(Canonicalizer)
- C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.k, D.n])
+ C=TensorDef(U, S.M, S.N, output=True),
+):
+ implements(ContractionOpInterface)
+ defines(Canonicalizer)
+ C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
+ U, B[D.k, D.n]
+ )
diff --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
index 871341c835a5d..4f3569b7974d8 100644
--- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
+++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
@@ -22,10 +22,12 @@
def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True)):
- domain(D.m, D.n, D.k)
- C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.k, D.n])
+ C=TensorDef(U, S.M, S.N, output=True),
+):
+ domain(D.m, D.n, D.k)
+ C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
+ U, B[D.k, D.n]
+ )
# Verifies that assignment to a scalar (represented as [None]) is represented
@@ -43,7 +45,7 @@ def matmul(
# CHECK-NEXT: - reduction
@linalg_structured_op
def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
- C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
+ C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
# Verifies that the index_dims of shape-only operands translate to correct
@@ -64,6 +66,7 @@ def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
def pool(
I=TensorDef(T, S.I),
K=TensorDef(T, S.K, index_dims=[D.k]),
- O=TensorDef(U, S.O, output=True)):
- domain(D.o, D.k)
- O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k])
+ O=TensorDef(U, S.O, output=True),
+):
+ domain(D.o, D.k)
+ O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k])
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 1167abf84d4e4..5e8414ad4055c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -6,145 +6,154 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: testFill
@run
def testFill():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
- # CHECK-LABEL: func @fill_tensor
- # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
- # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
- # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
- # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
- @func.FuncOp.from_py_func(
- RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
- def fill_tensor(out):
- zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
- return linalg.fill(zero, outs=[out])
-
- # CHECK-LABEL: func @fill_buffer
- # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
- # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
- # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
- # CHECK-NEXT: return
- @func.FuncOp.from_py_func(
- MemRefType.get((12, ShapedType.get_dynamic_size()), f32))
- def fill_buffer(out):
- zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
- linalg.fill(zero, outs=[out])
-
- print(module)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ # CHECK-LABEL: func @fill_tensor
+ # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
+ # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
+ # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
+ # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)
+ )
+ def fill_tensor(out):
+ zero = arith.ConstantOp(
+ value=FloatAttr.get(f32, 0.0), result=f32
+ ).result
+ return linalg.fill(zero, outs=[out])
+
+ # CHECK-LABEL: func @fill_buffer
+ # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
+ # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
+ # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
+ # CHECK-NEXT: return
+ @func.FuncOp.from_py_func(
+ MemRefType.get((12, ShapedType.get_dynamic_size()), f32)
+ )
+ def fill_buffer(out):
+ zero = arith.ConstantOp(
+ value=FloatAttr.get(f32, 0.0), result=f32
+ ).result
+ linalg.fill(zero, outs=[out])
+
+ print(module)
# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
@run
def testNamedStructuredOpCustomForm():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32))
- def named_form(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
- # Check for the named form with custom format
- # CHECK: linalg.elemwise_unary
- # CHECK-SAME: cast = #linalg.type_fn<cast_signed>
- # CHECK-SAME: fun = #linalg.unary_fn<exp>
- # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
- unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
- # CHECK: linalg.elemwise_binary
- # CHECK-SAME: cast = #linalg.type_fn<cast_unsigned>
- # CHECK-SAME: fun = #linalg.binary_fn<mul>
- # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
- # CHECK: return
- binary_result = linalg.elemwise_binary(
- lhs,
- rhs,
- outs=[init_result.result],
- fun=BinaryFn.mul,
- cast=TypeFn.cast_unsigned)
- return unary_result, binary_result
-
- print(module)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)
+ )
+ def named_form(lhs, rhs):
+ init_result = tensor.EmptyOp([4, 8], f32)
+ # Check for the named form with custom format
+ # CHECK: linalg.elemwise_unary
+ # CHECK-SAME: cast = #linalg.type_fn<cast_signed>
+ # CHECK-SAME: fun = #linalg.unary_fn<exp>
+ # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+ unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
+ # CHECK: linalg.elemwise_binary
+ # CHECK-SAME: cast = #linalg.type_fn<cast_unsigned>
+ # CHECK-SAME: fun = #linalg.binary_fn<mul>
+ # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+ # CHECK: return
+ binary_result = linalg.elemwise_binary(
+ lhs,
+ rhs,
+ outs=[init_result.result],
+ fun=BinaryFn.mul,
+ cast=TypeFn.cast_unsigned,
+ )
+ return unary_result, binary_result
+
+ print(module)
# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
@run
def testNamedStructuredOpGenericForm():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
- f32))
- def named_form(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
- # CHECK: "linalg.matmul"(%{{.*}})
- # CHECK-SAME: cast = #linalg.type_fn<cast_signed>
- # CHECK-SAME: operand_segment_sizes = array<i32: 2, 1>
- # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
- # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
- # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
- # CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
- # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
- return linalg.matmul(lhs, rhs, outs=[init_result.result])
-
- module.operation.print(print_generic_op_form=True)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+ )
+ def named_form(lhs, rhs):
+ init_result = tensor.EmptyOp([4, 8], f32)
+ # CHECK: "linalg.matmul"(%{{.*}})
+ # CHECK-SAME: cast = #linalg.type_fn<cast_signed>
+ # CHECK-SAME: operand_segment_sizes = array<i32: 2, 1>
+ # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
+ # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
+ # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
+ # CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
+ # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
+ return linalg.matmul(lhs, rhs, outs=[init_result.result])
+
+ module.operation.print(print_generic_op_form=True)
# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
@run
def testNamedStructuredAsGenericOp():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
- f32))
- def generic_form(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
- # CHECK: linalg.generic
- return linalg.matmul(
- lhs, rhs, outs=[init_result.result], emit_generic=True)
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+ )
+ def generic_form(lhs, rhs):
+ init_result = tensor.EmptyOp([4, 8], f32)
+ # CHECK: linalg.generic
+ return linalg.matmul(
+ lhs, rhs, outs=[init_result.result], emit_generic=True
+ )
- print(module)
+ print(module)
# CHECK-LABEL: TEST: testOpResultFromOtherOp
@run
def testOpResultFromOtherOp():
- with Context(), Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
- f32))
- def pass_an_op_directly(arg0, arg1):
- one = arith.ConstantOp(F32Type.get(), 1.0)
- # CHECK: %[[LHS:.*]] = linalg.fill
- lhs = linalg.fill(one, outs=[arg0])
- # CHECK: %[[RHS:.*]] = linalg.fill
- rhs = linalg.fill(one, outs=[arg1])
- # CHECK: %[[INIT:.*]] = tensor.empty
- init = tensor.EmptyOp([4, 8], f32)
- # CHECK: linalg.matmul
- # CHECK: ins(%[[LHS]], %[[RHS]]
- # CHECK: outs(%[[INIT]]
- return linalg.matmul(lhs, rhs, outs=init)
-
- print(module)
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+ )
+ def pass_an_op_directly(arg0, arg1):
+ one = arith.ConstantOp(F32Type.get(), 1.0)
+ # CHECK: %[[LHS:.*]] = linalg.fill
+ lhs = linalg.fill(one, outs=[arg0])
+ # CHECK: %[[RHS:.*]] = linalg.fill
+ rhs = linalg.fill(one, outs=[arg1])
+ # CHECK: %[[INIT:.*]] = tensor.empty
+ init = tensor.EmptyOp([4, 8], f32)
+ # CHECK: linalg.matmul
+ # CHECK: ins(%[[LHS]], %[[RHS]]
+ # CHECK: outs(%[[INIT]]
+ return linalg.matmul(lhs, rhs, outs=init)
+
+ print(module)
diff --git a/mlir/test/python/dialects/math_dialect.py b/mlir/test/python/dialects/math_dialect.py
index 04b6d848c7422..3d402c54a11e3 100644
--- a/mlir/test/python/dialects/math_dialect.py
+++ b/mlir/test/python/dialects/math_dialect.py
@@ -7,23 +7,26 @@
import mlir.dialects.func as func
import mlir.dialects.math as mlir_math
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
+ print("\nTEST:", f.__name__)
+ f()
+
# CHECK-LABEL: TEST: testMathOps
@run
def testMathOps():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(F32Type.get())
- def emit_sqrt(arg):
- return mlir_math.SqrtOp(arg)
-
- # CHECK-LABEL: func @emit_sqrt(
- # CHECK-SAME: %[[ARG:.*]]: f32) -> f32 {
- # CHECK: math.sqrt %[[ARG]] : f32
- # CHECK: return
- # CHECK: }
- print(module)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(F32Type.get())
+ def emit_sqrt(arg):
+ return mlir_math.SqrtOp(arg)
+
+ # CHECK-LABEL: func @emit_sqrt(
+ # CHECK-SAME: %[[ARG:.*]]: f32) -> f32 {
+ # CHECK: math.sqrt %[[ARG]] : f32
+ # CHECK: return
+ # CHECK: }
+ print(module)
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 59092feffba23..2e3cae671a9f1 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -6,17 +6,17 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: testSubViewAccessors
@run
def testSubViewAccessors():
- ctx = Context()
- module = Module.parse(
- r"""
+ ctx = Context()
+ module = Module.parse(
+ r"""
func.func @f1(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
@@ -27,48 +27,52 @@ def testSubViewAccessors():
memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
return
}
- """, ctx)
- func_body = module.body.operations[0].regions[0].blocks[0]
- subview = func_body.operations[6]
+ """,
+ ctx,
+ )
+ func_body = module.body.operations[0].regions[0].blocks[0]
+ subview = func_body.operations[6]
- assert subview.source == subview.operands[0]
- assert len(subview.offsets) == 2
- assert len(subview.sizes) == 2
- assert len(subview.strides) == 2
- assert subview.result == subview.results[0]
+ assert subview.source == subview.operands[0]
+ assert len(subview.offsets) == 2
+ assert len(subview.sizes) == 2
+ assert len(subview.strides) == 2
+ assert subview.result == subview.results[0]
- # CHECK: SubViewOp
- print(type(subview).__name__)
+ # CHECK: SubViewOp
+ print(type(subview).__name__)
- # CHECK: constant 0
- print(subview.offsets[0])
- # CHECK: constant 1
- print(subview.offsets[1])
- # CHECK: constant 2
- print(subview.sizes[0])
- # CHECK: constant 3
- print(subview.sizes[1])
- # CHECK: constant 4
- print(subview.strides[0])
- # CHECK: constant 5
- print(subview.strides[1])
+ # CHECK: constant 0
+ print(subview.offsets[0])
+ # CHECK: constant 1
+ print(subview.offsets[1])
+ # CHECK: constant 2
+ print(subview.sizes[0])
+ # CHECK: constant 3
+ print(subview.sizes[1])
+ # CHECK: constant 4
+ print(subview.strides[0])
+ # CHECK: constant 5
+ print(subview.strides[1])
# CHECK-LABEL: TEST: testCustomBuidlers
@run
def testCustomBuidlers():
- with Context() as ctx, Location.unknown(ctx):
- module = Module.parse(r"""
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.parse(
+ r"""
func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
return
}
- """)
- f = module.body.operations[0]
- func_body = f.regions[0].blocks[0]
- with InsertionPoint.at_block_terminator(func_body):
- memref.LoadOp(f.arguments[0], f.arguments[1:])
+ """
+ )
+ f = module.body.operations[0]
+ func_body = f.regions[0].blocks[0]
+ with InsertionPoint.at_block_terminator(func_body):
+ memref.LoadOp(f.arguments[0], f.arguments[1:])
- # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
- # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
- print(module)
- assert module.operation.verify()
+ # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+ # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+ print(module)
+ assert module.operation.verify()
diff --git a/mlir/test/python/dialects/ml_program.py b/mlir/test/python/dialects/ml_program.py
index 4d9804ff34732..f16de2add3799 100644
--- a/mlir/test/python/dialects/ml_program.py
+++ b/mlir/test/python/dialects/ml_program.py
@@ -6,23 +6,23 @@
def constructAndPrintInModule(f):
- print("\nTEST:", f.__name__)
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- f()
- print(module)
- return f
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
# CHECK-LABEL: testFuncOp
@constructAndPrintInModule
def testFuncOp():
- # CHECK: ml_program.func @foobar(%arg0: si32) -> si32
- f = ml_program.FuncOp(
- name="foobar",
- type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)]))
- block = f.add_entry_block()
- with InsertionPoint(block):
- # CHECK: ml_program.return
- ml_program.ReturnOp([block.arguments[0]])
+ # CHECK: ml_program.func @foobar(%arg0: si32) -> si32
+ f = ml_program.FuncOp(
+ name="foobar", type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)])
+ )
+ block = f.add_entry_block()
+ with InsertionPoint(block):
+ # CHECK: ml_program.return
+ ml_program.ReturnOp([block.arguments[0]])
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 802a1f271c565..71879bdcb51f5 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -6,207 +6,205 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
def add_dummy_value():
- return Operation.create(
- "custom.value",
- results=[IntegerType.get_signless(32)]).result
+ return Operation.create(
+ "custom.value", results=[IntegerType.get_signless(32)]
+ ).result
def testOdsBuildDefaultImplicitRegions():
-
- class TestFixedRegionsOp(OpView):
- OPERATION_NAME = "custom.test_op"
- _ODS_REGIONS = (2, True)
-
- class TestVariadicRegionsOp(OpView):
- OPERATION_NAME = "custom.test_any_regions_op"
- _ODS_REGIONS = (2, False)
-
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- m = Module.create()
- with InsertionPoint(m.body):
- op = TestFixedRegionsOp.build_generic(results=[], operands=[])
- # CHECK: NUM_REGIONS: 2
- print(f"NUM_REGIONS: {len(op.regions)}")
- # Including a regions= that matches should be fine.
- op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
- print(f"NUM_REGIONS: {len(op.regions)}")
- # Reject greater than.
- try:
- op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=3)
- except ValueError as e:
- # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
- print(f"ERROR:{e}")
- # Reject less than.
- try:
- op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=1)
- except ValueError as e:
- # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
- print(f"ERROR:{e}")
-
- # If no regions specified for a variadic region op, build the minimum.
- op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
- # CHECK: DEFAULT_NUM_REGIONS: 2
- print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
- # Should also accept an explicit regions= that matches the minimum.
- op = TestVariadicRegionsOp.build_generic(
- results=[], operands=[], regions=2)
- # CHECK: EQ_NUM_REGIONS: 2
- print(f"EQ_NUM_REGIONS: {len(op.regions)}")
- # And accept greater than minimum.
- # Should also accept an explicit regions= that matches the minimum.
- op = TestVariadicRegionsOp.build_generic(
- results=[], operands=[], regions=3)
- # CHECK: GT_NUM_REGIONS: 3
- print(f"GT_NUM_REGIONS: {len(op.regions)}")
- # Should reject less than minimum.
- try:
- op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=1)
- except ValueError as e:
- # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
- print(f"ERROR:{e}")
-
+ class TestFixedRegionsOp(OpView):
+ OPERATION_NAME = "custom.test_op"
+ _ODS_REGIONS = (2, True)
+
+ class TestVariadicRegionsOp(OpView):
+ OPERATION_NAME = "custom.test_any_regions_op"
+ _ODS_REGIONS = (2, False)
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint(m.body):
+ op = TestFixedRegionsOp.build_generic(results=[], operands=[])
+ # CHECK: NUM_REGIONS: 2
+ print(f"NUM_REGIONS: {len(op.regions)}")
+ # Including a regions= that matches should be fine.
+ op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
+ print(f"NUM_REGIONS: {len(op.regions)}")
+ # Reject greater than.
+ try:
+ op = TestFixedRegionsOp.build_generic(
+ results=[], operands=[], regions=3
+ )
+ except ValueError as e:
+ # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
+ print(f"ERROR:{e}")
+ # Reject less than.
+ try:
+ op = TestFixedRegionsOp.build_generic(
+ results=[], operands=[], regions=1
+ )
+ except ValueError as e:
+ # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
+ print(f"ERROR:{e}")
+
+ # If no regions specified for a variadic region op, build the minimum.
+ op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
+ # CHECK: DEFAULT_NUM_REGIONS: 2
+ print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
+ # Should also accept an explicit regions= that matches the minimum.
+ op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2)
+ # CHECK: EQ_NUM_REGIONS: 2
+ print(f"EQ_NUM_REGIONS: {len(op.regions)}")
+ # And accept greater than minimum.
+ # Should also accept an explicit regions= that matches the minimum.
+ op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3)
+ # CHECK: GT_NUM_REGIONS: 3
+ print(f"GT_NUM_REGIONS: {len(op.regions)}")
+ # Should reject less than minimum.
+ try:
+ op = TestVariadicRegionsOp.build_generic(
+ results=[], operands=[], regions=1
+ )
+ except ValueError as e:
+ # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
+ print(f"ERROR:{e}")
run(testOdsBuildDefaultImplicitRegions)
def testOdsBuildDefaultNonVariadic():
+ class TestOp(OpView):
+ OPERATION_NAME = "custom.test_op"
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint(m.body):
+ v0 = add_dummy_value()
+ v1 = add_dummy_value()
+ t0 = IntegerType.get_signless(8)
+ t1 = IntegerType.get_signless(16)
+ op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
+ # CHECK: %[[V0:.+]] = "custom.value"
+ # CHECK: %[[V1:.+]] = "custom.value"
+ # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
+ # CHECK-NOT: operand_segment_sizes
+ # CHECK-NOT: result_segment_sizes
+ # CHECK-SAME: : (i32, i32) -> (i8, i16)
+ print(m)
- class TestOp(OpView):
- OPERATION_NAME = "custom.test_op"
-
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- m = Module.create()
- with InsertionPoint(m.body):
- v0 = add_dummy_value()
- v1 = add_dummy_value()
- t0 = IntegerType.get_signless(8)
- t1 = IntegerType.get_signless(16)
- op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
- # CHECK: %[[V0:.+]] = "custom.value"
- # CHECK: %[[V1:.+]] = "custom.value"
- # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
- # CHECK-NOT: operand_segment_sizes
- # CHECK-NOT: result_segment_sizes
- # CHECK-SAME: : (i32, i32) -> (i8, i16)
- print(m)
run(testOdsBuildDefaultNonVariadic)
def testOdsBuildDefaultSizedVariadic():
+ class TestOp(OpView):
+ OPERATION_NAME = "custom.test_op"
+ _ODS_OPERAND_SEGMENTS = [1, -1, 0]
+ _ODS_RESULT_SEGMENTS = [-1, 0, 1]
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint(m.body):
+ v0 = add_dummy_value()
+ v1 = add_dummy_value()
+ v2 = add_dummy_value()
+ v3 = add_dummy_value()
+ t0 = IntegerType.get_signless(8)
+ t1 = IntegerType.get_signless(16)
+ t2 = IntegerType.get_signless(32)
+ t3 = IntegerType.get_signless(64)
+ # CHECK: %[[V0:.+]] = "custom.value"
+ # CHECK: %[[V1:.+]] = "custom.value"
+ # CHECK: %[[V2:.+]] = "custom.value"
+ # CHECK: %[[V3:.+]] = "custom.value"
+ # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
+ # CHECK-SAME: operand_segment_sizes = array<i32: 1, 2, 1>
+ # CHECK-SAME: result_segment_sizes = array<i32: 2, 1, 1>
+ # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
+ op = TestOp.build_generic(
+ results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3]
+ )
+
+ # Now test with optional omitted.
+ # CHECK: "custom.test_op"(%[[V0]])
+ # CHECK-SAME: operand_segment_sizes = array<i32: 1, 0, 0>
+ # CHECK-SAME: result_segment_sizes = array<i32: 0, 0, 1>
+ # CHECK-SAME: (i32) -> i64
+ op = TestOp.build_generic(
+ results=[None, None, t3], operands=[v0, None, None]
+ )
+ print(m)
+
+ # And verify that errors are raised for None in a required operand.
+ try:
+ op = TestOp.build_generic(
+ results=[None, None, t3], operands=[None, None, None]
+ )
+ except ValueError as e:
+ # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
+ print(f"OPERAND_CAST_ERROR:{e}")
+
+ # And verify that errors are raised for None in a required result.
+ try:
+ op = TestOp.build_generic(
+ results=[None, None, None], operands=[v0, None, None]
+ )
+ except ValueError as e:
+ # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
+ print(f"RESULT_CAST_ERROR:{e}")
+
+ # Variadic lists with None elements should reject.
+ try:
+ op = TestOp.build_generic(
+ results=[None, None, t3], operands=[v0, [None], None]
+ )
+ except ValueError as e:
+ # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
+ print(f"OPERAND_LIST_CAST_ERROR:{e}")
+ try:
+ op = TestOp.build_generic(
+ results=[[None], None, t3], operands=[v0, None, None]
+ )
+ except ValueError as e:
+ # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
+ print(f"RESULT_LIST_CAST_ERROR:{e}")
- class TestOp(OpView):
- OPERATION_NAME = "custom.test_op"
- _ODS_OPERAND_SEGMENTS = [1, -1, 0]
- _ODS_RESULT_SEGMENTS = [-1, 0, 1]
-
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- m = Module.create()
- with InsertionPoint(m.body):
- v0 = add_dummy_value()
- v1 = add_dummy_value()
- v2 = add_dummy_value()
- v3 = add_dummy_value()
- t0 = IntegerType.get_signless(8)
- t1 = IntegerType.get_signless(16)
- t2 = IntegerType.get_signless(32)
- t3 = IntegerType.get_signless(64)
- # CHECK: %[[V0:.+]] = "custom.value"
- # CHECK: %[[V1:.+]] = "custom.value"
- # CHECK: %[[V2:.+]] = "custom.value"
- # CHECK: %[[V3:.+]] = "custom.value"
- # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
- # CHECK-SAME: operand_segment_sizes = array<i32: 1, 2, 1>
- # CHECK-SAME: result_segment_sizes = array<i32: 2, 1, 1>
- # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
- op = TestOp.build_generic(
- results=[[t0, t1], t2, t3],
- operands=[v0, [v1, v2], v3])
-
- # Now test with optional omitted.
- # CHECK: "custom.test_op"(%[[V0]])
- # CHECK-SAME: operand_segment_sizes = array<i32: 1, 0, 0>
- # CHECK-SAME: result_segment_sizes = array<i32: 0, 0, 1>
- # CHECK-SAME: (i32) -> i64
- op = TestOp.build_generic(
- results=[None, None, t3],
- operands=[v0, None, None])
- print(m)
-
- # And verify that errors are raised for None in a required operand.
- try:
- op = TestOp.build_generic(
- results=[None, None, t3],
- operands=[None, None, None])
- except ValueError as e:
- # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
- print(f"OPERAND_CAST_ERROR:{e}")
-
- # And verify that errors are raised for None in a required result.
- try:
- op = TestOp.build_generic(
- results=[None, None, None],
- operands=[v0, None, None])
- except ValueError as e:
- # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
- print(f"RESULT_CAST_ERROR:{e}")
-
- # Variadic lists with None elements should reject.
- try:
- op = TestOp.build_generic(
- results=[None, None, t3],
- operands=[v0, [None], None])
- except ValueError as e:
- # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
- print(f"OPERAND_LIST_CAST_ERROR:{e}")
- try:
- op = TestOp.build_generic(
- results=[[None], None, t3],
- operands=[v0, None, None])
- except ValueError as e:
- # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
- print(f"RESULT_LIST_CAST_ERROR:{e}")
run(testOdsBuildDefaultSizedVariadic)
def testOdsBuildDefaultCastError():
+ class TestOp(OpView):
+ OPERATION_NAME = "custom.test_op"
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint(m.body):
+ v0 = add_dummy_value()
+ v1 = add_dummy_value()
+ t0 = IntegerType.get_signless(8)
+ t1 = IntegerType.get_signless(16)
+ try:
+ op = TestOp.build_generic(results=[t0, t1], operands=[None, v1])
+ except ValueError as e:
+ # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
+ print(f"ERROR: {e}")
+ try:
+ op = TestOp.build_generic(results=[t0, None], operands=[v0, v1])
+ except ValueError as e:
+ # CHECK: Result 1 of operation "custom.test_op" must be a Type
+ print(f"ERROR: {e}")
- class TestOp(OpView):
- OPERATION_NAME = "custom.test_op"
-
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- m = Module.create()
- with InsertionPoint(m.body):
- v0 = add_dummy_value()
- v1 = add_dummy_value()
- t0 = IntegerType.get_signless(8)
- t1 = IntegerType.get_signless(16)
- try:
- op = TestOp.build_generic(
- results=[t0, t1],
- operands=[None, v1])
- except ValueError as e:
- # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
- print(f"ERROR: {e}")
- try:
- op = TestOp.build_generic(
- results=[t0, None],
- operands=[v0, v1])
- except ValueError as e:
- # CHECK: Result 1 of operation "custom.test_op" must be a Type
- print(f"ERROR: {e}")
run(testOdsBuildDefaultCastError)
diff --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py
index 3d9cd19278997..0d364f9222a65 100644
--- a/mlir/test/python/dialects/pdl_ops.py
+++ b/mlir/test/python/dialects/pdl_ops.py
@@ -5,13 +5,13 @@
def constructAndPrintInModule(f):
- print("\nTEST:", f.__name__)
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- f()
- print(module)
- return f
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
# CHECK: module {
@@ -27,15 +27,15 @@ def constructAndPrintInModule(f):
# CHECK: }
@constructAndPrintInModule
def test_operations():
- pattern = PatternOp(1, "operations")
- with InsertionPoint(pattern.body):
- attr = AttributeOp()
- ty = TypeOp()
- op0 = OperationOp(attributes={"attr": attr}, types=[ty])
- op0_result = ResultOp(op0, 0)
- input = OperandOp()
- root = OperationOp(args=[op0_result, input])
- RewriteOp(root, "rewriter")
+ pattern = PatternOp(1, "operations")
+ with InsertionPoint(pattern.body):
+ attr = AttributeOp()
+ ty = TypeOp()
+ op0 = OperationOp(attributes={"attr": attr}, types=[ty])
+ op0_result = ResultOp(op0, 0)
+ input = OperandOp()
+ root = OperationOp(args=[op0_result, input])
+ RewriteOp(root, "rewriter")
# CHECK: module {
@@ -47,11 +47,12 @@ def test_operations():
# CHECK: }
@constructAndPrintInModule
def test_rewrite_with_args():
- pattern = PatternOp(1, "rewrite_with_args")
- with InsertionPoint(pattern.body):
- input = OperandOp()
- root = OperationOp(args=[input])
- RewriteOp(root, "rewriter", args=[input])
+ pattern = PatternOp(1, "rewrite_with_args")
+ with InsertionPoint(pattern.body):
+ input = OperandOp()
+ root = OperationOp(args=[input])
+ RewriteOp(root, "rewriter", args=[input])
+
# CHECK: module {
# CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) {
@@ -69,18 +70,19 @@ def test_rewrite_with_args():
# CHECK: }
@constructAndPrintInModule
def test_rewrite_multi_root_optimal():
- pattern = PatternOp(1, "rewrite_multi_root_optimal")
- with InsertionPoint(pattern.body):
- input1 = OperandOp()
- input2 = OperandOp()
- ty = TypeOp()
- op1 = OperationOp(args=[input1], types=[ty])
- val1 = ResultOp(op1, 0)
- root1 = OperationOp(args=[val1])
- op2 = OperationOp(args=[input2], types=[ty])
- val2 = ResultOp(op2, 0)
- root2 = OperationOp(args=[val1, val2])
- RewriteOp(name="rewriter", args=[root1, root2])
+ pattern = PatternOp(1, "rewrite_multi_root_optimal")
+ with InsertionPoint(pattern.body):
+ input1 = OperandOp()
+ input2 = OperandOp()
+ ty = TypeOp()
+ op1 = OperationOp(args=[input1], types=[ty])
+ val1 = ResultOp(op1, 0)
+ root1 = OperationOp(args=[val1])
+ op2 = OperationOp(args=[input2], types=[ty])
+ val2 = ResultOp(op2, 0)
+ root2 = OperationOp(args=[val1, val2])
+ RewriteOp(name="rewriter", args=[root1, root2])
+
# CHECK: module {
# CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) {
@@ -98,18 +100,19 @@ def test_rewrite_multi_root_optimal():
# CHECK: }
@constructAndPrintInModule
def test_rewrite_multi_root_forced():
- pattern = PatternOp(1, "rewrite_multi_root_forced")
- with InsertionPoint(pattern.body):
- input1 = OperandOp()
- input2 = OperandOp()
- ty = TypeOp()
- op1 = OperationOp(args=[input1], types=[ty])
- val1 = ResultOp(op1, 0)
- root1 = OperationOp(args=[val1])
- op2 = OperationOp(args=[input2], types=[ty])
- val2 = ResultOp(op2, 0)
- root2 = OperationOp(args=[val1, val2])
- RewriteOp(root1, name="rewriter", args=[root2])
+ pattern = PatternOp(1, "rewrite_multi_root_forced")
+ with InsertionPoint(pattern.body):
+ input1 = OperandOp()
+ input2 = OperandOp()
+ ty = TypeOp()
+ op1 = OperationOp(args=[input1], types=[ty])
+ val1 = ResultOp(op1, 0)
+ root1 = OperationOp(args=[val1])
+ op2 = OperationOp(args=[input2], types=[ty])
+ val2 = ResultOp(op2, 0)
+ root2 = OperationOp(args=[val1, val2])
+ RewriteOp(root1, name="rewriter", args=[root2])
+
# CHECK: module {
# CHECK: pdl.pattern @rewrite_add_body : benefit(1) {
@@ -125,16 +128,17 @@ def test_rewrite_multi_root_forced():
# CHECK: }
@constructAndPrintInModule
def test_rewrite_add_body():
- pattern = PatternOp(1, "rewrite_add_body")
- with InsertionPoint(pattern.body):
- ty1 = TypeOp(IntegerType.get_signless(32))
- ty2 = TypeOp()
- root = OperationOp(types=[ty1, ty2])
- rewrite = RewriteOp(root)
- with InsertionPoint(rewrite.add_body()):
- ty3 = TypeOp()
- newOp = OperationOp(name="foo.op", types=[ty1, ty3])
- ReplaceOp(root, with_op=newOp)
+ pattern = PatternOp(1, "rewrite_add_body")
+ with InsertionPoint(pattern.body):
+ ty1 = TypeOp(IntegerType.get_signless(32))
+ ty2 = TypeOp()
+ root = OperationOp(types=[ty1, ty2])
+ rewrite = RewriteOp(root)
+ with InsertionPoint(rewrite.add_body()):
+ ty3 = TypeOp()
+ newOp = OperationOp(name="foo.op", types=[ty1, ty3])
+ ReplaceOp(root, with_op=newOp)
+
# CHECK: module {
# CHECK: pdl.pattern @rewrite_type : benefit(1) {
@@ -148,14 +152,15 @@ def test_rewrite_add_body():
# CHECK: }
@constructAndPrintInModule
def test_rewrite_type():
- pattern = PatternOp(1, "rewrite_type")
- with InsertionPoint(pattern.body):
- ty1 = TypeOp(IntegerType.get_signless(32))
- ty2 = TypeOp()
- root = OperationOp(types=[ty1, ty2])
- rewrite = RewriteOp(root)
- with InsertionPoint(rewrite.add_body()):
- newOp = OperationOp(name="foo.op", types=[ty1, ty2])
+ pattern = PatternOp(1, "rewrite_type")
+ with InsertionPoint(pattern.body):
+ ty1 = TypeOp(IntegerType.get_signless(32))
+ ty2 = TypeOp()
+ root = OperationOp(types=[ty1, ty2])
+ rewrite = RewriteOp(root)
+ with InsertionPoint(rewrite.add_body()):
+ newOp = OperationOp(name="foo.op", types=[ty1, ty2])
+
# CHECK: module {
# CHECK: pdl.pattern @rewrite_types : benefit(1) {
@@ -169,14 +174,17 @@ def test_rewrite_type():
# CHECK: }
@constructAndPrintInModule
def test_rewrite_types():
- pattern = PatternOp(1, "rewrite_types")
- with InsertionPoint(pattern.body):
- types = TypesOp()
- root = OperationOp(types=[types])
- rewrite = RewriteOp(root)
- with InsertionPoint(rewrite.add_body()):
- otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)])
- newOp = OperationOp(name="foo.op", types=[types, otherTypes])
+ pattern = PatternOp(1, "rewrite_types")
+ with InsertionPoint(pattern.body):
+ types = TypesOp()
+ root = OperationOp(types=[types])
+ rewrite = RewriteOp(root)
+ with InsertionPoint(rewrite.add_body()):
+ otherTypes = TypesOp(
+ [IntegerType.get_signless(32), IntegerType.get_signless(64)]
+ )
+ newOp = OperationOp(name="foo.op", types=[types, otherTypes])
+
# CHECK: module {
# CHECK: pdl.pattern @rewrite_operands : benefit(1) {
@@ -190,14 +198,15 @@ def test_rewrite_types():
# CHECK: }
@constructAndPrintInModule
def test_rewrite_operands():
- pattern = PatternOp(1, "rewrite_operands")
- with InsertionPoint(pattern.body):
- types = TypesOp()
- operands = OperandsOp(types)
- root = OperationOp(args=[operands])
- rewrite = RewriteOp(root)
- with InsertionPoint(rewrite.add_body()):
- newOp = OperationOp(name="foo.op", types=[types])
+ pattern = PatternOp(1, "rewrite_operands")
+ with InsertionPoint(pattern.body):
+ types = TypesOp()
+ operands = OperandsOp(types)
+ root = OperationOp(args=[operands])
+ rewrite = RewriteOp(root)
+ with InsertionPoint(rewrite.add_body()):
+ newOp = OperationOp(name="foo.op", types=[types])
+
# CHECK: module {
# CHECK: pdl.pattern @native_rewrite : benefit(1) {
@@ -209,12 +218,13 @@ def test_rewrite_operands():
# CHECK: }
@constructAndPrintInModule
def test_native_rewrite():
- pattern = PatternOp(1, "native_rewrite")
- with InsertionPoint(pattern.body):
- root = OperationOp()
- rewrite = RewriteOp(root)
- with InsertionPoint(rewrite.add_body()):
- ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
+ pattern = PatternOp(1, "native_rewrite")
+ with InsertionPoint(pattern.body):
+ root = OperationOp()
+ rewrite = RewriteOp(root)
+ with InsertionPoint(rewrite.add_body()):
+ ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
+
# CHECK: module {
# CHECK: pdl.pattern @attribute_with_value : benefit(1) {
@@ -227,13 +237,14 @@ def test_native_rewrite():
# CHECK: }
@constructAndPrintInModule
def test_attribute_with_value():
- pattern = PatternOp(1, "attribute_with_value")
- with InsertionPoint(pattern.body):
- root = OperationOp()
- rewrite = RewriteOp(root)
- with InsertionPoint(rewrite.add_body()):
- attr = AttributeOp(value=Attribute.parse('"value"'))
- ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
+ pattern = PatternOp(1, "attribute_with_value")
+ with InsertionPoint(pattern.body):
+ root = OperationOp()
+ rewrite = RewriteOp(root)
+ with InsertionPoint(rewrite.add_body()):
+ attr = AttributeOp(value=Attribute.parse('"value"'))
+ ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
+
# CHECK: module {
# CHECK: pdl.pattern @erase : benefit(1) {
@@ -245,12 +256,13 @@ def test_attribute_with_value():
# CHECK: }
@constructAndPrintInModule
def test_erase():
- pattern = PatternOp(1, "erase")
- with InsertionPoint(pattern.body):
- root = OperationOp()
- rewrite = RewriteOp(root)
- with InsertionPoint(rewrite.add_body()):
- EraseOp(root)
+ pattern = PatternOp(1, "erase")
+ with InsertionPoint(pattern.body):
+ root = OperationOp()
+ rewrite = RewriteOp(root)
+ with InsertionPoint(rewrite.add_body()):
+ EraseOp(root)
+
# CHECK: module {
# CHECK: pdl.pattern @operation_results : benefit(1) {
@@ -263,14 +275,15 @@ def test_erase():
# CHECK: }
@constructAndPrintInModule
def test_operation_results():
- valueRange = RangeType.get(ValueType.get())
- pattern = PatternOp(1, "operation_results")
- with InsertionPoint(pattern.body):
- types = TypesOp()
- inputOp = OperationOp(types=[types])
- results = ResultsOp(valueRange, inputOp)
- root = OperationOp(args=[results])
- RewriteOp(root, name="rewriter")
+ valueRange = RangeType.get(ValueType.get())
+ pattern = PatternOp(1, "operation_results")
+ with InsertionPoint(pattern.body):
+ types = TypesOp()
+ inputOp = OperationOp(types=[types])
+ results = ResultsOp(valueRange, inputOp)
+ root = OperationOp(args=[results])
+ RewriteOp(root, name="rewriter")
+
# CHECK: module {
# CHECK: pdl.pattern : benefit(1) {
@@ -282,9 +295,9 @@ def test_operation_results():
# CHECK: }
@constructAndPrintInModule
def test_apply_native_constraint():
- pattern = PatternOp(1)
- with InsertionPoint(pattern.body):
- resultType = TypeOp()
- ApplyNativeConstraintOp("typeConstraint", args=[resultType])
- root = OperationOp(types=[resultType])
- RewriteOp(root, name="rewrite")
+ pattern = PatternOp(1)
+ with InsertionPoint(pattern.body):
+ resultType = TypeOp()
+ ApplyNativeConstraintOp("typeConstraint", args=[resultType])
+ root = OperationOp(types=[resultType])
+ RewriteOp(root, name="rewrite")
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 2ca79b29f567c..72a765c75e52c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -5,367 +5,373 @@
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
+
# CHECK-LABEL: TEST: testAttributes
@run
def testAttributes():
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
-
- #
- # Check op construction with attributes.
- #
-
- i32 = IntegerType.get_signless(32)
- one = IntegerAttr.get(i32, 1)
- two = IntegerAttr.get(i32, 2)
- unit = UnitAttr.get()
-
- # CHECK: "python_test.attributed_op"() {
- # CHECK-DAG: mandatory_i32 = 1 : i32
- # CHECK-DAG: optional_i32 = 2 : i32
- # CHECK-DAG: unit
- # CHECK: }
- op = test.AttributedOp(one, optional_i32=two, unit=unit)
- print(f"{op}")
-
- # CHECK: "python_test.attributed_op"() {
- # CHECK: mandatory_i32 = 2 : i32
- # CHECK: }
- op2 = test.AttributedOp(two)
- print(f"{op2}")
-
- #
- # Check generic "attributes" access and mutation.
- #
-
- assert "additional" not in op.attributes
-
- # CHECK: "python_test.attributed_op"() {
- # CHECK-DAG: additional = 1 : i32
- # CHECK-DAG: mandatory_i32 = 2 : i32
- # CHECK: }
- op2.attributes["additional"] = one
- print(f"{op2}")
-
- # CHECK: "python_test.attributed_op"() {
- # CHECK-DAG: additional = 2 : i32
- # CHECK-DAG: mandatory_i32 = 2 : i32
- # CHECK: }
- op2.attributes["additional"] = two
- print(f"{op2}")
-
- # CHECK: "python_test.attributed_op"() {
- # CHECK-NOT: additional = 2 : i32
- # CHECK: mandatory_i32 = 2 : i32
- # CHECK: }
- del op2.attributes["additional"]
- print(f"{op2}")
-
- try:
- print(op.attributes["additional"])
- except KeyError:
- pass
- else:
- assert False, "expected KeyError on unknown attribute key"
-
- #
- # Check accessors to defined attributes.
- #
-
- # CHECK: Mandatory: 1
- # CHECK: Optional: 2
- # CHECK: Unit: True
- print(f"Mandatory: {op.mandatory_i32.value}")
- print(f"Optional: {op.optional_i32.value}")
- print(f"Unit: {op.unit}")
-
- # CHECK: Mandatory: 2
- # CHECK: Optional: None
- # CHECK: Unit: False
- print(f"Mandatory: {op2.mandatory_i32.value}")
- print(f"Optional: {op2.optional_i32}")
- print(f"Unit: {op2.unit}")
-
- # CHECK: Mandatory: 2
- # CHECK: Optional: None
- # CHECK: Unit: False
- op.mandatory_i32 = two
- op.optional_i32 = None
- op.unit = False
- print(f"Mandatory: {op.mandatory_i32.value}")
- print(f"Optional: {op.optional_i32}")
- print(f"Unit: {op.unit}")
- assert "optional_i32" not in op.attributes
- assert "unit" not in op.attributes
-
- try:
- op.mandatory_i32 = None
- except ValueError:
- pass
- else:
- assert False, "expected ValueError on setting a mandatory attribute to None"
-
- # CHECK: Optional: 2
- op.optional_i32 = two
- print(f"Optional: {op.optional_i32.value}")
-
- # CHECK: Optional: None
- del op.optional_i32
- print(f"Optional: {op.optional_i32}")
-
- # CHECK: Unit: False
- op.unit = None
- print(f"Unit: {op.unit}")
- assert "unit" not in op.attributes
-
- # CHECK: Unit: True
- op.unit = True
- print(f"Unit: {op.unit}")
-
- # CHECK: Unit: False
- del op.unit
- print(f"Unit: {op.unit}")
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+
+ #
+ # Check op construction with attributes.
+ #
+
+ i32 = IntegerType.get_signless(32)
+ one = IntegerAttr.get(i32, 1)
+ two = IntegerAttr.get(i32, 2)
+ unit = UnitAttr.get()
+
+ # CHECK: "python_test.attributed_op"() {
+ # CHECK-DAG: mandatory_i32 = 1 : i32
+ # CHECK-DAG: optional_i32 = 2 : i32
+ # CHECK-DAG: unit
+ # CHECK: }
+ op = test.AttributedOp(one, optional_i32=two, unit=unit)
+ print(f"{op}")
+
+ # CHECK: "python_test.attributed_op"() {
+ # CHECK: mandatory_i32 = 2 : i32
+ # CHECK: }
+ op2 = test.AttributedOp(two)
+ print(f"{op2}")
+
+ #
+ # Check generic "attributes" access and mutation.
+ #
+
+ assert "additional" not in op.attributes
+
+ # CHECK: "python_test.attributed_op"() {
+ # CHECK-DAG: additional = 1 : i32
+ # CHECK-DAG: mandatory_i32 = 2 : i32
+ # CHECK: }
+ op2.attributes["additional"] = one
+ print(f"{op2}")
+
+ # CHECK: "python_test.attributed_op"() {
+ # CHECK-DAG: additional = 2 : i32
+ # CHECK-DAG: mandatory_i32 = 2 : i32
+ # CHECK: }
+ op2.attributes["additional"] = two
+ print(f"{op2}")
+
+ # CHECK: "python_test.attributed_op"() {
+ # CHECK-NOT: additional = 2 : i32
+ # CHECK: mandatory_i32 = 2 : i32
+ # CHECK: }
+ del op2.attributes["additional"]
+ print(f"{op2}")
+
+ try:
+ print(op.attributes["additional"])
+ except KeyError:
+ pass
+ else:
+ assert False, "expected KeyError on unknown attribute key"
+
+ #
+ # Check accessors to defined attributes.
+ #
+
+ # CHECK: Mandatory: 1
+ # CHECK: Optional: 2
+ # CHECK: Unit: True
+ print(f"Mandatory: {op.mandatory_i32.value}")
+ print(f"Optional: {op.optional_i32.value}")
+ print(f"Unit: {op.unit}")
+
+ # CHECK: Mandatory: 2
+ # CHECK: Optional: None
+ # CHECK: Unit: False
+ print(f"Mandatory: {op2.mandatory_i32.value}")
+ print(f"Optional: {op2.optional_i32}")
+ print(f"Unit: {op2.unit}")
+
+ # CHECK: Mandatory: 2
+ # CHECK: Optional: None
+ # CHECK: Unit: False
+ op.mandatory_i32 = two
+ op.optional_i32 = None
+ op.unit = False
+ print(f"Mandatory: {op.mandatory_i32.value}")
+ print(f"Optional: {op.optional_i32}")
+ print(f"Unit: {op.unit}")
+ assert "optional_i32" not in op.attributes
+ assert "unit" not in op.attributes
+
+ try:
+ op.mandatory_i32 = None
+ except ValueError:
+ pass
+ else:
+ assert False, "expected ValueError on setting a mandatory attribute to None"
+
+ # CHECK: Optional: 2
+ op.optional_i32 = two
+ print(f"Optional: {op.optional_i32.value}")
+
+ # CHECK: Optional: None
+ del op.optional_i32
+ print(f"Optional: {op.optional_i32}")
+
+ # CHECK: Unit: False
+ op.unit = None
+ print(f"Unit: {op.unit}")
+ assert "unit" not in op.attributes
+
+ # CHECK: Unit: True
+ op.unit = True
+ print(f"Unit: {op.unit}")
+
+ # CHECK: Unit: False
+ del op.unit
+ print(f"Unit: {op.unit}")
+
# CHECK-LABEL: TEST: attrBuilder
@run
def attrBuilder():
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- op = test.AttributesOp(x_bool=True,
- x_i16=1,
- x_i32=2,
- x_i64=3,
- x_si16=-1,
- x_si32=-2,
- x_f32=1.5,
- x_f64=2.5,
- x_str='x_str',
- x_i32_array=[1, 2, 3],
- x_i64_array=[4, 5, 6],
- x_f32_array=[1.5, -2.5, 3.5],
- x_f64_array=[4.5, 5.5, -6.5],
- x_i64_dense=[1, 2, 3, 4, 5, 6])
- print(op)
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ op = test.AttributesOp(
+ x_bool=True,
+ x_i16=1,
+ x_i32=2,
+ x_i64=3,
+ x_si16=-1,
+ x_si32=-2,
+ x_f32=1.5,
+ x_f64=2.5,
+ x_str="x_str",
+ x_i32_array=[1, 2, 3],
+ x_i64_array=[4, 5, 6],
+ x_f32_array=[1.5, -2.5, 3.5],
+ x_f64_array=[4.5, 5.5, -6.5],
+ x_i64_dense=[1, 2, 3, 4, 5, 6],
+ )
+ print(op)
# CHECK-LABEL: TEST: inferReturnTypes
@run
def inferReturnTypes():
- with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
- module = Module.create()
- with InsertionPoint(module.body):
- op = test.InferResultsOp()
- dummy = test.DummyOp()
-
- # CHECK: [Type(i32), Type(i64)]
- iface = InferTypeOpInterface(op)
- print(iface.inferReturnTypes())
-
- # CHECK: [Type(i32), Type(i64)]
- iface_static = InferTypeOpInterface(test.InferResultsOp)
- print(iface.inferReturnTypes())
-
- assert isinstance(iface.opview, test.InferResultsOp)
- assert iface.opview == iface.operation.opview
-
- try:
- iface_static.opview
- except TypeError:
- pass
- else:
- assert False, ("not expected to be able to obtain an opview from a static"
- " interface")
-
- try:
- InferTypeOpInterface(dummy)
- except ValueError:
- pass
- else:
- assert False, "not expected dummy op to implement the interface"
-
- try:
- InferTypeOpInterface(test.DummyOp)
- except ValueError:
- pass
- else:
- assert False, "not expected dummy op class to implement the interface"
+ with Context() as ctx, Location.unknown(ctx):
+ test.register_python_test_dialect(ctx)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ op = test.InferResultsOp()
+ dummy = test.DummyOp()
+
+ # CHECK: [Type(i32), Type(i64)]
+ iface = InferTypeOpInterface(op)
+ print(iface.inferReturnTypes())
+
+ # CHECK: [Type(i32), Type(i64)]
+ iface_static = InferTypeOpInterface(test.InferResultsOp)
+ print(iface.inferReturnTypes())
+
+ assert isinstance(iface.opview, test.InferResultsOp)
+ assert iface.opview == iface.operation.opview
+
+ try:
+ iface_static.opview
+ except TypeError:
+ pass
+ else:
+ assert False, (
+ "not expected to be able to obtain an opview from a static" " interface"
+ )
+
+ try:
+ InferTypeOpInterface(dummy)
+ except ValueError:
+ pass
+ else:
+ assert False, "not expected dummy op to implement the interface"
+
+ try:
+ InferTypeOpInterface(test.DummyOp)
+ except ValueError:
+ pass
+ else:
+ assert False, "not expected dummy op class to implement the interface"
# CHECK-LABEL: TEST: resultTypesDefinedByTraits
@run
def resultTypesDefinedByTraits():
- with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
- module = Module.create()
- with InsertionPoint(module.body):
- inferred = test.InferResultsOp()
- same = test.SameOperandAndResultTypeOp([inferred.results[0]])
- # CHECK-COUNT-2: i32
- print(same.one.type)
- print(same.two.type)
-
- first_type_attr = test.FirstAttrDeriveTypeAttrOp(
- inferred.results[1], TypeAttr.get(IndexType.get()))
- # CHECK-COUNT-2: index
- print(first_type_attr.one.type)
- print(first_type_attr.two.type)
-
- first_attr = test.FirstAttrDeriveAttrOp(
- FloatAttr.get(F32Type.get(), 3.14))
- # CHECK-COUNT-3: f32
- print(first_attr.one.type)
- print(first_attr.two.type)
- print(first_attr.three.type)
-
- implied = test.InferResultsImpliedOp()
- # CHECK: i32
- print(implied.integer.type)
- # CHECK: f64
- print(implied.flt.type)
- # CHECK: index
- print(implied.index.type)
+ with Context() as ctx, Location.unknown(ctx):
+ test.register_python_test_dialect(ctx)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ inferred = test.InferResultsOp()
+ same = test.SameOperandAndResultTypeOp([inferred.results[0]])
+ # CHECK-COUNT-2: i32
+ print(same.one.type)
+ print(same.two.type)
+
+ first_type_attr = test.FirstAttrDeriveTypeAttrOp(
+ inferred.results[1], TypeAttr.get(IndexType.get())
+ )
+ # CHECK-COUNT-2: index
+ print(first_type_attr.one.type)
+ print(first_type_attr.two.type)
+
+ first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
+ # CHECK-COUNT-3: f32
+ print(first_attr.one.type)
+ print(first_attr.two.type)
+ print(first_attr.three.type)
+
+ implied = test.InferResultsImpliedOp()
+ # CHECK: i32
+ print(implied.integer.type)
+ # CHECK: f64
+ print(implied.flt.type)
+ # CHECK: index
+ print(implied.index.type)
# CHECK-LABEL: TEST: testOptionalOperandOp
@run
def testOptionalOperandOp():
- with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
+ with Context() as ctx, Location.unknown():
+ test.register_python_test_dialect(ctx)
- module = Module.create()
- with InsertionPoint(module.body):
+ module = Module.create()
+ with InsertionPoint(module.body):
- op1 = test.OptionalOperandOp()
- # CHECK: op1.input is None: True
- print(f"op1.input is None: {op1.input is None}")
+ op1 = test.OptionalOperandOp()
+ # CHECK: op1.input is None: True
+ print(f"op1.input is None: {op1.input is None}")
- op2 = test.OptionalOperandOp(input=op1)
- # CHECK: op2.input is None: False
- print(f"op2.input is None: {op2.input is None}")
+ op2 = test.OptionalOperandOp(input=op1)
+ # CHECK: op2.input is None: False
+ print(f"op2.input is None: {op2.input is None}")
# CHECK-LABEL: TEST: testCustomAttribute
@run
def testCustomAttribute():
- with Context() as ctx:
- test.register_python_test_dialect(ctx)
- a = test.TestAttr.get()
- # CHECK: #python_test.test_attr
- print(a)
-
- # The following cast must not assert.
- b = test.TestAttr(a)
-
- unit = UnitAttr.get()
- try:
- test.TestAttr(unit)
- except ValueError as e:
- assert "Cannot cast attribute to TestAttr" in str(e)
- else:
- raise
-
- # The following must trigger a TypeError from our adaptors and must not
- # crash.
- try:
- test.TestAttr(42)
- except TypeError as e:
- assert "Expected an MLIR object" in str(e)
- else:
- raise
-
- # The following must trigger a TypeError from pybind (therefore, not
- # checking its message) and must not crash.
- try:
- test.TestAttr(42, 56)
- except TypeError:
- pass
- else:
- raise
+ with Context() as ctx:
+ test.register_python_test_dialect(ctx)
+ a = test.TestAttr.get()
+ # CHECK: #python_test.test_attr
+ print(a)
+
+ # The following cast must not assert.
+ b = test.TestAttr(a)
+
+ unit = UnitAttr.get()
+ try:
+ test.TestAttr(unit)
+ except ValueError as e:
+ assert "Cannot cast attribute to TestAttr" in str(e)
+ else:
+ raise
+
+ # The following must trigger a TypeError from our adaptors and must not
+ # crash.
+ try:
+ test.TestAttr(42)
+ except TypeError as e:
+ assert "Expected an MLIR object" in str(e)
+ else:
+ raise
+
+ # The following must trigger a TypeError from pybind (therefore, not
+ # checking its message) and must not crash.
+ try:
+ test.TestAttr(42, 56)
+ except TypeError:
+ pass
+ else:
+ raise
@run
def testCustomType():
- with Context() as ctx:
- test.register_python_test_dialect(ctx)
- a = test.TestType.get()
- # CHECK: !python_test.test_type
- print(a)
-
- # The following cast must not assert.
- b = test.TestType(a)
- # Instance custom types should have typeids
- assert isinstance(b.typeid, TypeID)
- # Subclasses of ir.Type should not have a static_typeid
- # CHECK: 'TestType' object has no attribute 'static_typeid'
- try:
- b.static_typeid
- except AttributeError as e:
- print(e)
-
- i8 = IntegerType.get_signless(8)
- try:
- test.TestType(i8)
- except ValueError as e:
- assert "Cannot cast type to TestType" in str(e)
- else:
- raise
-
- # The following must trigger a TypeError from our adaptors and must not
- # crash.
- try:
- test.TestType(42)
- except TypeError as e:
- assert "Expected an MLIR object" in str(e)
- else:
- raise
-
- # The following must trigger a TypeError from pybind (therefore, not
- # checking its message) and must not crash.
- try:
- test.TestType(42, 56)
- except TypeError:
- pass
- else:
- raise
+ with Context() as ctx:
+ test.register_python_test_dialect(ctx)
+ a = test.TestType.get()
+ # CHECK: !python_test.test_type
+ print(a)
+
+ # The following cast must not assert.
+ b = test.TestType(a)
+ # Instance custom types should have typeids
+ assert isinstance(b.typeid, TypeID)
+ # Subclasses of ir.Type should not have a static_typeid
+ # CHECK: 'TestType' object has no attribute 'static_typeid'
+ try:
+ b.static_typeid
+ except AttributeError as e:
+ print(e)
+
+ i8 = IntegerType.get_signless(8)
+ try:
+ test.TestType(i8)
+ except ValueError as e:
+ assert "Cannot cast type to TestType" in str(e)
+ else:
+ raise
+
+ # The following must trigger a TypeError from our adaptors and must not
+ # crash.
+ try:
+ test.TestType(42)
+ except TypeError as e:
+ assert "Expected an MLIR object" in str(e)
+ else:
+ raise
+
+ # The following must trigger a TypeError from pybind (therefore, not
+ # checking its message) and must not crash.
+ try:
+ test.TestType(42, 56)
+ except TypeError:
+ pass
+ else:
+ raise
@run
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
- with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
+ with Context() as ctx, Location.unknown():
+ test.register_python_test_dialect(ctx)
- i8 = IntegerType.get_signless(8)
+ i8 = IntegerType.get_signless(8)
- class Tensor(test.TestTensorValue):
- def __str__(self):
- return super().__str__().replace("Value", "Tensor")
+ class Tensor(test.TestTensorValue):
+ def __str__(self):
+ return super().__str__().replace("Value", "Tensor")
- module = Module.create()
- with InsertionPoint(module.body):
- t = tensor.EmptyOp([10, 10], i8).result
+ module = Module.create()
+ with InsertionPoint(module.body):
+ t = tensor.EmptyOp([10, 10], i8).result
- # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
- print(Value(t))
+ # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+ print(Value(t))
- tt = Tensor(t)
- # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
- print(tt)
+ tt = Tensor(t)
+ # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+ print(tt)
- # CHECK: False
- print(tt.is_null())
+ # CHECK: False
+ print(tt.is_null())
- # Classes of custom types that inherit from concrete types should have
- # static_typeid
- assert isinstance(test.TestTensorType.static_typeid, TypeID)
- # And it should be equal to the in-tree concrete type
- assert test.TestTensorType.static_typeid == t.type.typeid
+ # Classes of custom types that inherit from concrete types should have
+ # static_typeid
+ assert isinstance(test.TestTensorType.static_typeid, TypeID)
+ # And it should be equal to the in-tree concrete type
+ assert test.TestTensorType.static_typeid == t.type.typeid
# CHECK-LABEL: TEST: inferReturnTypeComponents
@@ -412,7 +418,7 @@ def inferReturnTypeComponents():
# CHECK: shape: None
iface = InferShapedTypeOpInterface(unranked_op)
shaped_type_components = iface.inferReturnTypeComponents(
- operands=[unranked_op.operand]
+ operands=[unranked_op.operand]
)[0]
print("has rank:", shaped_type_components.has_rank)
print("rank:", shaped_type_components.rank)
diff --git a/mlir/test/python/dialects/quant.py b/mlir/test/python/dialects/quant.py
index 32614be31f377..0ee3327dec152 100644
--- a/mlir/test/python/dialects/quant.py
+++ b/mlir/test/python/dialects/quant.py
@@ -5,127 +5,133 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: test_type_hierarchy
@run
def test_type_hierarchy():
- with Context():
- i8 = IntegerType.get_signless(8)
- any = Type.parse("!quant.any<i8<-8:7>:f32>")
- uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
- per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
- calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
+ with Context():
+ i8 = IntegerType.get_signless(8)
+ any = Type.parse("!quant.any<i8<-8:7>:f32>")
+ uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
+ per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
+ calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
- assert not quant.QuantizedType.isinstance(i8)
- assert quant.QuantizedType.isinstance(any)
- assert quant.QuantizedType.isinstance(uniform)
- assert quant.QuantizedType.isinstance(per_axis)
- assert quant.QuantizedType.isinstance(calibrated)
+ assert not quant.QuantizedType.isinstance(i8)
+ assert quant.QuantizedType.isinstance(any)
+ assert quant.QuantizedType.isinstance(uniform)
+ assert quant.QuantizedType.isinstance(per_axis)
+ assert quant.QuantizedType.isinstance(calibrated)
- assert quant.AnyQuantizedType.isinstance(any)
- assert quant.UniformQuantizedType.isinstance(uniform)
- assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
- assert quant.CalibratedQuantizedType.isinstance(calibrated)
+ assert quant.AnyQuantizedType.isinstance(any)
+ assert quant.UniformQuantizedType.isinstance(uniform)
+ assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
+ assert quant.CalibratedQuantizedType.isinstance(calibrated)
- assert not quant.AnyQuantizedType.isinstance(uniform)
- assert not quant.UniformQuantizedType.isinstance(per_axis)
+ assert not quant.AnyQuantizedType.isinstance(uniform)
+ assert not quant.UniformQuantizedType.isinstance(per_axis)
# CHECK-LABEL: TEST: test_any_quantized_type
@run
def test_any_quantized_type():
- with Context():
- i8 = IntegerType.get_signless(8)
- f32 = F32Type.get()
- any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32,
- -8, 7)
-
- # CHECK: flags: 1
- print(f"flags: {any.flags}")
- # CHECK: signed: True
- print(f"signed: {any.is_signed}")
- # CHECK: storage type: i8
- print(f"storage type: {any.storage_type}")
- # CHECK: expressed type: f32
- print(f"expressed type: {any.expressed_type}")
- # CHECK: storage min: -8
- print(f"storage min: {any.storage_type_min}")
- # CHECK: storage max: 7
- print(f"storage max: {any.storage_type_max}")
- # CHECK: storage width: 8
- print(f"storage width: {any.storage_type_integral_width}")
- # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
- print(f"quantized element type: {any.quantized_element_type}")
- # CHECK: !quant.any<i8<-8:7>:f32>
- print(any)
- assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
+ with Context():
+ i8 = IntegerType.get_signless(8)
+ f32 = F32Type.get()
+ any = quant.AnyQuantizedType.get(
+ quant.QuantizedType.FLAG_SIGNED, i8, f32, -8, 7
+ )
+
+ # CHECK: flags: 1
+ print(f"flags: {any.flags}")
+ # CHECK: signed: True
+ print(f"signed: {any.is_signed}")
+ # CHECK: storage type: i8
+ print(f"storage type: {any.storage_type}")
+ # CHECK: expressed type: f32
+ print(f"expressed type: {any.expressed_type}")
+ # CHECK: storage min: -8
+ print(f"storage min: {any.storage_type_min}")
+ # CHECK: storage max: 7
+ print(f"storage max: {any.storage_type_max}")
+ # CHECK: storage width: 8
+ print(f"storage width: {any.storage_type_integral_width}")
+ # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
+ print(f"quantized element type: {any.quantized_element_type}")
+ # CHECK: !quant.any<i8<-8:7>:f32>
+ print(any)
+ assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
# CHECK-LABEL: TEST: test_uniform_type
@run
def test_uniform_type():
- with Context():
- i8 = IntegerType.get_signless(8)
- f32 = F32Type.get()
- uniform = quant.UniformQuantizedType.get(
- quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7)
-
- # CHECK: scale: 0.99872
- print(f"scale: {uniform.scale}")
- # CHECK: zero point: 127
- print(f"zero point: {uniform.zero_point}")
- # CHECK: fixed point: False
- print(f"fixed point: {uniform.is_fixed_point}")
- # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
- print(uniform)
- assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
+ with Context():
+ i8 = IntegerType.get_signless(8)
+ f32 = F32Type.get()
+ uniform = quant.UniformQuantizedType.get(
+ quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7
+ )
+
+ # CHECK: scale: 0.99872
+ print(f"scale: {uniform.scale}")
+ # CHECK: zero point: 127
+ print(f"zero point: {uniform.zero_point}")
+ # CHECK: fixed point: False
+ print(f"fixed point: {uniform.is_fixed_point}")
+ # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
+ print(uniform)
+ assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
# CHECK-LABEL: TEST: test_uniform_per_axis_type
@run
def test_uniform_per_axis_type():
- with Context():
- i8 = IntegerType.get_signless(8)
- f32 = F32Type.get()
- per_axis = quant.UniformQuantizedPerAxisType.get(
- quant.QuantizedType.FLAG_SIGNED,
- i8,
- f32, [200, 0.99872], [0, 120],
- quantized_dimension=1,
- storage_type_min=quant.QuantizedType.default_minimum_for_integer(
- is_signed=True, integral_width=8),
- storage_type_max=quant.QuantizedType.default_maximum_for_integer(
- is_signed=True, integral_width=8))
-
- # CHECK: scales: None
- print(f"scales: {per_axis.scales}")
- # CHECK: zero_points: None
- print(f"zero_points: {per_axis.zero_points}")
- # CHECK: quantized dim: 1
- print(f"quantized dim: {per_axis.quantized_dimension}")
- # CHECK: fixed point: False
- print(f"fixed point: {per_axis.is_fixed_point}")
- # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
- print(per_axis)
- assert per_axis == Type.parse(
- "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
+ with Context():
+ i8 = IntegerType.get_signless(8)
+ f32 = F32Type.get()
+ per_axis = quant.UniformQuantizedPerAxisType.get(
+ quant.QuantizedType.FLAG_SIGNED,
+ i8,
+ f32,
+ [200, 0.99872],
+ [0, 120],
+ quantized_dimension=1,
+ storage_type_min=quant.QuantizedType.default_minimum_for_integer(
+ is_signed=True, integral_width=8
+ ),
+ storage_type_max=quant.QuantizedType.default_maximum_for_integer(
+ is_signed=True, integral_width=8
+ ),
+ )
+
+ # CHECK: scales: None
+ print(f"scales: {per_axis.scales}")
+ # CHECK: zero_points: None
+ print(f"zero_points: {per_axis.zero_points}")
+ # CHECK: quantized dim: 1
+ print(f"quantized dim: {per_axis.quantized_dimension}")
+ # CHECK: fixed point: False
+ print(f"fixed point: {per_axis.is_fixed_point}")
+ # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
+ print(per_axis)
+ assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
# CHECK-LABEL: TEST: test_calibrated_type
@run
def test_calibrated_type():
- with Context():
- f32 = F32Type.get()
- calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
-
- # CHECK: min: -0.998
- print(f"min: {calibrated.min}")
- # CHECK: max: 1.2321
- print(f"max: {calibrated.max}")
- # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
- print(calibrated)
- assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
+ with Context():
+ f32 = F32Type.get()
+ calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
+
+ # CHECK: min: -0.998
+ print(f"min: {calibrated.min}")
+ # CHECK: max: 1.2321
+ print(f"max: {calibrated.max}")
+ # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
+ print(calibrated)
+ assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 4a618ff4eecc3..8cb55fdf6a1eb 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -8,26 +8,26 @@
def constructAndPrintInModule(f):
- print("\nTEST:", f.__name__)
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- f()
- print(module)
- return f
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
# CHECK-LABEL: TEST: testSimpleLoop
@constructAndPrintInModule
def testSimpleLoop():
- index_type = IndexType.get()
+ index_type = IndexType.get()
- @func.FuncOp.from_py_func(index_type, index_type, index_type)
- def simple_loop(lb, ub, step):
- loop = scf.ForOp(lb, ub, step, [lb, lb])
- with InsertionPoint(loop.body):
- scf.YieldOp(loop.inner_iter_args)
- return
+ @func.FuncOp.from_py_func(index_type, index_type, index_type)
+ def simple_loop(lb, ub, step):
+ loop = scf.ForOp(lb, ub, step, [lb, lb])
+ with InsertionPoint(loop.body):
+ scf.YieldOp(loop.inner_iter_args)
+ return
# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
@@ -39,14 +39,14 @@ def simple_loop(lb, ub, step):
# CHECK-LABEL: TEST: testInductionVar
@constructAndPrintInModule
def testInductionVar():
- index_type = IndexType.get()
+ index_type = IndexType.get()
- @func.FuncOp.from_py_func(index_type, index_type, index_type)
- def induction_var(lb, ub, step):
- loop = scf.ForOp(lb, ub, step, [lb])
- with InsertionPoint(loop.body):
- scf.YieldOp([loop.induction_variable])
- return
+ @func.FuncOp.from_py_func(index_type, index_type, index_type)
+ def induction_var(lb, ub, step):
+ loop = scf.ForOp(lb, ub, step, [lb])
+ with InsertionPoint(loop.body):
+ scf.YieldOp([loop.induction_variable])
+ return
# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
@@ -56,19 +56,18 @@ def induction_var(lb, ub, step):
@constructAndPrintInModule
def testOpsAsArguments():
- index_type = IndexType.get()
- callee = func.FuncOp(
- "callee", ([], [index_type, index_type]), visibility="private")
- f = func.FuncOp("ops_as_arguments", ([], []))
- with InsertionPoint(f.add_entry_block()):
- lb = arith.ConstantOp.create_index(0)
- ub = arith.ConstantOp.create_index(42)
- step = arith.ConstantOp.create_index(2)
- iter_args = func.CallOp(callee, [])
- loop = scf.ForOp(lb, ub, step, iter_args)
- with InsertionPoint(loop.body):
- scf.YieldOp(loop.inner_iter_args)
- func.ReturnOp([])
+ index_type = IndexType.get()
+ callee = func.FuncOp("callee", ([], [index_type, index_type]), visibility="private")
+ f = func.FuncOp("ops_as_arguments", ([], []))
+ with InsertionPoint(f.add_entry_block()):
+ lb = arith.ConstantOp.create_index(0)
+ ub = arith.ConstantOp.create_index(42)
+ step = arith.ConstantOp.create_index(2)
+ iter_args = func.CallOp(callee, [])
+ loop = scf.ForOp(lb, ub, step, iter_args)
+ with InsertionPoint(loop.body):
+ scf.YieldOp(loop.inner_iter_args)
+ func.ReturnOp([])
# CHECK-LABEL: TEST: testOpsAsArguments
@@ -86,17 +85,17 @@ def testOpsAsArguments():
@constructAndPrintInModule
def testIfWithoutElse():
- bool = IntegerType.get_signless(1)
- i32 = IntegerType.get_signless(32)
+ bool = IntegerType.get_signless(1)
+ i32 = IntegerType.get_signless(32)
- @func.FuncOp.from_py_func(bool)
- def simple_if(cond):
- if_op = scf.IfOp(cond)
- with InsertionPoint(if_op.then_block):
- one = arith.ConstantOp(i32, 1)
- add = arith.AddIOp(one, one)
- scf.YieldOp([])
- return
+ @func.FuncOp.from_py_func(bool)
+ def simple_if(cond):
+ if_op = scf.IfOp(cond)
+ with InsertionPoint(if_op.then_block):
+ one = arith.ConstantOp(i32, 1)
+ add = arith.AddIOp(one, one)
+ scf.YieldOp([])
+ return
# CHECK: func @simple_if(%[[ARG0:.*]]: i1)
@@ -108,22 +107,22 @@ def simple_if(cond):
@constructAndPrintInModule
def testIfWithElse():
- bool = IntegerType.get_signless(1)
- i32 = IntegerType.get_signless(32)
-
- @func.FuncOp.from_py_func(bool)
- def simple_if_else(cond):
- if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
- with InsertionPoint(if_op.then_block):
- x_true = arith.ConstantOp(i32, 0)
- y_true = arith.ConstantOp(i32, 1)
- scf.YieldOp([x_true, y_true])
- with InsertionPoint(if_op.else_block):
- x_false = arith.ConstantOp(i32, 2)
- y_false = arith.ConstantOp(i32, 3)
- scf.YieldOp([x_false, y_false])
- add = arith.AddIOp(if_op.results[0], if_op.results[1])
- return
+ bool = IntegerType.get_signless(1)
+ i32 = IntegerType.get_signless(32)
+
+ @func.FuncOp.from_py_func(bool)
+ def simple_if_else(cond):
+ if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
+ with InsertionPoint(if_op.then_block):
+ x_true = arith.ConstantOp(i32, 0)
+ y_true = arith.ConstantOp(i32, 1)
+ scf.YieldOp([x_true, y_true])
+ with InsertionPoint(if_op.else_block):
+ x_false = arith.ConstantOp(i32, 2)
+ y_false = arith.ConstantOp(i32, 3)
+ scf.YieldOp([x_false, y_false])
+ add = arith.AddIOp(if_op.results[0], if_op.results[1])
+ return
# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py
index 3e7a8b27d4c09..ad755852f5d37 100644
--- a/mlir/test/python/dialects/shape.py
+++ b/mlir/test/python/dialects/shape.py
@@ -7,36 +7,38 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: testConstShape
@run
def testConstShape():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(
- RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
- def const_shape_tensor(arg):
- shape.ConstWitnessOp(False)
- shape.ConstSizeOp(30)
- shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
- x = shape.ConstShapeOp([1, 2])
- shape.MeetOp(x, x, error="impossible")
- return shape.ConstShapeOp(
- DenseElementsAttr.get(
- np.array([3, 4], dtype=np.int64), type=IndexType.get()))
-
-
-
- # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
- # CHECK-DAG: shape.const_witness false
- # CHECK-DAG: shape.const_size 30
- # CHECK-DAG: shape.const_size 40
- # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
- # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
- print(module)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)
+ )
+ def const_shape_tensor(arg):
+ shape.ConstWitnessOp(False)
+ shape.ConstSizeOp(30)
+ shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
+ x = shape.ConstShapeOp([1, 2])
+ shape.MeetOp(x, x, error="impossible")
+ return shape.ConstShapeOp(
+ DenseElementsAttr.get(
+ np.array([3, 4], dtype=np.int64), type=IndexType.get()
+ )
+ )
+
+ # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
+ # CHECK-DAG: shape.const_witness false
+ # CHECK-DAG: shape.const_size 30
+ # CHECK-DAG: shape.const_size 40
+ # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
+ # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
+ print(module)
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 6190bebcd5e98..b7a06067b5f56 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -3,97 +3,106 @@
from mlir.ir import *
from mlir.dialects import sparse_tensor as st
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: testEncodingAttr1D
@run
def testEncodingAttr1D():
- with Context() as ctx:
- parsed = Attribute.parse('#sparse_tensor.encoding<{'
- ' lvlTypes = [ "compressed" ],'
- ' posWidth = 16,'
- ' crdWidth = 32'
- '}>')
- # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 16, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
-
- # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
- print(f"lvl_types: {casted.lvl_types}")
- # CHECK: dim_ordering: None
- print(f"dim_ordering: {casted.dim_ordering}")
- # CHECK: pos_width: 16
- print(f"pos_width: {casted.pos_width}")
- # CHECK: crd_width: 32
- print(f"crd_width: {casted.crd_width}")
-
- created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
- # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
- print(created)
- # CHECK: created_equal: False
- print(f"created_equal: {created == casted}")
-
- # Verify that the factory creates an instance of the proper type.
- # CHECK: is_proper_instance: True
- print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
- # CHECK: created_pos_width: 0
- print(f"created_pos_width: {created.pos_width}")
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ ' lvlTypes = [ "compressed" ],'
+ " posWidth = 16,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 16, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
+
+ # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
+ print(f"lvl_types: {casted.lvl_types}")
+ # CHECK: dim_ordering: None
+ print(f"dim_ordering: {casted.dim_ordering}")
+ # CHECK: pos_width: 16
+ print(f"pos_width: {casted.pos_width}")
+ # CHECK: crd_width: 32
+ print(f"crd_width: {casted.crd_width}")
+
+ created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+ # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
+ print(created)
+ # CHECK: created_equal: False
+ print(f"created_equal: {created == casted}")
+
+ # Verify that the factory creates an instance of the proper type.
+ # CHECK: is_proper_instance: True
+ print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+ # CHECK: created_pos_width: 0
+ print(f"created_pos_width: {created.pos_width}")
# CHECK-LABEL: TEST: testEncodingAttr2D
@run
def testEncodingAttr2D():
- with Context() as ctx:
- parsed = Attribute.parse('#sparse_tensor.encoding<{'
- ' lvlTypes = [ "dense", "compressed" ],'
- ' dimOrdering = affine_map<(d0, d1) -> (d1, d0)>,'
- ' posWidth = 8,'
- ' crdWidth = 32'
- '}>')
- # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
-
- # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
- print(f"lvl_types: {casted.lvl_types}")
- # CHECK: dim_ordering: (d0, d1) -> (d1, d0)
- print(f"dim_ordering: {casted.dim_ordering}")
- # CHECK: pos_width: 8
- print(f"pos_width: {casted.pos_width}")
- # CHECK: crd_width: 32
- print(f"crd_width: {casted.crd_width}")
-
- created = st.EncodingAttr.get(casted.lvl_types, casted.dim_ordering,
- casted.higher_ordering, 8, 32)
- # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
- print(created)
- # CHECK: created_equal: True
- print(f"created_equal: {created == casted}")
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ ' lvlTypes = [ "dense", "compressed" ],'
+ " dimOrdering = affine_map<(d0, d1) -> (d1, d0)>,"
+ " posWidth = 8,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
+
+ # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
+ print(f"lvl_types: {casted.lvl_types}")
+ # CHECK: dim_ordering: (d0, d1) -> (d1, d0)
+ print(f"dim_ordering: {casted.dim_ordering}")
+ # CHECK: pos_width: 8
+ print(f"pos_width: {casted.pos_width}")
+ # CHECK: crd_width: 32
+ print(f"crd_width: {casted.crd_width}")
+
+ created = st.EncodingAttr.get(
+ casted.lvl_types, casted.dim_ordering, casted.higher_ordering, 8, 32
+ )
+ # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
+ print(created)
+ # CHECK: created_equal: True
+ print(f"created_equal: {created == casted}")
# CHECK-LABEL: TEST: testEncodingAttrOnTensorType
@run
def testEncodingAttrOnTensorType():
- with Context() as ctx, Location.unknown():
- encoding = st.EncodingAttr(
- Attribute.parse('#sparse_tensor.encoding<{'
- ' lvlTypes = [ "compressed" ], '
- ' posWidth = 64,'
- ' crdWidth = 32'
- '}>'))
- tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
- # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>>
- print(tt)
- # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>
- print(tt.encoding)
- assert tt.encoding == encoding
+ with Context() as ctx, Location.unknown():
+ encoding = st.EncodingAttr(
+ Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ ' lvlTypes = [ "compressed" ], '
+ " posWidth = 64,"
+ " crdWidth = 32"
+ "}>"
+ )
+ )
+ tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
+ # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>>
+ print(tt)
+ # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>
+ print(tt.encoding)
+ assert tt.encoding == encoding
diff --git a/mlir/test/python/dialects/sparse_tensor/passes.py b/mlir/test/python/dialects/sparse_tensor/passes.py
index 9319e16e054de..c37c5207ebd9f 100644
--- a/mlir/test/python/dialects/sparse_tensor/passes.py
+++ b/mlir/test/python/dialects/sparse_tensor/passes.py
@@ -7,16 +7,16 @@
def run(f):
- print('\nTEST:', f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: testSparseTensorPass
@run
def testSparseTensorPass():
- with Context() as context:
- PassManager.parse('any(sparsification)')
- PassManager.parse('any(sparse-tensor-conversion)')
- # CHECK: SUCCESS
- print('SUCCESS')
+ with Context() as context:
+ PassManager.parse("any(sparsification)")
+ PassManager.parse("any(sparse-tensor-conversion)")
+ # CHECK: SUCCESS
+ print("SUCCESS")
diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py
index b0ad4b4ad4d65..b690c934dc46b 100644
--- a/mlir/test/python/dialects/tensor.py
+++ b/mlir/test/python/dialects/tensor.py
@@ -7,125 +7,135 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: testDimOp
@run
def testDimOp():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32Type = F32Type.get()
- indexType = IndexType.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get(
- (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()),
- f32Type))
- # CHECK: func @tensor_static_dim
- # CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
- # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
- # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
- # CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
- # CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
- # CHECK: return %[[D0]], %[[D1]]
- def tensor_static_dim(t):
- c0 = arith.ConstantOp(indexType, 0)
- c1 = arith.ConstantOp(indexType, 1)
- d0 = tensor.DimOp(t, c0)
- d1 = tensor.DimOp(t, c1)
- return [d0.result, d1.result]
-
- print(module)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32Type = F32Type.get()
+ indexType = IndexType.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get(
+ (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()),
+ f32Type,
+ )
+ )
+ # CHECK: func @tensor_static_dim
+ # CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+ # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+ # CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+ # CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+ # CHECK: return %[[D0]], %[[D1]]
+ def tensor_static_dim(t):
+ c0 = arith.ConstantOp(indexType, 0)
+ c1 = arith.ConstantOp(indexType, 1)
+ d0 = tensor.DimOp(t, c0)
+ d1 = tensor.DimOp(t, c1)
+ return [d0.result, d1.result]
+
+ print(module)
# CHECK-LABEL: TEST: testEmptyOp
@run
def testEmptyOp():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
- # CHECK-LABEL: func @static_sizes
- # CHECK: %0 = tensor.empty() : tensor<3x4xf32>
- @func.FuncOp.from_py_func()
- def static_sizes():
- return tensor.EmptyOp([3, 4], f32)
-
- # CHECK-LABEL: func @dynamic_sizes
- # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
- @func.FuncOp.from_py_func(IndexType.get(), IndexType.get())
- def dynamic_sizes(d0, d1):
- return tensor.EmptyOp([d0, d1], f32)
-
- # CHECK-LABEL: func @mixed_static_dynamic_sizes
- # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32>
- @func.FuncOp.from_py_func(IndexType.get())
- def mixed_static_dynamic_sizes(d0):
- return tensor.EmptyOp([d0, 4], f32)
-
- # CHECK-LABEL: func @zero_d
- # CHECK: %0 = tensor.empty() : tensor<f32>
- @func.FuncOp.from_py_func()
- def zero_d():
- return tensor.EmptyOp([], f32)
-
- print(module)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ # CHECK-LABEL: func @static_sizes
+ # CHECK: %0 = tensor.empty() : tensor<3x4xf32>
+ @func.FuncOp.from_py_func()
+ def static_sizes():
+ return tensor.EmptyOp([3, 4], f32)
+
+ # CHECK-LABEL: func @dynamic_sizes
+ # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
+ @func.FuncOp.from_py_func(IndexType.get(), IndexType.get())
+ def dynamic_sizes(d0, d1):
+ return tensor.EmptyOp([d0, d1], f32)
+
+ # CHECK-LABEL: func @mixed_static_dynamic_sizes
+ # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32>
+ @func.FuncOp.from_py_func(IndexType.get())
+ def mixed_static_dynamic_sizes(d0):
+ return tensor.EmptyOp([d0, 4], f32)
+
+ # CHECK-LABEL: func @zero_d
+ # CHECK: %0 = tensor.empty() : tensor<f32>
+ @func.FuncOp.from_py_func()
+ def zero_d():
+ return tensor.EmptyOp([], f32)
+
+ print(module)
# CHECK-LABEL: TEST: testInferTypesInsertSlice
@run
def testInferTypesInsertSlice():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32Type = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((1, 1), f32Type),
- RankedTensorType.get((1, 1), f32Type))
- # CHECK: func @f
- # CHECK: tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
- # CHECK-SAME: tensor<1x1xf32> into tensor<1x1xf32>
- def f(source, dest):
- d0 = tensor.InsertSliceOp(source, dest, [], [], [],
- DenseI64ArrayAttr.get([0, 0]),
- DenseI64ArrayAttr.get([1, 1]),
- DenseI64ArrayAttr.get([0, 0]))
- return [d0.result]
-
- print(module)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32Type = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 1), f32Type),
+ RankedTensorType.get((1, 1), f32Type),
+ )
+ # CHECK: func @f
+ # CHECK: tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
+ # CHECK-SAME: tensor<1x1xf32> into tensor<1x1xf32>
+ def f(source, dest):
+ d0 = tensor.InsertSliceOp(
+ source,
+ dest,
+ [],
+ [],
+ [],
+ DenseI64ArrayAttr.get([0, 0]),
+ DenseI64ArrayAttr.get([1, 1]),
+ DenseI64ArrayAttr.get([0, 0]),
+ )
+ return [d0.result]
+
+ print(module)
# CHECK-LABEL: TEST: testFromElementsOp
@run
def testFromElementsOp():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
- @func.FuncOp.from_py_func()
- def default_builder():
- c0 = arith.ConstantOp(f32, 0.0)
- # CHECK: %[[C0:.*]] = "arith.constant
- # CHECK-SAME: value = 0.000000e+00 : f32
- print(c0)
- c1 = arith.ConstantOp(f32, 1.0)
- # CHECK: %[[C1:.*]] = "arith.constant
- # CHECK-SAME: value = 1.000000e+00 : f32
- print(c1)
-
- t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1])
- # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32>
- print(t)
-
- t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1])
- # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32>
- print(t)
-
- t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
- # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
- print(t)
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func()
+ def default_builder():
+ c0 = arith.ConstantOp(f32, 0.0)
+ # CHECK: %[[C0:.*]] = "arith.constant
+ # CHECK-SAME: value = 0.000000e+00 : f32
+ print(c0)
+ c1 = arith.ConstantOp(f32, 1.0)
+ # CHECK: %[[C1:.*]] = "arith.constant
+ # CHECK-SAME: value = 1.000000e+00 : f32
+ print(c1)
+
+ t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1])
+ # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32>
+ print(t)
+
+ t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1])
+ # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32>
+ print(t)
+
+ t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
+ # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
+ print(t)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6b36c025eafa3..ca6499b5706d1 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -6,158 +6,188 @@
def run(f):
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- print("\nTEST:", f.__name__)
- f()
- print(module)
- return f
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ print("\nTEST:", f.__name__)
+ f()
+ print(module)
+ return f
@run
def testTypes():
- # CHECK-LABEL: TEST: testTypes
- # CHECK: !transform.any_op
- any_op = transform.AnyOpType.get()
- print(any_op)
+ # CHECK-LABEL: TEST: testTypes
+ # CHECK: !transform.any_op
+ any_op = transform.AnyOpType.get()
+ print(any_op)
- # CHECK: !transform.op<"foo.bar">
- # CHECK: foo.bar
- concrete_op = transform.OperationType.get("foo.bar")
- print(concrete_op)
- print(concrete_op.operation_name)
+ # CHECK: !transform.op<"foo.bar">
+ # CHECK: foo.bar
+ concrete_op = transform.OperationType.get("foo.bar")
+ print(concrete_op)
+ print(concrete_op.operation_name)
@run
def testSequenceOp():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [transform.AnyOpType.get()],
- transform.AnyOpType.get())
- with InsertionPoint(sequence.body):
- transform.YieldOp([sequence.bodyTarget])
- # CHECK-LABEL: TEST: testSequenceOp
- # CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
- # CHECK: yield %[[ARG0]] : !transform.any_op
- # CHECK: }
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [transform.AnyOpType.get()],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ transform.YieldOp([sequence.bodyTarget])
+ # CHECK-LABEL: TEST: testSequenceOp
+ # CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: yield %[[ARG0]] : !transform.any_op
+ # CHECK: }
@run
def testNestedSequenceOp():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get())
- with InsertionPoint(sequence.body):
- nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget)
- with InsertionPoint(nested.body):
- doubly_nested = transform.SequenceOp(
- transform.FailurePropagationMode.PROPAGATE,
- [transform.AnyOpType.get()], nested.bodyTarget)
- with InsertionPoint(doubly_nested.body):
- transform.YieldOp([doubly_nested.bodyTarget])
- transform.YieldOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testNestedSequenceOp
- # CHECK: transform.sequence failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
- # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
- # CHECK: yield %[[ARG2]] : !transform.any_op
- # CHECK: }
- # CHECK: }
- # CHECK: }
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget
+ )
+ with InsertionPoint(nested.body):
+ doubly_nested = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [transform.AnyOpType.get()],
+ nested.bodyTarget,
+ )
+ with InsertionPoint(doubly_nested.body):
+ transform.YieldOp([doubly_nested.bodyTarget])
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOp
+ # CHECK: transform.sequence failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
+ # CHECK: yield %[[ARG2]] : !transform.any_op
+ # CHECK: }
+ # CHECK: }
+ # CHECK: }
@run
def testSequenceOpWithExtras():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
- [transform.AnyOpType.get(),
- transform.OperationType.get("foo.bar")])
- with InsertionPoint(sequence.body):
- transform.YieldOp()
- # CHECK-LABEL: TEST: testSequenceOpWithExtras
- # CHECK: transform.sequence failures(propagate)
- # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [],
+ transform.AnyOpType.get(),
+ [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+ )
+ with InsertionPoint(sequence.body):
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
@run
def testNestedSequenceOpWithExtras():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
- [transform.AnyOpType.get(),
- transform.OperationType.get("foo.bar")])
- with InsertionPoint(sequence.body):
- nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [], sequence.bodyTarget,
- sequence.bodyExtraArgs)
- with InsertionPoint(nested.body):
- transform.YieldOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
- # CHECK: transform.sequence failures(propagate)
- # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
- # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [],
+ transform.AnyOpType.get(),
+ [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+ )
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [],
+ sequence.bodyTarget,
+ sequence.bodyExtraArgs,
+ )
+ with InsertionPoint(nested.body):
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+ # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
@run
def testTransformPDLOps():
- withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
- with InsertionPoint(withPdl.body):
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [transform.AnyOpType.get()],
- withPdl.bodyTarget)
- with InsertionPoint(sequence.body):
- match = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher")
- transform.YieldOp(match)
- # CHECK-LABEL: TEST: testTransformPDLOps
- # CHECK: transform.with_pdl_patterns {
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
- # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
- # CHECK: yield %[[RES]] : !transform.any_op
- # CHECK: }
- # CHECK: }
+ withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+ with InsertionPoint(withPdl.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [transform.AnyOpType.get()],
+ withPdl.bodyTarget,
+ )
+ with InsertionPoint(sequence.body):
+ match = transform_pdl.PDLMatchOp(
+ transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+ )
+ transform.YieldOp(match)
+ # CHECK-LABEL: TEST: testTransformPDLOps
+ # CHECK: transform.with_pdl_patterns {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+ # CHECK: yield %[[RES]] : !transform.any_op
+ # CHECK: }
+ # CHECK: }
@run
def testGetClosestIsolatedParentOp():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get())
- with InsertionPoint(sequence.body):
- transform.GetClosestIsolatedParentOp(transform.AnyOpType.get(), sequence.bodyTarget)
- transform.YieldOp()
- # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
- # CHECK: transform.sequence
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: = get_closest_isolated_parent %[[ARG1]]
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ transform.GetClosestIsolatedParentOp(
+ transform.AnyOpType.get(), sequence.bodyTarget
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
+ # CHECK: transform.sequence
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: = get_closest_isolated_parent %[[ARG1]]
@run
def testMergeHandlesOp():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get())
- with InsertionPoint(sequence.body):
- transform.MergeHandlesOp([sequence.bodyTarget])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testMergeHandlesOp
- # CHECK: transform.sequence
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: = merge_handles %[[ARG1]]
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ transform.MergeHandlesOp([sequence.bodyTarget])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMergeHandlesOp
+ # CHECK: transform.sequence
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: = merge_handles %[[ARG1]]
@run
def testReplicateOp():
- with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
- with InsertionPoint(with_pdl.body):
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget)
- with InsertionPoint(sequence.body):
- m1 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first")
- m2 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second")
- transform.ReplicateOp(m1, [m2])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testReplicateOp
- # CHECK: %[[FIRST:.+]] = pdl_match
- # CHECK: %[[SECOND:.+]] = pdl_match
- # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+ with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+ with InsertionPoint(with_pdl.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
+ )
+ with InsertionPoint(sequence.body):
+ m1 = transform_pdl.PDLMatchOp(
+ transform.AnyOpType.get(), sequence.bodyTarget, "first"
+ )
+ m2 = transform_pdl.PDLMatchOp(
+ transform.AnyOpType.get(), sequence.bodyTarget, "second"
+ )
+ transform.ReplicateOp(m1, [m2])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testReplicateOp
+ # CHECK: %[[FIRST:.+]] = pdl_match
+ # CHECK: %[[SECOND:.+]] = pdl_match
+ # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py
index 067a8b60d4f89..28a022a400fe6 100644
--- a/mlir/test/python/dialects/transform_loop_ext.py
+++ b/mlir/test/python/dialects/transform_loop_ext.py
@@ -7,70 +7,92 @@
def run(f):
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- print("\nTEST:", f.__name__)
- f()
- print(module)
- return f
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ print("\nTEST:", f.__name__)
+ f()
+ print(module)
+ return f
@run
def getParentLoop():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- loop.GetParentForOp(transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2)
- transform.YieldOp()
- # CHECK-LABEL: TEST: getParentLoop
- # CHECK: = transform.loop.get_parent_for %
- # CHECK: num_loops = 2
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ loop.GetParentForOp(
+ transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: getParentLoop
+ # CHECK: = transform.loop.get_parent_for %
+ # CHECK: num_loops = 2
@run
def loopOutline():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [], transform.OperationType.get("scf.for"))
- with InsertionPoint(sequence.body):
- loop.LoopOutlineOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo")
- transform.YieldOp()
- # CHECK-LABEL: TEST: loopOutline
- # CHECK: = transform.loop.outline %
- # CHECK: func_name = "foo"
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [],
+ transform.OperationType.get("scf.for"),
+ )
+ with InsertionPoint(sequence.body):
+ loop.LoopOutlineOp(
+ transform.AnyOpType.get(),
+ transform.AnyOpType.get(),
+ sequence.bodyTarget,
+ func_name="foo",
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: loopOutline
+ # CHECK: = transform.loop.outline %
+ # CHECK: func_name = "foo"
@run
def loopPeel():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [], transform.OperationType.get("scf.for"))
- with InsertionPoint(sequence.body):
- loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
- transform.YieldOp()
- # CHECK-LABEL: TEST: loopPeel
- # CHECK: = transform.loop.peel %
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [],
+ transform.OperationType.get("scf.for"),
+ )
+ with InsertionPoint(sequence.body):
+ loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: loopPeel
+ # CHECK: = transform.loop.peel %
@run
def loopPipeline():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [], transform.OperationType.get("scf.for"))
- with InsertionPoint(sequence.body):
- loop.LoopPipelineOp(pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3)
- transform.YieldOp()
- # CHECK-LABEL: TEST: loopPipeline
- # CHECK: = transform.loop.pipeline %
- # CHECK-DAG: iteration_interval = 3
- # (read_latency has default value and is not printed)
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [],
+ transform.OperationType.get("scf.for"),
+ )
+ with InsertionPoint(sequence.body):
+ loop.LoopPipelineOp(
+ pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: loopPipeline
+ # CHECK: = transform.loop.pipeline %
+ # CHECK-DAG: iteration_interval = 3
+ # (read_latency has default value and is not printed)
@run
def loopUnroll():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [], transform.OperationType.get("scf.for"))
- with InsertionPoint(sequence.body):
- loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
- transform.YieldOp()
- # CHECK-LABEL: TEST: loopUnroll
- # CHECK: transform.loop.unroll %
- # CHECK: factor = 42
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [],
+ transform.OperationType.get("scf.for"),
+ )
+ with InsertionPoint(sequence.body):
+ loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: loopUnroll
+ # CHECK: transform.loop.unroll %
+ # CHECK: factor = 42
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index d2a82b8218f25..2dfae47bdfb49 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -8,204 +8,230 @@
def run(f):
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- print("\nTEST:", f.__name__)
- f()
- print(module)
- return f
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ print("\nTEST:", f.__name__)
+ f()
+ print(module)
+ return f
@run
def testDecompose():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- structured.DecomposeOp(sequence.bodyTarget)
- transform.YieldOp()
- # CHECK-LABEL: TEST: testDecompose
- # CHECK: transform.sequence
- # CHECK: transform.structured.decompose
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.DecomposeOp(sequence.bodyTarget)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testDecompose
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.decompose
@run
def testGeneralize():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- structured.GeneralizeOp(sequence.bodyTarget)
- transform.YieldOp()
- # CHECK-LABEL: TEST: testGeneralize
- # CHECK: transform.sequence
- # CHECK: transform.structured.generalize
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.GeneralizeOp(sequence.bodyTarget)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testGeneralize
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.generalize
@run
def testInterchange():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- structured.InterchangeOp(
- sequence.bodyTarget,
- iterator_interchange=[1, 0])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testInterchange
- # CHECK: transform.sequence
- # CHECK: transform.structured.interchange
- # CHECK: iterator_interchange = [1, 0]
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testInterchange
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.interchange
+ # CHECK: iterator_interchange = [1, 0]
@run
def testMultitileSizes():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- structured.MultiTileSizesOp(pdl.OperationType.get(),
- sequence.bodyTarget,
- dimension=1,
- target_size=42)
- transform.YieldOp()
- # CHECK-LABEL: TEST: testMultitileSizes
- # CHECK: transform.sequence
- # CHECK: transform.structured.multitile_sizes
- # CHECK-DAG: dimension = 1
- # CHECK-DAG: target_size = 42
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.MultiTileSizesOp(
+ pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMultitileSizes
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.multitile_sizes
+ # CHECK-DAG: dimension = 1
+ # CHECK-DAG: target_size = 42
@run
def testPad():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- structured.PadOp(
- sequence.bodyTarget,
- padding_values=[FloatAttr.get_f32(42.0)],
- padding_dimensions=[1],
- transpose_paddings=[[1, 0]])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testPad
- # CHECK: transform.sequence
- # CHECK: transform.structured.pad
- # CHECK-DAG: padding_values = [4.200000e+01 : f32]
- # CHECK-DAG: padding_dimensions = [1]
- # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
- # (pack_paddings has default values)
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.PadOp(
+ sequence.bodyTarget,
+ padding_values=[FloatAttr.get_f32(42.0)],
+ padding_dimensions=[1],
+ transpose_paddings=[[1, 0]],
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testPad
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.pad
+ # CHECK-DAG: padding_values = [4.200000e+01 : f32]
+ # CHECK-DAG: padding_dimensions = [1]
+ # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
+ # (pack_paddings has default values)
+
@run
def testScalarize():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- structured.ScalarizeOp(sequence.bodyTarget)
- transform.YieldOp()
- # CHECK-LABEL: TEST: testScalarize
- # CHECK: transform.structured.scalarize
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.ScalarizeOp(sequence.bodyTarget)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testScalarize
+ # CHECK: transform.structured.scalarize
@run
def testSplit():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
- structured.SplitOp(
- split.results[0], dimension=3, split_point=split.results[1])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testSplit
- # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
- # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
+ structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testSplit
+ # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
+ # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
+
@run
def testTileCompact():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- structured.TileOp(sequence.bodyTarget,
- sizes=[4, 8],
- interchange=[0, 1])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testTileCompact
- # CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
- # CHECK: interchange = [0, 1]
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
+ # CHECK: interchange = [0, 1]
+
@run
def testTileAttributes():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- attr = DenseI64ArrayAttr.get([4, 8])
- ichange = DenseI64ArrayAttr.get([0, 1])
- with InsertionPoint(sequence.body):
- structured.TileOp(sequence.bodyTarget,
- sizes=attr,
- interchange=ichange)
- transform.YieldOp()
- # CHECK-LABEL: TEST: testTileAttributes
- # CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
- # CHECK: interchange = [0, 1]
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ attr = DenseI64ArrayAttr.get([4, 8])
+ ichange = DenseI64ArrayAttr.get([0, 1])
+ with InsertionPoint(sequence.body):
+ structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileAttributes
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
+ # CHECK: interchange = [0, 1]
+
@run
def testTileZero():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- structured.TileOp(sequence.bodyTarget,
- sizes=[4, 0, 2, 0],
- interchange=[0, 1, 2, 3])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testTileZero
- # CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
- # CHECK: interchange = [0, 1, 2, 3]
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.TileOp(
+ sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileZero
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
+ # CHECK: interchange = [0, 1, 2, 3]
+
@run
def testTileDynamic():
- with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
- with InsertionPoint(with_pdl.body):
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [],
- with_pdl.bodyTarget)
- with InsertionPoint(sequence.body):
- m1 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
- m2 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
- structured.TileOp(sequence.bodyTarget,
- sizes=[m1, 3, m2, 0])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testTileDynamic
- # CHECK: %[[FIRST:.+]] = pdl_match
- # CHECK: %[[SECOND:.+]] = pdl_match
- # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
+ with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
+ with InsertionPoint(with_pdl.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
+ )
+ with InsertionPoint(sequence.body):
+ m1 = transform_pdl.PDLMatchOp(
+ pdl.OperationType.get(), sequence.bodyTarget, "first"
+ )
+ m2 = transform_pdl.PDLMatchOp(
+ pdl.OperationType.get(), sequence.bodyTarget, "second"
+ )
+ structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileDynamic
+ # CHECK: %[[FIRST:.+]] = pdl_match
+ # CHECK: %[[SECOND:.+]] = pdl_match
+ # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
@run
def testTileExplicitLoopTypeSingle():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [], transform.AnyOpType.get())
- with InsertionPoint(sequence.body):
- structured.TileOp(transform.OperationType.get("scf.for"),
- sequence.bodyTarget,
- sizes=[2, 3, 4])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
- # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
- # CHECK-COUNT-3: !transform.op<"scf.for">
-
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.TileOp(
+ transform.OperationType.get("scf.for"), sequence.bodyTarget, sizes=[2, 3, 4]
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
+ # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
+ # CHECK-COUNT-3: !transform.op<"scf.for">
@run
def testTileExplicitLoopTypeAll():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
- [], transform.AnyOpType.get())
- types = [
- transform.OperationType.get(x)
- for x in ["scf.for", "scf.parallel", "scf.forall"]
- ]
- with InsertionPoint(sequence.body):
- structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
- transform.YieldOp()
- # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
- # CHECK: = transform.structured.tile
- # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
- # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ types = [
+ transform.OperationType.get(x)
+ for x in ["scf.for", "scf.parallel", "scf.forall"]
+ ]
+ with InsertionPoint(sequence.body):
+ structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
+ # CHECK: = transform.structured.tile
+ # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
+ # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
+
@run
def testVectorize():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- with InsertionPoint(sequence.body):
- structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
- transform.YieldOp()
- # CHECK-LABEL: TEST: testVectorize
- # CHECK: transform.sequence
- # CHECK: = transform.structured.vectorize
- # CHECK: {vectorize_padding}
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testVectorize
+ # CHECK: transform.sequence
+ # CHECK: = transform.structured.vectorize
+ # CHECK: {vectorize_padding}
diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py
index 83c09616d56fe..2347abb62b410 100644
--- a/mlir/test/python/dialects/vector.py
+++ b/mlir/test/python/dialects/vector.py
@@ -5,57 +5,62 @@
import mlir.dialects.func as func
import mlir.dialects.vector as vector
+
def run(f):
- print("\nTEST:", f.__name__)
- with Context(), Location.unknown():
- f()
- return f
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ f()
+ return f
+
# CHECK-LABEL: TEST: testPrintOp
@run
def testPrintOp():
- module = Module.create()
- with InsertionPoint(module.body):
+ module = Module.create()
+ with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
- def print_vector(arg):
- return vector.PrintOp(arg)
+ @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
+ def print_vector(arg):
+ return vector.PrintOp(arg)
- # CHECK-LABEL: func @print_vector(
- # CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) {
- # CHECK: vector.print %[[ARG]] : vector<12x5xf32>
- # CHECK: return
- # CHECK: }
- print(module)
+ # CHECK-LABEL: func @print_vector(
+ # CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) {
+ # CHECK: vector.print %[[ARG]] : vector<12x5xf32>
+ # CHECK: return
+ # CHECK: }
+ print(module)
# CHECK-LABEL: TEST: testTransferReadOp
@run
def testTransferReadOp():
- module = Module.create()
- with InsertionPoint(module.body):
- vector_type = VectorType.get([2, 3], F32Type.get())
- memref_type = MemRefType.get(
- [ShapedType.get_dynamic_size(),
- ShapedType.get_dynamic_size()], F32Type.get())
- index_type = IndexType.get()
- mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
- identity_map = AffineMap.get_identity(vector_type.rank)
- identity_map_attr = AffineMapAttr.get(identity_map)
- f = func.FuncOp("transfer_read",
- ([memref_type, index_type,
- F32Type.get(), mask_type], []))
- with InsertionPoint(f.add_entry_block()):
- A, zero, padding, mask = f.arguments
- vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
- padding, mask=mask)
- vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
- padding)
- func.ReturnOp([])
-
- # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
- # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
- # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
- # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
- # CHECK-NOT: %[[MASK]]
- print(module)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ vector_type = VectorType.get([2, 3], F32Type.get())
+ memref_type = MemRefType.get(
+ [ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()],
+ F32Type.get(),
+ )
+ index_type = IndexType.get()
+ mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
+ identity_map = AffineMap.get_identity(vector_type.rank)
+ identity_map_attr = AffineMapAttr.get(identity_map)
+ f = func.FuncOp(
+ "transfer_read", ([memref_type, index_type, F32Type.get(), mask_type], [])
+ )
+ with InsertionPoint(f.add_entry_block()):
+ A, zero, padding, mask = f.arguments
+ vector.TransferReadOp(
+ vector_type, A, [zero, zero], identity_map_attr, padding, mask=mask
+ )
+ vector.TransferReadOp(
+ vector_type, A, [zero, zero], identity_map_attr, padding
+ )
+ func.ReturnOp([])
+
+ # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
+ # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
+ # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
+ # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
+ # CHECK-NOT: %[[MASK]]
+ print(module)
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 973810d5b82f4..50d6e82348a9f 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -10,34 +10,36 @@
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
def log(*args):
- print(*args, file=sys.stderr)
- sys.stderr.flush()
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
def run(f):
- log("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
# Verify capsule interop.
# CHECK-LABEL: TEST: testCapsule
def testCapsule():
- with Context():
- module = Module.parse(r"""
+ with Context():
+ module = Module.parse(
+ r"""
llvm.func @none() {
llvm.return
}
- """)
- execution_engine = ExecutionEngine(module)
- execution_engine_capsule = execution_engine._CAPIPtr
- # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
- log(repr(execution_engine_capsule))
- execution_engine._testing_release()
- execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
- # CHECK: _mlirExecutionEngine.ExecutionEngine
- log(repr(execution_engine1))
+ """
+ )
+ execution_engine = ExecutionEngine(module)
+ execution_engine_capsule = execution_engine._CAPIPtr
+ # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
+ log(repr(execution_engine_capsule))
+ execution_engine._testing_release()
+ execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
+ # CHECK: _mlirExecutionEngine.ExecutionEngine
+ log(repr(execution_engine1))
run(testCapsule)
@@ -46,40 +48,45 @@ def testCapsule():
# Test invalid ExecutionEngine creation
# CHECK-LABEL: TEST: testInvalidModule
def testInvalidModule():
- with Context():
- # Builtin function
- module = Module.parse(r"""
+ with Context():
+ # Builtin function
+ module = Module.parse(
+ r"""
func.func @foo() { return }
- """)
- # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine.
- try:
- execution_engine = ExecutionEngine(module)
- except RuntimeError as e:
- log("Got RuntimeError: ", e)
+ """
+ )
+ # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine.
+ try:
+ execution_engine = ExecutionEngine(module)
+ except RuntimeError as e:
+ log("Got RuntimeError: ", e)
run(testInvalidModule)
def lowerToLLVM(module):
- pm = PassManager.parse(
- "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)")
- pm.run(module.operation)
- return module
+ pm = PassManager.parse(
+ "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)"
+ )
+ pm.run(module.operation)
+ return module
# Test simple ExecutionEngine execution
# CHECK-LABEL: TEST: testInvokeVoid
def testInvokeVoid():
- with Context():
- module = Module.parse(r"""
+ with Context():
+ module = Module.parse(
+ r"""
func.func @void() attributes { llvm.emit_c_interface } {
return
}
- """)
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- # Nothing to check other than no exception thrown here.
- execution_engine.invoke("void")
+ """
+ )
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ # Nothing to check other than no exception thrown here.
+ execution_engine.invoke("void")
run(testInvokeVoid)
@@ -88,23 +95,25 @@ def testInvokeVoid():
# Test argument passing and result with a simple float addition.
# CHECK-LABEL: TEST: testInvokeFloatAdd
def testInvokeFloatAdd():
- with Context():
- module = Module.parse(r"""
+ with Context():
+ module = Module.parse(
+ r"""
func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
%add = arith.addf %arg0, %arg1 : f32
return %add : f32
}
- """)
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- # Prepare arguments: two input floats and one result.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- arg0 = c_float_p(42.)
- arg1 = c_float_p(2.)
- res = c_float_p(-1.)
- execution_engine.invoke("add", arg0, arg1, res)
- # CHECK: 42.0 + 2.0 = 44.0
- log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
+ """
+ )
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ # Prepare arguments: two input floats and one result.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ arg0 = c_float_p(42.0)
+ arg1 = c_float_p(2.0)
+ res = c_float_p(-1.0)
+ execution_engine.invoke("add", arg0, arg1, res)
+ # CHECK: 42.0 + 2.0 = 44.0
+ log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
run(testInvokeFloatAdd)
@@ -113,33 +122,35 @@ def testInvokeFloatAdd():
# Test callback
# CHECK-LABEL: TEST: testBasicCallback
def testBasicCallback():
- # Define a callback function that takes a float and an integer and returns a float.
- @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
- def callback(a, b):
- return a / 2 + b / 2
-
- with Context():
- # The module just forwards to a runtime function known as "some_callback_into_python".
- module = Module.parse(r"""
+ # Define a callback function that takes a float and an integer and returns a float.
+ @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
+ def callback(a, b):
+ return a / 2 + b / 2
+
+ with Context():
+ # The module just forwards to a runtime function known as "some_callback_into_python".
+ module = Module.parse(
+ r"""
func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
%resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
return %resf : f32
}
func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
- """)
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- execution_engine.register_runtime("some_callback_into_python", callback)
-
- # Prepare arguments: two input floats and one result.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- c_int_p = ctypes.c_int * 1
- arg0 = c_float_p(42.)
- arg1 = c_int_p(2)
- res = c_float_p(-1.)
- execution_engine.invoke("add", arg0, arg1, res)
- # CHECK: 42.0 + 2 = 44.0
- log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
+ """
+ )
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.register_runtime("some_callback_into_python", callback)
+
+ # Prepare arguments: two input floats and one result.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ c_int_p = ctypes.c_int * 1
+ arg0 = c_float_p(42.0)
+ arg1 = c_int_p(2)
+ res = c_float_p(-1.0)
+ execution_engine.invoke("add", arg0, arg1, res)
+ # CHECK: 42.0 + 2 = 44.0
+ log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
run(testBasicCallback)
@@ -148,44 +159,46 @@ def callback(a, b):
# Test callback with an unranked memref
# CHECK-LABEL: TEST: testUnrankedMemRefCallback
def testUnrankedMemRefCallback():
- # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
- @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
- def callback(a):
- arr = unranked_memref_to_numpy(a, np.float32)
- log("Inside callback: ")
- log(arr)
-
- with Context():
- # The module just forwards to a runtime function known as "some_callback_into_python".
- module = Module.parse(r"""
+ # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
+ @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
+ def callback(a):
+ arr = unranked_memref_to_numpy(a, np.float32)
+ log("Inside callback: ")
+ log(arr)
+
+ with Context():
+ # The module just forwards to a runtime function known as "some_callback_into_python".
+ module = Module.parse(
+ r"""
func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
return
}
func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
-""")
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- execution_engine.register_runtime("some_callback_into_python", callback)
- inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
- # CHECK: Inside callback:
- # CHECK{LITERAL}: [[1. 2.]
- # CHECK{LITERAL}: [3. 4.]]
- execution_engine.invoke(
- "callback_memref",
- ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
- )
- inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
- strided_arr = np.lib.stride_tricks.as_strided(
- inp_arr_1, strides=(4, 0), shape=(3, 4))
- # CHECK: Inside callback:
- # CHECK{LITERAL}: [[5. 5. 5. 5.]
- # CHECK{LITERAL}: [6. 6. 6. 6.]
- # CHECK{LITERAL}: [7. 7. 7. 7.]]
- execution_engine.invoke(
- "callback_memref",
- ctypes.pointer(
- ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
- )
+"""
+ )
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.register_runtime("some_callback_into_python", callback)
+ inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
+ # CHECK: Inside callback:
+ # CHECK{LITERAL}: [[1. 2.]
+ # CHECK{LITERAL}: [3. 4.]]
+ execution_engine.invoke(
+ "callback_memref",
+ ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
+ )
+ inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
+ strided_arr = np.lib.stride_tricks.as_strided(
+ inp_arr_1, strides=(4, 0), shape=(3, 4)
+ )
+ # CHECK: Inside callback:
+ # CHECK{LITERAL}: [[5. 5. 5. 5.]
+ # CHECK{LITERAL}: [6. 6. 6. 6.]
+ # CHECK{LITERAL}: [7. 7. 7. 7.]]
+ execution_engine.invoke(
+ "callback_memref",
+ ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
+ )
run(testUnrankedMemRefCallback)
@@ -194,36 +207,39 @@ def callback(a):
# Test callback with a ranked memref.
# CHECK-LABEL: TEST: testRankedMemRefCallback
def testRankedMemRefCallback():
- # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
- @ctypes.CFUNCTYPE(
- None,
- ctypes.POINTER(
- make_nd_memref_descriptor(2,
- np.ctypeslib.as_ctypes_type(np.float32))),
- )
- def callback(a):
- arr = ranked_memref_to_numpy(a)
- log("Inside Callback: ")
- log(arr)
-
- with Context():
- # The module just forwards to a runtime function known as "some_callback_into_python".
- module = Module.parse(r"""
+ # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
+ @ctypes.CFUNCTYPE(
+ None,
+ ctypes.POINTER(
+ make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
+ ),
+ )
+ def callback(a):
+ arr = ranked_memref_to_numpy(a)
+ log("Inside Callback: ")
+ log(arr)
+
+ with Context():
+ # The module just forwards to a runtime function known as "some_callback_into_python".
+ module = Module.parse(
+ r"""
func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
return
}
func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
-""")
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- execution_engine.register_runtime("some_callback_into_python", callback)
- inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
- # CHECK: Inside Callback:
- # CHECK{LITERAL}: [[1. 5.]
- # CHECK{LITERAL}: [6. 7.]]
- execution_engine.invoke(
- "callback_memref",
- ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))))
+"""
+ )
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.register_runtime("some_callback_into_python", callback)
+ inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
+ # CHECK: Inside Callback:
+ # CHECK{LITERAL}: [[1. 5.]
+ # CHECK{LITERAL}: [6. 7.]]
+ execution_engine.invoke(
+ "callback_memref",
+ ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
+ )
run(testRankedMemRefCallback)
@@ -232,8 +248,9 @@ def callback(a):
# Test addition of two memrefs.
# CHECK-LABEL: TEST: testMemrefAdd
def testMemrefAdd():
- with Context():
- module = Module.parse("""
+ with Context():
+ module = Module.parse(
+ """
module {
func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
%0 = arith.constant 0 : index
@@ -243,23 +260,28 @@ def testMemrefAdd():
memref.store %3, %arg2[%0] : memref<1xf32>
return
}
- } """)
- arg1 = np.array([32.5]).astype(np.float32)
- arg2 = np.array(6).astype(np.float32)
- res = np.array([0]).astype(np.float32)
-
- arg1_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg1)))
- arg2_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg2)))
- res_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(res)))
-
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
- res_memref_ptr)
- # CHECK: [32.5] + 6.0 = [38.5]
- log("{0} + {1} = {2}".format(arg1, arg2, res))
+ } """
+ )
+ arg1 = np.array([32.5]).astype(np.float32)
+ arg2 = np.array(6).astype(np.float32)
+ res = np.array([0]).astype(np.float32)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1))
+ )
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2))
+ )
+ res_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(res))
+ )
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke(
+ "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
+ )
+ # CHECK: [32.5] + 6.0 = [38.5]
+ log("{0} + {1} = {2}".format(arg1, arg2, res))
run(testMemrefAdd)
@@ -268,8 +290,9 @@ def testMemrefAdd():
# Test addition of two f16 memrefs
# CHECK-LABEL: TEST: testF16MemrefAdd
def testF16MemrefAdd():
- with Context():
- module = Module.parse("""
+ with Context():
+ module = Module.parse(
+ """
module {
func.func @main(%arg0: memref<1xf16>,
%arg1: memref<1xf16>,
@@ -281,29 +304,34 @@ def testF16MemrefAdd():
memref.store %3, %arg2[%0] : memref<1xf16>
return
}
- } """)
-
- arg1 = np.array([11.]).astype(np.float16)
- arg2 = np.array([22.]).astype(np.float16)
- arg3 = np.array([0.]).astype(np.float16)
-
- arg1_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg1)))
- arg2_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg2)))
- arg3_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg3)))
-
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
- arg3_memref_ptr)
- # CHECK: [11.] + [22.] = [33.]
- log("{0} + {1} = {2}".format(arg1, arg2, arg3))
-
- # test to-numpy utility
- # CHECK: [33.]
- npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
- log(npout)
+ } """
+ )
+
+ arg1 = np.array([11.0]).astype(np.float16)
+ arg2 = np.array([22.0]).astype(np.float16)
+ arg3 = np.array([0.0]).astype(np.float16)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1))
+ )
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2))
+ )
+ arg3_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg3))
+ )
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke(
+ "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
+ )
+ # CHECK: [11.] + [22.] = [33.]
+ log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+ # test to-numpy utility
+ # CHECK: [33.]
+ npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+ log(npout)
run(testF16MemrefAdd)
@@ -312,8 +340,9 @@ def testF16MemrefAdd():
# Test addition of two complex memrefs
# CHECK-LABEL: TEST: testComplexMemrefAdd
def testComplexMemrefAdd():
- with Context():
- module = Module.parse("""
+ with Context():
+ module = Module.parse(
+ """
module {
func.func @main(%arg0: memref<1xcomplex<f64>>,
%arg1: memref<1xcomplex<f64>>,
@@ -325,31 +354,34 @@ def testComplexMemrefAdd():
memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
return
}
- } """)
-
- arg1 = np.array([1.+2.j]).astype(np.complex128)
- arg2 = np.array([3.+4.j]).astype(np.complex128)
- arg3 = np.array([0.+0.j]).astype(np.complex128)
-
- arg1_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg1)))
- arg2_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg2)))
- arg3_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg3)))
-
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- execution_engine.invoke("main",
- arg1_memref_ptr,
- arg2_memref_ptr,
- arg3_memref_ptr)
- # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
- log("{0} + {1} = {2}".format(arg1, arg2, arg3))
-
- # test to-numpy utility
- # CHECK: [4.+6.j]
- npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
- log(npout)
+ } """
+ )
+
+ arg1 = np.array([1.0 + 2.0j]).astype(np.complex128)
+ arg2 = np.array([3.0 + 4.0j]).astype(np.complex128)
+ arg3 = np.array([0.0 + 0.0j]).astype(np.complex128)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1))
+ )
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2))
+ )
+ arg3_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg3))
+ )
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke(
+ "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
+ )
+ # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
+ log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+ # test to-numpy utility
+ # CHECK: [4.+6.j]
+ npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+ log(npout)
run(testComplexMemrefAdd)
@@ -358,8 +390,9 @@ def testComplexMemrefAdd():
# Test addition of two complex unranked memrefs
# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
def testComplexUnrankedMemrefAdd():
- with Context():
- module = Module.parse("""
+ with Context():
+ module = Module.parse(
+ """
module {
func.func @main(%arg0: memref<*xcomplex<f32>>,
%arg1: memref<*xcomplex<f32>>,
@@ -374,32 +407,34 @@ def testComplexUnrankedMemrefAdd():
memref.store %3, %C[%0] : memref<1xcomplex<f32>>
return
}
- } """)
-
- arg1 = np.array([5.+6.j]).astype(np.complex64)
- arg2 = np.array([7.+8.j]).astype(np.complex64)
- arg3 = np.array([0.+0.j]).astype(np.complex64)
-
- arg1_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_unranked_memref_descriptor(arg1)))
- arg2_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_unranked_memref_descriptor(arg2)))
- arg3_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_unranked_memref_descriptor(arg3)))
-
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- execution_engine.invoke("main",
- arg1_memref_ptr,
- arg2_memref_ptr,
- arg3_memref_ptr)
- # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
- log("{0} + {1} = {2}".format(arg1, arg2, arg3))
-
- # test to-numpy utility
- # CHECK: [12.+14.j]
- npout = unranked_memref_to_numpy(arg3_memref_ptr[0],
- np.dtype(np.complex64))
- log(npout)
+ } """
+ )
+
+ arg1 = np.array([5.0 + 6.0j]).astype(np.complex64)
+ arg2 = np.array([7.0 + 8.0j]).astype(np.complex64)
+ arg3 = np.array([0.0 + 0.0j]).astype(np.complex64)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_unranked_memref_descriptor(arg1))
+ )
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_unranked_memref_descriptor(arg2))
+ )
+ arg3_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_unranked_memref_descriptor(arg3))
+ )
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke(
+ "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
+ )
+ # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
+ log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+ # test to-numpy utility
+ # CHECK: [12.+14.j]
+ npout = unranked_memref_to_numpy(arg3_memref_ptr[0], np.dtype(np.complex64))
+ log(npout)
run(testComplexUnrankedMemrefAdd)
@@ -408,8 +443,9 @@ def testComplexUnrankedMemrefAdd():
# Test addition of two 2d_memref
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
def testDynamicMemrefAdd2D():
- with Context():
- module = Module.parse("""
+ with Context():
+ module = Module.parse(
+ """
module {
func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
@@ -441,23 +477,28 @@ def testDynamicMemrefAdd2D():
return
}
}
- """)
- arg1 = np.random.randn(2, 2).astype(np.float32)
- arg2 = np.random.randn(2, 2).astype(np.float32)
- res = np.random.randn(2, 2).astype(np.float32)
-
- arg1_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg1)))
- arg2_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg2)))
- res_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(res)))
-
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- execution_engine.invoke("memref_add_2d", arg1_memref_ptr, arg2_memref_ptr,
- res_memref_ptr)
- # CHECK: True
- log(np.allclose(arg1 + arg2, res))
+ """
+ )
+ arg1 = np.random.randn(2, 2).astype(np.float32)
+ arg2 = np.random.randn(2, 2).astype(np.float32)
+ res = np.random.randn(2, 2).astype(np.float32)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1))
+ )
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2))
+ )
+ res_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(res))
+ )
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke(
+ "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
+ )
+ # CHECK: True
+ log(np.allclose(arg1 + arg2, res))
run(testDynamicMemrefAdd2D)
@@ -466,8 +507,9 @@ def testDynamicMemrefAdd2D():
# Test loading of shared libraries.
# CHECK-LABEL: TEST: testSharedLibLoad
def testSharedLibLoad():
- with Context():
- module = Module.parse("""
+ with Context():
+ module = Module.parse(
+ """
module {
func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
%c0 = arith.constant 0 : index
@@ -478,35 +520,36 @@ def testSharedLibLoad():
return
}
func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
- } """)
- arg0 = np.array([0.0]).astype(np.float32)
-
- arg0_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg0)))
-
- if sys.platform == 'win32':
- shared_libs = [
- "../../../../bin/mlir_runner_utils.dll",
- "../../../../bin/mlir_c_runner_utils.dll"
- ]
- elif sys.platform == 'darwin':
- shared_libs = [
- "../../../../lib/libmlir_runner_utils.dylib",
- "../../../../lib/libmlir_c_runner_utils.dylib"
- ]
- else:
- shared_libs = [
- "../../../../lib/libmlir_runner_utils.so",
- "../../../../lib/libmlir_c_runner_utils.so"
- ]
-
- execution_engine = ExecutionEngine(
- lowerToLLVM(module),
- opt_level=3,
- shared_libs=shared_libs)
- execution_engine.invoke("main", arg0_memref_ptr)
- # CHECK: Unranked Memref
- # CHECK-NEXT: [42]
+ } """
+ )
+ arg0 = np.array([0.0]).astype(np.float32)
+
+ arg0_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg0))
+ )
+
+ if sys.platform == "win32":
+ shared_libs = [
+ "../../../../bin/mlir_runner_utils.dll",
+ "../../../../bin/mlir_c_runner_utils.dll",
+ ]
+ elif sys.platform == "darwin":
+ shared_libs = [
+ "../../../../lib/libmlir_runner_utils.dylib",
+ "../../../../lib/libmlir_c_runner_utils.dylib",
+ ]
+ else:
+ shared_libs = [
+ "../../../../lib/libmlir_runner_utils.so",
+ "../../../../lib/libmlir_c_runner_utils.so",
+ ]
+
+ execution_engine = ExecutionEngine(
+ lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
+ )
+ execution_engine.invoke("main", arg0_memref_ptr)
+ # CHECK: Unranked Memref
+ # CHECK-NEXT: [42]
run(testSharedLibLoad)
@@ -515,8 +558,9 @@ def testSharedLibLoad():
# Test that nano time clock is available.
# CHECK-LABEL: TEST: testNanoTime
def testNanoTime():
- with Context():
- module = Module.parse("""
+ with Context():
+ module = Module.parse(
+ """
module {
func.func @main() attributes { llvm.emit_c_interface } {
%now = call @nanoTime() : () -> i64
@@ -529,26 +573,26 @@ def testNanoTime():
}
func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
- }""")
-
- if sys.platform == 'win32':
- shared_libs = [
- "../../../../bin/mlir_runner_utils.dll",
- "../../../../bin/mlir_c_runner_utils.dll"
- ]
- else:
- shared_libs = [
- "../../../../lib/libmlir_runner_utils.so",
- "../../../../lib/libmlir_c_runner_utils.so"
- ]
-
- execution_engine = ExecutionEngine(
- lowerToLLVM(module),
- opt_level=3,
- shared_libs=shared_libs)
- execution_engine.invoke("main")
- # CHECK: Unranked Memref
- # CHECK: [{{.*}}]
+ }"""
+ )
+
+ if sys.platform == "win32":
+ shared_libs = [
+ "../../../../bin/mlir_runner_utils.dll",
+ "../../../../bin/mlir_c_runner_utils.dll",
+ ]
+ else:
+ shared_libs = [
+ "../../../../lib/libmlir_runner_utils.so",
+ "../../../../lib/libmlir_c_runner_utils.so",
+ ]
+
+ execution_engine = ExecutionEngine(
+ lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
+ )
+ execution_engine.invoke("main")
+ # CHECK: Unranked Memref
+ # CHECK: [{{.*}}]
run(testNanoTime)
@@ -557,36 +601,36 @@ def testNanoTime():
# Test that nano time clock is available.
# CHECK-LABEL: TEST: testDumpToObjectFile
def testDumpToObjectFile():
- fd, object_path = tempfile.mkstemp(suffix=".o")
+ fd, object_path = tempfile.mkstemp(suffix=".o")
- try:
- with Context():
- module = Module.parse("""
+ try:
+ with Context():
+ module = Module.parse(
+ """
module {
func.func @main() attributes { llvm.emit_c_interface } {
return
}
- }""")
+ }"""
+ )
- execution_engine = ExecutionEngine(
- lowerToLLVM(module),
- opt_level=3)
+ execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3)
- # CHECK: Object file exists: True
- print(f"Object file exists: {os.path.exists(object_path)}")
- # CHECK: Object file is empty: True
- print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
+ # CHECK: Object file exists: True
+ print(f"Object file exists: {os.path.exists(object_path)}")
+ # CHECK: Object file is empty: True
+ print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
- execution_engine.dump_to_object_file(object_path)
+ execution_engine.dump_to_object_file(object_path)
- # CHECK: Object file exists: True
- print(f"Object file exists: {os.path.exists(object_path)}")
- # CHECK: Object file is empty: False
- print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
+ # CHECK: Object file exists: True
+ print(f"Object file exists: {os.path.exists(object_path)}")
+ # CHECK: Object file is empty: False
+ print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
- finally:
- os.close(fd)
- os.remove(object_path)
+ finally:
+ os.close(fd)
+ os.remove(object_path)
run(testDumpToObjectFile)
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 2cba577b33266..f6519fb17a6b9 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -15,8 +15,8 @@
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
def log(*args):
- print(*args, file=sys.stderr)
- sys.stderr.flush()
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
elemwise_boiler = """
@@ -186,428 +186,458 @@ def log(*args):
def transform(module, boilerplate):
- # TODO: Allow cloning functions from one module to another.
- # Atm we have to resort to string concatenation.
- ops = module.operation.regions[0].blocks[0].operations
- mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
-
- pm = PassManager('builtin.module')
- pm.add("func.func(convert-linalg-to-loops)")
- pm.add("func.func(lower-affine)")
- pm.add("func.func(convert-math-to-llvm)")
- pm.add("func.func(convert-scf-to-cf)")
- pm.add("func.func(arith-expand)")
- pm.add("func.func(memref-expand)")
- pm.add("convert-vector-to-llvm")
- pm.add("finalize-memref-to-llvm")
- pm.add("convert-func-to-llvm")
- pm.add("reconcile-unrealized-casts")
- pm.run(mod.operation)
- return mod
+ # TODO: Allow cloning functions from one module to another.
+ # Atm we have to resort to string concatenation.
+ ops = module.operation.regions[0].blocks[0].operations
+ mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
+
+ pm = PassManager("builtin.module")
+ pm.add("func.func(convert-linalg-to-loops)")
+ pm.add("func.func(lower-affine)")
+ pm.add("func.func(convert-math-to-llvm)")
+ pm.add("func.func(convert-scf-to-cf)")
+ pm.add("func.func(arith-expand)")
+ pm.add("func.func(memref-expand)")
+ pm.add("convert-vector-to-llvm")
+ pm.add("finalize-memref-to-llvm")
+ pm.add("convert-func-to-llvm")
+ pm.add("reconcile-unrealized-casts")
+ pm.run(mod.operation)
+ return mod
def test_elemwise_builtin():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i8 = IntegerType.get_signless(8)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((), f32), MemRefType.get((4, 8), f32),
- MemRefType.get((4, 8), f32))
- def elemwise_exp_add_on_buffers(lhs, rhs, out):
- linalg.elemwise_unary(lhs, outs=[out])
- linalg.elemwise_binary(out, rhs, outs=[out])
-
- @func.FuncOp.from_py_func(
- MemRefType.get((), f32), MemRefType.get((4, 8), f32),
- MemRefType.get((4, 8), f32))
- def elemwise_log_mul_on_buffers(lhs, rhs, out):
- linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
- linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
-
- execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result f32.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- res = c_float_p(-1.)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
- # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
- # CHECK: RESULT: 4.71828
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ i8 = IntegerType.get_signless(8)
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((), f32),
+ MemRefType.get((4, 8), f32),
+ MemRefType.get((4, 8), f32),
+ )
+ def elemwise_exp_add_on_buffers(lhs, rhs, out):
+ linalg.elemwise_unary(lhs, outs=[out])
+ linalg.elemwise_binary(out, rhs, outs=[out])
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((), f32),
+ MemRefType.get((4, 8), f32),
+ MemRefType.get((4, 8), f32),
+ )
+ def elemwise_log_mul_on_buffers(lhs, rhs, out):
+ linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
+ linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
+
+ execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result f32.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ res = c_float_p(-1.0)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
+ # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
+ # CHECK: RESULT: 4.71828
test_elemwise_builtin()
def test_elemwise_generic():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i8 = IntegerType.get_signless(8)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((), f32), MemRefType.get((4, 8), f32),
- MemRefType.get((4, 8), f32))
- def elemwise_exp_add_on_buffers(lhs, rhs, out):
- linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
- linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
-
- @func.FuncOp.from_py_func(
- MemRefType.get((), f32), MemRefType.get((4, 8), f32),
- MemRefType.get((4, 8), f32))
- def elemwise_log_mul_on_buffers(lhs, rhs, out):
- linalg.elemwise_unary(
- lhs, outs=[out], fun=UnaryFn.log, emit_generic=True)
- linalg.elemwise_binary(
- out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True)
-
- execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result f32.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- res = c_float_p(-1.)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
- # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
- # CHECK: RESULT: 4.71828
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ i8 = IntegerType.get_signless(8)
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((), f32),
+ MemRefType.get((4, 8), f32),
+ MemRefType.get((4, 8), f32),
+ )
+ def elemwise_exp_add_on_buffers(lhs, rhs, out):
+ linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
+ linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((), f32),
+ MemRefType.get((4, 8), f32),
+ MemRefType.get((4, 8), f32),
+ )
+ def elemwise_log_mul_on_buffers(lhs, rhs, out):
+ linalg.elemwise_unary(
+ lhs, outs=[out], fun=UnaryFn.log, emit_generic=True
+ )
+ linalg.elemwise_binary(
+ out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True
+ )
+
+ execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result f32.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ res = c_float_p(-1.0)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
+ # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
+ # CHECK: RESULT: 4.71828
test_elemwise_generic()
def test_matmul_builtin():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i8 = IntegerType.get_signless(8)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32))
- def matmul_signed_on_buffers(lhs, rhs, out):
- linalg.matmul(lhs, rhs, outs=[out])
-
- @func.FuncOp.from_py_func(
- MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32))
- def matmul_unsigned_on_buffers(lhs, rhs, out):
- linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
-
- execution_engine = ExecutionEngine(transform(module, matmul_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result f32.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- res = c_float_p(-1.)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32
- # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160
- # CHECK: RESULT: 8128
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ i8 = IntegerType.get_signless(8)
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((4, 16), i8),
+ MemRefType.get((16, 8), f32),
+ MemRefType.get((4, 8), f32),
+ )
+ def matmul_signed_on_buffers(lhs, rhs, out):
+ linalg.matmul(lhs, rhs, outs=[out])
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((4, 16), i8),
+ MemRefType.get((16, 8), f32),
+ MemRefType.get((4, 8), f32),
+ )
+ def matmul_unsigned_on_buffers(lhs, rhs, out):
+ linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
+
+ execution_engine = ExecutionEngine(transform(module, matmul_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result f32.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ res = c_float_p(-1.0)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32
+ # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160
+ # CHECK: RESULT: 8128
test_matmul_builtin()
def test_matmul_generic():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i8 = IntegerType.get_signless(8)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32))
- def matmul_signed_on_buffers(lhs, rhs, out):
- linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
-
- @func.FuncOp.from_py_func(
- MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32))
- def matmul_unsigned_on_buffers(lhs, rhs, out):
- linalg.matmul(
- lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True)
-
- execution_engine = ExecutionEngine(transform(module, matmul_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result f32.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- res = c_float_p(-1.)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32
- # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160
- # CHECK: RESULT: 8128
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ i8 = IntegerType.get_signless(8)
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((4, 16), i8),
+ MemRefType.get((16, 8), f32),
+ MemRefType.get((4, 8), f32),
+ )
+ def matmul_signed_on_buffers(lhs, rhs, out):
+ linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((4, 16), i8),
+ MemRefType.get((16, 8), f32),
+ MemRefType.get((4, 8), f32),
+ )
+ def matmul_unsigned_on_buffers(lhs, rhs, out):
+ linalg.matmul(
+ lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True
+ )
+
+ execution_engine = ExecutionEngine(transform(module, matmul_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result f32.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ res = c_float_p(-1.0)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32
+ # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160
+ # CHECK: RESULT: 8128
test_matmul_generic()
def test_fill_builtin():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
- def fill_0d_on_buffers(value, out):
- linalg.fill(value, outs=[out])
+ @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+ def fill_0d_on_buffers(value, out):
+ linalg.fill(value, outs=[out])
- @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
- def fill_1d_on_buffers(value, out):
- linalg.fill(value, outs=[out])
+ @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+ def fill_1d_on_buffers(value, out):
+ linalg.fill(value, outs=[out])
- @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
- def fill_2d_on_buffers(value, out):
- linalg.fill(value, outs=[out])
+ @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+ def fill_2d_on_buffers(value, out):
+ linalg.fill(value, outs=[out])
- execution_engine = ExecutionEngine(transform(module, fill_boiler))
+ execution_engine = ExecutionEngine(transform(module, fill_boiler))
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result i32.
- # Arguments must be passed as pointers.
- c_int_p = ctypes.c_int * 1
- res = c_int_p(-1)
- execution_engine.invoke("main", res)
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
- log("RESULT: ", res[0])
- # CHECK: RESULT: 6
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: 6
test_fill_builtin()
def test_fill_generic():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
- def fill_0d_on_buffers(value, out):
- linalg.fill(value, outs=[out], emit_generic=True)
+ @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+ def fill_0d_on_buffers(value, out):
+ linalg.fill(value, outs=[out], emit_generic=True)
- @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
- def fill_1d_on_buffers(value, out):
- linalg.fill(value, outs=[out], emit_generic=True)
+ @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+ def fill_1d_on_buffers(value, out):
+ linalg.fill(value, outs=[out], emit_generic=True)
- @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
- def fill_2d_on_buffers(value, out):
- linalg.fill(value, outs=[out], emit_generic=True)
+ @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+ def fill_2d_on_buffers(value, out):
+ linalg.fill(value, outs=[out], emit_generic=True)
- execution_engine = ExecutionEngine(transform(module, fill_boiler))
+ execution_engine = ExecutionEngine(transform(module, fill_boiler))
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result i32.
- # Arguments must be passed as pointers.
- c_int_p = ctypes.c_int * 1
- res = c_int_p(-1)
- execution_engine.invoke("main", res)
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
- log("RESULT: ", res[0])
- # CHECK: RESULT: 6
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: 6
test_fill_generic()
def test_fill_rng_builtin():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f64 = F64Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
- def fill_rng_on_buffers(min, max, seed, out):
- linalg.fill_rng_2d(min, max, seed, outs=[out])
+ @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
+ def fill_rng_on_buffers(min, max, seed, out):
+ linalg.fill_rng_2d(min, max, seed, outs=[out])
- execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
+ execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result i32.
- # Arguments must be passed as pointers.
- c_int_p = ctypes.c_int * 1
- res = c_int_p(-1)
- execution_engine.invoke("main", res)
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
- log("RESULT: ", res[0])
- # CHECK: RESULT: -480
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: -480
test_fill_rng_builtin()
def test_fill_rng_generic():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f64 = F64Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
- def fill_rng_on_buffers(min, max, seed, out):
- linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
+ @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
+ def fill_rng_on_buffers(min, max, seed, out):
+ linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
- execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
+ execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result i32.
- # Arguments must be passed as pointers.
- c_int_p = ctypes.c_int * 1
- res = c_int_p(-1)
- execution_engine.invoke("main", res)
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
- log("RESULT: ", res[0])
- # CHECK: RESULT: -480
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: -480
test_fill_rng_generic()
def test_max_pooling_builtin():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f64 = F64Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
- MemRefType.get((1, 2, 4, 1), i32))
- def pooling_on_buffers(input, shape, output):
- linalg.pooling_nhwc_max(
- input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
-
- execution_engine = ExecutionEngine(transform(module, pooling_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result i32.
- # Arguments must be passed as pointers.
- c_int_p = ctypes.c_int * 1
- res = c_int_p(-1)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # 77 is not selected due to the dilation 2 in the second dimension.
- # CHECK: RESULT: 42
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((1, 4, 16, 1), f64),
+ MemRefType.get((2, 2), f64),
+ MemRefType.get((1, 2, 4, 1), i32),
+ )
+ def pooling_on_buffers(input, shape, output):
+ linalg.pooling_nhwc_max(
+ input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]
+ )
+
+ execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # 77 is not selected due to the dilation 2 in the second dimension.
+ # CHECK: RESULT: 42
test_max_pooling_builtin()
def test_max_pooling_generic():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f64 = F64Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
- MemRefType.get((1, 2, 4, 1), i32))
- def pooling_on_buffers(input, shape, output):
- linalg.pooling_nhwc_max(
- input,
- shape,
- outs=[output],
- strides=[2, 4],
- dilations=[1, 2],
- emit_generic=True)
-
- execution_engine = ExecutionEngine(transform(module, pooling_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result i32.
- # Arguments must be passed as pointers.
- c_int_p = ctypes.c_int * 1
- res = c_int_p(-1)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # 77 is not selected due to the dilation 2 in the second dimension.
- # CHECK: RESULT: 42
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((1, 4, 16, 1), f64),
+ MemRefType.get((2, 2), f64),
+ MemRefType.get((1, 2, 4, 1), i32),
+ )
+ def pooling_on_buffers(input, shape, output):
+ linalg.pooling_nhwc_max(
+ input,
+ shape,
+ outs=[output],
+ strides=[2, 4],
+ dilations=[1, 2],
+ emit_generic=True,
+ )
+
+ execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # 77 is not selected due to the dilation 2 in the second dimension.
+ # CHECK: RESULT: 42
test_max_pooling_generic()
def test_min_pooling_builtin():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f64 = F64Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
- MemRefType.get((1, 2, 4, 1), i32))
- # Set the strides and use the default dilations.
- def pooling_on_buffers(input, shape, output):
- linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
-
- execution_engine = ExecutionEngine(transform(module, pooling_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result i32.
- # Arguments must be passed as pointers.
- c_int_p = ctypes.c_int * 1
- res = c_int_p(-1)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # CHECK: RESULT: -13
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((1, 4, 16, 1), f64),
+ MemRefType.get((2, 2), f64),
+ MemRefType.get((1, 2, 4, 1), i32),
+ )
+ # Set the strides and use the default dilations.
+ def pooling_on_buffers(input, shape, output):
+ linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
+
+ execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: -13
test_min_pooling_builtin()
def test_min_pooling_generic():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f64 = F64Type.get()
- i32 = IntegerType.get_signless(32)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
- MemRefType.get((1, 2, 4, 1), i32))
- # Set the strides and use the default dilations.
- def pooling_on_buffers(input, shape, output):
- linalg.pooling_nhwc_min(
- input, shape, outs=[output], strides=[2, 4], emit_generic=True)
-
- execution_engine = ExecutionEngine(transform(module, pooling_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result i32.
- # Arguments must be passed as pointers.
- c_int_p = ctypes.c_int * 1
- res = c_int_p(-1)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # CHECK: RESULT: -13
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ MemRefType.get((1, 4, 16, 1), f64),
+ MemRefType.get((2, 2), f64),
+ MemRefType.get((1, 2, 4, 1), i32),
+ )
+ # Set the strides and use the default dilations.
+ def pooling_on_buffers(input, shape, output):
+ linalg.pooling_nhwc_min(
+ input, shape, outs=[output], strides=[2, 4], emit_generic=True
+ )
+
+ execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: -13
test_min_pooling_generic()
diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py
index 6a3a6fcc65e1b..63564303e8315 100644
--- a/mlir/test/python/ir/affine_expr.py
+++ b/mlir/test/python/ir/affine_expr.py
@@ -3,59 +3,61 @@
import gc
from mlir.ir import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testAffineExprCapsule
@run
def testAffineExprCapsule():
- with Context() as ctx:
- affine_expr = AffineExpr.get_constant(42)
+ with Context() as ctx:
+ affine_expr = AffineExpr.get_constant(42)
- affine_expr_capsule = affine_expr._CAPIPtr
- # CHECK: capsule object
- # CHECK: mlir.ir.AffineExpr._CAPIPtr
- print(affine_expr_capsule)
+ affine_expr_capsule = affine_expr._CAPIPtr
+ # CHECK: capsule object
+ # CHECK: mlir.ir.AffineExpr._CAPIPtr
+ print(affine_expr_capsule)
- affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
- assert affine_expr == affine_expr_2
- assert affine_expr_2.context == ctx
+ affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
+ assert affine_expr == affine_expr_2
+ assert affine_expr_2.context == ctx
# CHECK-LABEL: TEST: testAffineExprEq
@run
def testAffineExprEq():
- with Context():
- a1 = AffineExpr.get_constant(42)
- a2 = AffineExpr.get_constant(42)
- a3 = AffineExpr.get_constant(43)
- # CHECK: True
- print(a1 == a1)
- # CHECK: True
- print(a1 == a2)
- # CHECK: False
- print(a1 == a3)
- # CHECK: False
- print(a1 == None)
- # CHECK: False
- print(a1 == "foo")
+ with Context():
+ a1 = AffineExpr.get_constant(42)
+ a2 = AffineExpr.get_constant(42)
+ a3 = AffineExpr.get_constant(43)
+ # CHECK: True
+ print(a1 == a1)
+ # CHECK: True
+ print(a1 == a2)
+ # CHECK: False
+ print(a1 == a3)
+ # CHECK: False
+ print(a1 == None)
+ # CHECK: False
+ print(a1 == "foo")
# CHECK-LABEL: TEST: testAffineExprContext
@run
def testAffineExprContext():
- with Context():
- a1 = AffineExpr.get_constant(42)
- with Context():
- a2 = AffineExpr.get_constant(42)
+ with Context():
+ a1 = AffineExpr.get_constant(42)
+ with Context():
+ a2 = AffineExpr.get_constant(42)
+
+ # CHECK: False
+ print(a1 == a2)
- # CHECK: False
- print(a1 == a2)
run(testAffineExprContext)
@@ -63,340 +65,343 @@ def testAffineExprContext():
# CHECK-LABEL: TEST: testAffineExprConstant
@run
def testAffineExprConstant():
- with Context():
- a1 = AffineExpr.get_constant(42)
- # CHECK: 42
- print(a1.value)
- # CHECK: 42
- print(a1)
+ with Context():
+ a1 = AffineExpr.get_constant(42)
+ # CHECK: 42
+ print(a1.value)
+ # CHECK: 42
+ print(a1)
- a2 = AffineConstantExpr.get(42)
- # CHECK: 42
- print(a2.value)
- # CHECK: 42
- print(a2)
+ a2 = AffineConstantExpr.get(42)
+ # CHECK: 42
+ print(a2.value)
+ # CHECK: 42
+ print(a2)
- assert a1 == a2
+ assert a1 == a2
# CHECK-LABEL: TEST: testAffineExprDim
@run
def testAffineExprDim():
- with Context():
- d1 = AffineExpr.get_dim(1)
- d11 = AffineDimExpr.get(1)
- d2 = AffineDimExpr.get(2)
+ with Context():
+ d1 = AffineExpr.get_dim(1)
+ d11 = AffineDimExpr.get(1)
+ d2 = AffineDimExpr.get(2)
- # CHECK: 1
- print(d1.position)
- # CHECK: d1
- print(d1)
+ # CHECK: 1
+ print(d1.position)
+ # CHECK: d1
+ print(d1)
- # CHECK: 2
- print(d2.position)
- # CHECK: d2
- print(d2)
+ # CHECK: 2
+ print(d2.position)
+ # CHECK: d2
+ print(d2)
- assert d1 == d11
- assert d1 != d2
+ assert d1 == d11
+ assert d1 != d2
# CHECK-LABEL: TEST: testAffineExprSymbol
@run
def testAffineExprSymbol():
- with Context():
- s1 = AffineExpr.get_symbol(1)
- s11 = AffineSymbolExpr.get(1)
- s2 = AffineSymbolExpr.get(2)
+ with Context():
+ s1 = AffineExpr.get_symbol(1)
+ s11 = AffineSymbolExpr.get(1)
+ s2 = AffineSymbolExpr.get(2)
- # CHECK: 1
- print(s1.position)
- # CHECK: s1
- print(s1)
+ # CHECK: 1
+ print(s1.position)
+ # CHECK: s1
+ print(s1)
- # CHECK: 2
- print(s2.position)
- # CHECK: s2
- print(s2)
+ # CHECK: 2
+ print(s2.position)
+ # CHECK: s2
+ print(s2)
- assert s1 == s11
- assert s1 != s2
+ assert s1 == s11
+ assert s1 != s2
# CHECK-LABEL: TEST: testAffineAddExpr
@run
def testAffineAddExpr():
- with Context():
- d1 = AffineDimExpr.get(1)
- d2 = AffineDimExpr.get(2)
- d12 = AffineExpr.get_add(d1, d2)
- # CHECK: d1 + d2
- print(d12)
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ d2 = AffineDimExpr.get(2)
+ d12 = AffineExpr.get_add(d1, d2)
+ # CHECK: d1 + d2
+ print(d12)
- d12op = d1 + d2
- # CHECK: d1 + d2
- print(d12op)
+ d12op = d1 + d2
+ # CHECK: d1 + d2
+ print(d12op)
- d1cst_op = d1 + 2
- # CHECK: d1 + 2
- print(d1cst_op)
+ d1cst_op = d1 + 2
+ # CHECK: d1 + 2
+ print(d1cst_op)
- d1cst_op2 = 2 + d1
- # CHECK: d1 + 2
- print(d1cst_op2)
+ d1cst_op2 = 2 + d1
+ # CHECK: d1 + 2
+ print(d1cst_op2)
- assert d12 == d12op
- assert d12.lhs == d1
- assert d12.rhs == d2
+ assert d12 == d12op
+ assert d12.lhs == d1
+ assert d12.rhs == d2
# CHECK-LABEL: TEST: testAffineMulExpr
@run
def testAffineMulExpr():
- with Context():
- d1 = AffineDimExpr.get(1)
- c2 = AffineConstantExpr.get(2)
- expr = AffineExpr.get_mul(d1, c2)
- # CHECK: d1 * 2
- print(expr)
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ expr = AffineExpr.get_mul(d1, c2)
+ # CHECK: d1 * 2
+ print(expr)
- # CHECK: d1 * 2
- op = d1 * c2
- print(op)
+ # CHECK: d1 * 2
+ op = d1 * c2
+ print(op)
- # CHECK: d1 * 2
- op_cst = d1 * 2
- print(op_cst)
+ # CHECK: d1 * 2
+ op_cst = d1 * 2
+ print(op_cst)
- # CHECK: d1 * 2
- op_cst2 = 2 * d1
- print(op_cst2)
+ # CHECK: d1 * 2
+ op_cst2 = 2 * d1
+ print(op_cst2)
- assert expr == op
- assert expr == op_cst
- assert expr.lhs == d1
- assert expr.rhs == c2
+ assert expr == op
+ assert expr == op_cst
+ assert expr.lhs == d1
+ assert expr.rhs == c2
# CHECK-LABEL: TEST: testAffineModExpr
@run
def testAffineModExpr():
- with Context():
- d1 = AffineDimExpr.get(1)
- c2 = AffineConstantExpr.get(2)
- expr = AffineExpr.get_mod(d1, c2)
- # CHECK: d1 mod 2
- print(expr)
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ expr = AffineExpr.get_mod(d1, c2)
+ # CHECK: d1 mod 2
+ print(expr)
- # CHECK: d1 mod 2
- op = d1 % c2
- print(op)
+ # CHECK: d1 mod 2
+ op = d1 % c2
+ print(op)
- # CHECK: d1 mod 2
- op_cst = d1 % 2
- print(op_cst)
+ # CHECK: d1 mod 2
+ op_cst = d1 % 2
+ print(op_cst)
- # CHECK: 2 mod d1
- print(2 % d1)
+ # CHECK: 2 mod d1
+ print(2 % d1)
- assert expr == op
- assert expr == op_cst
- assert expr.lhs == d1
- assert expr.rhs == c2
+ assert expr == op
+ assert expr == op_cst
+ assert expr.lhs == d1
+ assert expr.rhs == c2
- expr2 = AffineExpr.get_mod(c2, d1)
- expr3 = AffineExpr.get_mod(2, d1)
- expr4 = AffineExpr.get_mod(d1, 2)
+ expr2 = AffineExpr.get_mod(c2, d1)
+ expr3 = AffineExpr.get_mod(2, d1)
+ expr4 = AffineExpr.get_mod(d1, 2)
- # CHECK: 2 mod d1
- print(expr2)
- # CHECK: 2 mod d1
- print(expr3)
- # CHECK: d1 mod 2
- print(expr4)
+ # CHECK: 2 mod d1
+ print(expr2)
+ # CHECK: 2 mod d1
+ print(expr3)
+ # CHECK: d1 mod 2
+ print(expr4)
- assert expr2 == expr3
- assert expr4 == expr
+ assert expr2 == expr3
+ assert expr4 == expr
# CHECK-LABEL: TEST: testAffineFloorDivExpr
@run
def testAffineFloorDivExpr():
- with Context():
- d1 = AffineDimExpr.get(1)
- c2 = AffineConstantExpr.get(2)
- expr = AffineExpr.get_floor_div(d1, c2)
- # CHECK: d1 floordiv 2
- print(expr)
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ expr = AffineExpr.get_floor_div(d1, c2)
+ # CHECK: d1 floordiv 2
+ print(expr)
- assert expr.lhs == d1
- assert expr.rhs == c2
+ assert expr.lhs == d1
+ assert expr.rhs == c2
- expr2 = AffineExpr.get_floor_div(c2, d1)
- expr3 = AffineExpr.get_floor_div(2, d1)
- expr4 = AffineExpr.get_floor_div(d1, 2)
+ expr2 = AffineExpr.get_floor_div(c2, d1)
+ expr3 = AffineExpr.get_floor_div(2, d1)
+ expr4 = AffineExpr.get_floor_div(d1, 2)
- # CHECK: 2 floordiv d1
- print(expr2)
- # CHECK: 2 floordiv d1
- print(expr3)
- # CHECK: d1 floordiv 2
- print(expr4)
+ # CHECK: 2 floordiv d1
+ print(expr2)
+ # CHECK: 2 floordiv d1
+ print(expr3)
+ # CHECK: d1 floordiv 2
+ print(expr4)
- assert expr2 == expr3
- assert expr4 == expr
+ assert expr2 == expr3
+ assert expr4 == expr
# CHECK-LABEL: TEST: testAffineCeilDivExpr
@run
def testAffineCeilDivExpr():
- with Context():
- d1 = AffineDimExpr.get(1)
- c2 = AffineConstantExpr.get(2)
- expr = AffineExpr.get_ceil_div(d1, c2)
- # CHECK: d1 ceildiv 2
- print(expr)
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ expr = AffineExpr.get_ceil_div(d1, c2)
+ # CHECK: d1 ceildiv 2
+ print(expr)
- assert expr.lhs == d1
- assert expr.rhs == c2
+ assert expr.lhs == d1
+ assert expr.rhs == c2
- expr2 = AffineExpr.get_ceil_div(c2, d1)
- expr3 = AffineExpr.get_ceil_div(2, d1)
- expr4 = AffineExpr.get_ceil_div(d1, 2)
+ expr2 = AffineExpr.get_ceil_div(c2, d1)
+ expr3 = AffineExpr.get_ceil_div(2, d1)
+ expr4 = AffineExpr.get_ceil_div(d1, 2)
- # CHECK: 2 ceildiv d1
- print(expr2)
- # CHECK: 2 ceildiv d1
- print(expr3)
- # CHECK: d1 ceildiv 2
- print(expr4)
+ # CHECK: 2 ceildiv d1
+ print(expr2)
+ # CHECK: 2 ceildiv d1
+ print(expr3)
+ # CHECK: d1 ceildiv 2
+ print(expr4)
- assert expr2 == expr3
- assert expr4 == expr
+ assert expr2 == expr3
+ assert expr4 == expr
# CHECK-LABEL: TEST: testAffineExprSub
@run
def testAffineExprSub():
- with Context():
- d1 = AffineDimExpr.get(1)
- d2 = AffineDimExpr.get(2)
- expr = d1 - d2
- # CHECK: d1 - d2
- print(expr)
-
- assert expr.lhs == d1
- rhs = AffineMulExpr(expr.rhs)
- # CHECK: d2
- print(rhs.lhs)
- # CHECK: -1
- print(rhs.rhs)
-
- # CHECK: d1 - 42
- print(d1 - 42)
- # CHECK: -d1 + 42
- print(42 - d1)
-
- c42 = AffineConstantExpr.get(42)
- assert d1 - 42 == d1 - c42
- assert 42 - d1 == c42 - d1
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ d2 = AffineDimExpr.get(2)
+ expr = d1 - d2
+ # CHECK: d1 - d2
+ print(expr)
+
+ assert expr.lhs == d1
+ rhs = AffineMulExpr(expr.rhs)
+ # CHECK: d2
+ print(rhs.lhs)
+ # CHECK: -1
+ print(rhs.rhs)
+
+ # CHECK: d1 - 42
+ print(d1 - 42)
+ # CHECK: -d1 + 42
+ print(42 - d1)
+
+ c42 = AffineConstantExpr.get(42)
+ assert d1 - 42 == d1 - c42
+ assert 42 - d1 == c42 - d1
+
# CHECK-LABEL: TEST: testClassHierarchy
@run
def testClassHierarchy():
- with Context():
- d1 = AffineDimExpr.get(1)
- c2 = AffineConstantExpr.get(2)
- add = AffineAddExpr.get(d1, c2)
- mul = AffineMulExpr.get(d1, c2)
- mod = AffineModExpr.get(d1, c2)
- floor_div = AffineFloorDivExpr.get(d1, c2)
- ceil_div = AffineCeilDivExpr.get(d1, c2)
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ add = AffineAddExpr.get(d1, c2)
+ mul = AffineMulExpr.get(d1, c2)
+ mod = AffineModExpr.get(d1, c2)
+ floor_div = AffineFloorDivExpr.get(d1, c2)
+ ceil_div = AffineCeilDivExpr.get(d1, c2)
+
+ # CHECK: False
+ print(isinstance(d1, AffineBinaryExpr))
+ # CHECK: False
+ print(isinstance(c2, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(add, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(mul, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(mod, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(floor_div, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(ceil_div, AffineBinaryExpr))
+
+ try:
+ AffineBinaryExpr(d1)
+ except ValueError as e:
+ # CHECK: Cannot cast affine expression to AffineBinaryExpr
+ print(e)
+
+ try:
+ AffineBinaryExpr(c2)
+ except ValueError as e:
+ # CHECK: Cannot cast affine expression to AffineBinaryExpr
+ print(e)
- # CHECK: False
- print(isinstance(d1, AffineBinaryExpr))
- # CHECK: False
- print(isinstance(c2, AffineBinaryExpr))
- # CHECK: True
- print(isinstance(add, AffineBinaryExpr))
- # CHECK: True
- print(isinstance(mul, AffineBinaryExpr))
- # CHECK: True
- print(isinstance(mod, AffineBinaryExpr))
- # CHECK: True
- print(isinstance(floor_div, AffineBinaryExpr))
- # CHECK: True
- print(isinstance(ceil_div, AffineBinaryExpr))
-
- try:
- AffineBinaryExpr(d1)
- except ValueError as e:
- # CHECK: Cannot cast affine expression to AffineBinaryExpr
- print(e)
-
- try:
- AffineBinaryExpr(c2)
- except ValueError as e:
- # CHECK: Cannot cast affine expression to AffineBinaryExpr
- print(e)
# CHECK-LABEL: TEST: testIsInstance
@run
def testIsInstance():
- with Context():
- d1 = AffineDimExpr.get(1)
- c2 = AffineConstantExpr.get(2)
- add = AffineAddExpr.get(d1, c2)
- mul = AffineMulExpr.get(d1, c2)
-
- # CHECK: True
- print(AffineDimExpr.isinstance(d1))
- # CHECK: False
- print(AffineConstantExpr.isinstance(d1))
- # CHECK: True
- print(AffineConstantExpr.isinstance(c2))
- # CHECK: False
- print(AffineMulExpr.isinstance(c2))
- # CHECK: True
- print(AffineAddExpr.isinstance(add))
- # CHECK: False
- print(AffineMulExpr.isinstance(add))
- # CHECK: True
- print(AffineMulExpr.isinstance(mul))
- # CHECK: False
- print(AffineAddExpr.isinstance(mul))
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ add = AffineAddExpr.get(d1, c2)
+ mul = AffineMulExpr.get(d1, c2)
+
+ # CHECK: True
+ print(AffineDimExpr.isinstance(d1))
+ # CHECK: False
+ print(AffineConstantExpr.isinstance(d1))
+ # CHECK: True
+ print(AffineConstantExpr.isinstance(c2))
+ # CHECK: False
+ print(AffineMulExpr.isinstance(c2))
+ # CHECK: True
+ print(AffineAddExpr.isinstance(add))
+ # CHECK: False
+ print(AffineMulExpr.isinstance(add))
+ # CHECK: True
+ print(AffineMulExpr.isinstance(mul))
+ # CHECK: False
+ print(AffineAddExpr.isinstance(mul))
# CHECK-LABEL: TEST: testCompose
@run
def testCompose():
- with Context():
- # d0 + d2.
- expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2))
+ with Context():
+ # d0 + d2.
+ expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2))
- # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)
- map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1))
- map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0))
- map3 = AffineAddExpr.get(
- AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)),
- AffineDimExpr.get(2))
- map = AffineMap.get(3, 2, [map1, map2, map3])
+ # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)
+ map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1))
+ map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0))
+ map3 = AffineAddExpr.get(
+ AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)),
+ AffineDimExpr.get(2),
+ )
+ map = AffineMap.get(3, 2, [map1, map2, map3])
- # CHECK: d0 + s1 + d0 + d1 + d2
- print(expr.compose(map))
+ # CHECK: d0 + s1 + d0 + d1 + d2
+ print(expr.compose(map))
# CHECK-LABEL: TEST: testHash
@run
def testHash():
- with Context():
- d0 = AffineDimExpr.get(0)
- s1 = AffineSymbolExpr.get(1)
- assert hash(d0) == hash(AffineDimExpr.get(0))
- assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1))
-
- dictionary = dict()
- dictionary[d0] = 0
- dictionary[s1] = 1
- assert d0 in dictionary
- assert s1 in dictionary
+ with Context():
+ d0 = AffineDimExpr.get(0)
+ s1 = AffineSymbolExpr.get(1)
+ assert hash(d0) == hash(AffineDimExpr.get(0))
+ assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1))
+
+ dictionary = dict()
+ dictionary[d0] = 0
+ dictionary[s1] = 1
+ assert d0 in dictionary
+ assert s1 in dictionary
diff --git a/mlir/test/python/ir/affine_map.py b/mlir/test/python/ir/affine_map.py
index 52c7261500c90..672335e9bf8a7 100644
--- a/mlir/test/python/ir/affine_map.py
+++ b/mlir/test/python/ir/affine_map.py
@@ -5,237 +5,241 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testAffineMapCapsule
@run
def testAffineMapCapsule():
- with Context() as ctx:
- am1 = AffineMap.get_empty(ctx)
- # CHECK: mlir.ir.AffineMap._CAPIPtr
- affine_map_capsule = am1._CAPIPtr
- print(affine_map_capsule)
- am2 = AffineMap._CAPICreate(affine_map_capsule)
- assert am2 == am1
- assert am2.context is ctx
+ with Context() as ctx:
+ am1 = AffineMap.get_empty(ctx)
+ # CHECK: mlir.ir.AffineMap._CAPIPtr
+ affine_map_capsule = am1._CAPIPtr
+ print(affine_map_capsule)
+ am2 = AffineMap._CAPICreate(affine_map_capsule)
+ assert am2 == am1
+ assert am2.context is ctx
# CHECK-LABEL: TEST: testAffineMapGet
@run
def testAffineMapGet():
- with Context() as ctx:
- d0 = AffineDimExpr.get(0)
- d1 = AffineDimExpr.get(1)
- c2 = AffineConstantExpr.get(2)
-
- # CHECK: (d0, d1)[s0, s1, s2] -> ()
- map0 = AffineMap.get(2, 3, [])
- print(map0)
-
- # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
- map1 = AffineMap.get(2, 3, [d1, c2])
- print(map1)
-
- # CHECK: () -> (2)
- map2 = AffineMap.get(0, 0, [c2])
- print(map2)
-
- # CHECK: (d0, d1) -> (d0, d1)
- map3 = AffineMap.get(2, 0, [d0, d1])
- print(map3)
-
- # CHECK: (d0, d1) -> (d1)
- map4 = AffineMap.get(2, 0, [d1])
- print(map4)
-
- # CHECK: (d0, d1, d2) -> (d2, d0, d1)
- map5 = AffineMap.get_permutation([2, 0, 1])
- print(map5)
-
- assert map1 == AffineMap.get(2, 3, [d1, c2])
- assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
- assert map2 == AffineMap.get_constant(2)
- assert map3 == AffineMap.get_identity(2)
- assert map4 == AffineMap.get_minor_identity(2, 1)
-
- try:
- AffineMap.get(1, 1, [1])
- except RuntimeError as e:
- # CHECK: Invalid expression when attempting to create an AffineMap
- print(e)
-
- try:
- AffineMap.get(1, 1, [None])
- except RuntimeError as e:
- # CHECK: Invalid expression (None?) when attempting to create an AffineMap
- print(e)
-
- try:
- AffineMap.get_permutation([1, 0, 1])
- except RuntimeError as e:
- # CHECK: Invalid permutation when attempting to create an AffineMap
- print(e)
-
- try:
- map3.get_submap([42])
- except ValueError as e:
- # CHECK: result position out of bounds
- print(e)
-
- try:
- map3.get_minor_submap(42)
- except ValueError as e:
- # CHECK: number of results out of bounds
- print(e)
-
- try:
- map3.get_major_submap(42)
- except ValueError as e:
- # CHECK: number of results out of bounds
- print(e)
+ with Context() as ctx:
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+
+ # CHECK: (d0, d1)[s0, s1, s2] -> ()
+ map0 = AffineMap.get(2, 3, [])
+ print(map0)
+
+ # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
+ map1 = AffineMap.get(2, 3, [d1, c2])
+ print(map1)
+
+ # CHECK: () -> (2)
+ map2 = AffineMap.get(0, 0, [c2])
+ print(map2)
+
+ # CHECK: (d0, d1) -> (d0, d1)
+ map3 = AffineMap.get(2, 0, [d0, d1])
+ print(map3)
+
+ # CHECK: (d0, d1) -> (d1)
+ map4 = AffineMap.get(2, 0, [d1])
+ print(map4)
+
+ # CHECK: (d0, d1, d2) -> (d2, d0, d1)
+ map5 = AffineMap.get_permutation([2, 0, 1])
+ print(map5)
+
+ assert map1 == AffineMap.get(2, 3, [d1, c2])
+ assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
+ assert map2 == AffineMap.get_constant(2)
+ assert map3 == AffineMap.get_identity(2)
+ assert map4 == AffineMap.get_minor_identity(2, 1)
+
+ try:
+ AffineMap.get(1, 1, [1])
+ except RuntimeError as e:
+ # CHECK: Invalid expression when attempting to create an AffineMap
+ print(e)
+
+ try:
+ AffineMap.get(1, 1, [None])
+ except RuntimeError as e:
+ # CHECK: Invalid expression (None?) when attempting to create an AffineMap
+ print(e)
+
+ try:
+ AffineMap.get_permutation([1, 0, 1])
+ except RuntimeError as e:
+ # CHECK: Invalid permutation when attempting to create an AffineMap
+ print(e)
+
+ try:
+ map3.get_submap([42])
+ except ValueError as e:
+ # CHECK: result position out of bounds
+ print(e)
+
+ try:
+ map3.get_minor_submap(42)
+ except ValueError as e:
+ # CHECK: number of results out of bounds
+ print(e)
+
+ try:
+ map3.get_major_submap(42)
+ except ValueError as e:
+ # CHECK: number of results out of bounds
+ print(e)
# CHECK-LABEL: TEST: testAffineMapDerive
@run
def testAffineMapDerive():
- with Context() as ctx:
- map5 = AffineMap.get_identity(5)
+ with Context() as ctx:
+ map5 = AffineMap.get_identity(5)
- # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
- map123 = map5.get_submap([1, 2, 3])
- print(map123)
+ # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
+ map123 = map5.get_submap([1, 2, 3])
+ print(map123)
- # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
- map01 = map5.get_major_submap(2)
- print(map01)
+ # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
+ map01 = map5.get_major_submap(2)
+ print(map01)
- # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
- map34 = map5.get_minor_submap(2)
- print(map34)
+ # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
+ map34 = map5.get_minor_submap(2)
+ print(map34)
# CHECK-LABEL: TEST: testAffineMapProperties
@run
def testAffineMapProperties():
- with Context():
- d0 = AffineDimExpr.get(0)
- d1 = AffineDimExpr.get(1)
- d2 = AffineDimExpr.get(2)
- map1 = AffineMap.get(3, 0, [d2, d0])
- map2 = AffineMap.get(3, 0, [d2, d0, d1])
- map3 = AffineMap.get(3, 1, [d2, d0, d1])
- # CHECK: False
- print(map1.is_permutation)
- # CHECK: True
- print(map1.is_projected_permutation)
- # CHECK: True
- print(map2.is_permutation)
- # CHECK: True
- print(map2.is_projected_permutation)
- # CHECK: False
- print(map3.is_permutation)
- # CHECK: False
- print(map3.is_projected_permutation)
+ with Context():
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ d2 = AffineDimExpr.get(2)
+ map1 = AffineMap.get(3, 0, [d2, d0])
+ map2 = AffineMap.get(3, 0, [d2, d0, d1])
+ map3 = AffineMap.get(3, 1, [d2, d0, d1])
+ # CHECK: False
+ print(map1.is_permutation)
+ # CHECK: True
+ print(map1.is_projected_permutation)
+ # CHECK: True
+ print(map2.is_permutation)
+ # CHECK: True
+ print(map2.is_projected_permutation)
+ # CHECK: False
+ print(map3.is_permutation)
+ # CHECK: False
+ print(map3.is_projected_permutation)
# CHECK-LABEL: TEST: testAffineMapExprs
@run
def testAffineMapExprs():
- with Context():
- d0 = AffineDimExpr.get(0)
- d1 = AffineDimExpr.get(1)
- d2 = AffineDimExpr.get(2)
- map3 = AffineMap.get(3, 1, [d2, d0, d1])
-
- # CHECK: 3
- print(map3.n_dims)
- # CHECK: 4
- print(map3.n_inputs)
- # CHECK: 1
- print(map3.n_symbols)
- assert map3.n_inputs == map3.n_dims + map3.n_symbols
-
- # CHECK: 3
- print(len(map3.results))
- for expr in map3.results:
- # CHECK: d2
- # CHECK: d0
- # CHECK: d1
- print(expr)
- for expr in map3.results[-1:-4:-1]:
- # CHECK: d1
- # CHECK: d0
- # CHECK: d2
- print(expr)
- assert list(map3.results) == [d2, d0, d1]
+ with Context():
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ d2 = AffineDimExpr.get(2)
+ map3 = AffineMap.get(3, 1, [d2, d0, d1])
+
+ # CHECK: 3
+ print(map3.n_dims)
+ # CHECK: 4
+ print(map3.n_inputs)
+ # CHECK: 1
+ print(map3.n_symbols)
+ assert map3.n_inputs == map3.n_dims + map3.n_symbols
+
+ # CHECK: 3
+ print(len(map3.results))
+ for expr in map3.results:
+ # CHECK: d2
+ # CHECK: d0
+ # CHECK: d1
+ print(expr)
+ for expr in map3.results[-1:-4:-1]:
+ # CHECK: d1
+ # CHECK: d0
+ # CHECK: d2
+ print(expr)
+ assert list(map3.results) == [d2, d0, d1]
# CHECK-LABEL: TEST: testCompressUnusedSymbols
@run
def testCompressUnusedSymbols():
- with Context() as ctx:
- d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
- AffineDimExpr.get(2))
- s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
- AffineSymbolExpr.get(2))
- maps = [
- AffineMap.get(3, 3, [d2, d0, d1]),
- AffineMap.get(3, 3, [d2, d0 + s2, d1]),
- AffineMap.get(3, 3, [d1, d2, d0])
- ]
-
- compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
-
- # CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
- # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
- # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
- print(maps)
-
- # CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
- # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
- # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
- print(compressed_maps)
+ with Context() as ctx:
+ d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
+ s0, s1, s2 = (
+ AffineSymbolExpr.get(0),
+ AffineSymbolExpr.get(1),
+ AffineSymbolExpr.get(2),
+ )
+ maps = [
+ AffineMap.get(3, 3, [d2, d0, d1]),
+ AffineMap.get(3, 3, [d2, d0 + s2, d1]),
+ AffineMap.get(3, 3, [d1, d2, d0]),
+ ]
+
+ compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
+
+ # CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
+ # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
+ # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
+ print(maps)
+
+ # CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
+ # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
+ # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
+ print(compressed_maps)
# CHECK-LABEL: TEST: testReplace
@run
def testReplace():
- with Context() as ctx:
- d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
- AffineDimExpr.get(2))
- s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
- AffineSymbolExpr.get(2))
- map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
+ with Context() as ctx:
+ d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
+ s0, s1, s2 = (
+ AffineSymbolExpr.get(0),
+ AffineSymbolExpr.get(1),
+ AffineSymbolExpr.get(2),
+ )
+ map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
- replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
- replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
- replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
+ replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
+ replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
+ replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
- # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
- print(replace0)
+ # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
+ print(replace0)
- # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
- print(replace1)
+ # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
+ print(replace1)
- # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
- print(replace3)
+ # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
+ print(replace3)
# CHECK-LABEL: TEST: testHash
@run
def testHash():
- with Context():
- d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
- m1 = AffineMap.get(2, 0, [d0, d1])
- m2 = AffineMap.get(2, 0, [d1, d0])
- assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
-
- dictionary = dict()
- dictionary[m1] = 1
- dictionary[m2] = 2
- assert m1 in dictionary
+ with Context():
+ d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
+ m1 = AffineMap.get(2, 0, [d0, d1])
+ m2 = AffineMap.get(2, 0, [d1, d0])
+ assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
+
+ dictionary = dict()
+ dictionary[m1] = 1
+ dictionary[m2] = 2
+ assert m1 in dictionary
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 3de4edb884157..5ce8bc66fcf96 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -6,26 +6,30 @@
from mlir.ir import *
import numpy as np
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
+
################################################################################
# Tests of the array/buffer .get() factory method on unsupported dtype.
################################################################################
+
@run
def testGetDenseElementsUnsupported():
- with Context():
- array = np.array([["hello", "goodbye"]])
- try:
- attr = DenseElementsAttr.get(array)
- except ValueError as e:
- # CHECK: unimplemented array format conversion from format:
- print(e)
+ with Context():
+ array = np.array([["hello", "goodbye"]])
+ try:
+ attr = DenseElementsAttr.get(array)
+ except ValueError as e:
+ # CHECK: unimplemented array format conversion from format:
+ print(e)
+
################################################################################
# Splats.
@@ -34,85 +38,85 @@ def testGetDenseElementsUnsupported():
# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
@run
def testGetDenseElementsSplatInt():
- with Context(), Location.unknown():
- t = IntegerType.get_signless(32)
- element = IntegerAttr.get(t, 555)
- shaped_type = RankedTensorType.get((2, 3, 4), t)
- attr = DenseElementsAttr.get_splat(shaped_type, element)
- # CHECK: dense<555> : tensor<2x3x4xi32>
- print(attr)
- # CHECK: is_splat: True
- print("is_splat:", attr.is_splat)
- assert attr.get_splat_value() == element
+ with Context(), Location.unknown():
+ t = IntegerType.get_signless(32)
+ element = IntegerAttr.get(t, 555)
+ shaped_type = RankedTensorType.get((2, 3, 4), t)
+ attr = DenseElementsAttr.get_splat(shaped_type, element)
+ # CHECK: dense<555> : tensor<2x3x4xi32>
+ print(attr)
+ # CHECK: is_splat: True
+ print("is_splat:", attr.is_splat)
+ assert attr.get_splat_value() == element
# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
@run
def testGetDenseElementsSplatFloat():
- with Context(), Location.unknown():
- t = F32Type.get()
- element = FloatAttr.get(t, 1.2)
- shaped_type = RankedTensorType.get((2, 3, 4), t)
- attr = DenseElementsAttr.get_splat(shaped_type, element)
- # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
- print(attr)
- assert attr.get_splat_value() == element
+ with Context(), Location.unknown():
+ t = F32Type.get()
+ element = FloatAttr.get(t, 1.2)
+ shaped_type = RankedTensorType.get((2, 3, 4), t)
+ attr = DenseElementsAttr.get_splat(shaped_type, element)
+ # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
+ print(attr)
+ assert attr.get_splat_value() == element
# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
@run
def testGetDenseElementsSplatErrors():
- with Context(), Location.unknown():
- t = F32Type.get()
- other_t = F64Type.get()
- element = FloatAttr.get(t, 1.2)
- other_element = FloatAttr.get(other_t, 1.2)
- shaped_type = RankedTensorType.get((2, 3, 4), t)
- dynamic_shaped_type = UnrankedTensorType.get(t)
- non_shaped_type = t
-
- try:
- attr = DenseElementsAttr.get_splat(non_shaped_type, element)
- except ValueError as e:
- # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
- print(e)
-
- try:
- attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
- except ValueError as e:
- # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
- print(e)
-
- try:
- attr = DenseElementsAttr.get_splat(shaped_type, other_element)
- except ValueError as e:
- # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
- print(e)
+ with Context(), Location.unknown():
+ t = F32Type.get()
+ other_t = F64Type.get()
+ element = FloatAttr.get(t, 1.2)
+ other_element = FloatAttr.get(other_t, 1.2)
+ shaped_type = RankedTensorType.get((2, 3, 4), t)
+ dynamic_shaped_type = UnrankedTensorType.get(t)
+ non_shaped_type = t
+
+ try:
+ attr = DenseElementsAttr.get_splat(non_shaped_type, element)
+ except ValueError as e:
+ # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
+ print(e)
+
+ try:
+ attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
+ except ValueError as e:
+ # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
+ print(e)
+
+ try:
+ attr = DenseElementsAttr.get_splat(shaped_type, other_element)
+ except ValueError as e:
+ # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
+ print(e)
# CHECK-LABEL: TEST: testRepeatedValuesSplat
@run
def testRepeatedValuesSplat():
- with Context():
- array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
- print(attr)
- # CHECK: is_splat: True
- print("is_splat:", attr.is_splat)
- # CHECK{LITERAL}: [[1. 1. 1.]
- # CHECK{LITERAL}: [1. 1. 1.]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
+ print(attr)
+ # CHECK: is_splat: True
+ print("is_splat:", attr.is_splat)
+ # CHECK{LITERAL}: [[1. 1. 1.]
+ # CHECK{LITERAL}: [1. 1. 1.]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testNonSplat
@run
def testNonSplat():
- with Context():
- array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
- attr = DenseElementsAttr.get(array)
- # CHECK: is_splat: False
- print("is_splat:", attr.is_splat)
+ with Context():
+ array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: is_splat: False
+ print("is_splat:", attr.is_splat)
################################################################################
@@ -121,50 +125,59 @@ def testNonSplat():
### explicitly provided types
+
@run
def testGetDenseElementsBF16():
- with Context():
- array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
- attr = DenseElementsAttr.get(array, type=BF16Type.get())
- # Note: These values don't mean much since just bit-casting. But they
- # shouldn't change.
- # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
- print(attr)
+ with Context():
+ array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
+ attr = DenseElementsAttr.get(array, type=BF16Type.get())
+ # Note: These values don't mean much since just bit-casting. But they
+ # shouldn't change.
+ # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
+ print(attr)
+
@run
def testGetDenseElementsInteger4():
- with Context():
- array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
- attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
- # Note: These values don't mean much since just bit-casting. But they
- # shouldn't change.
- # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
- print(attr)
+ with Context():
+ array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
+ attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
+ # Note: These values don't mean much since just bit-casting. But they
+ # shouldn't change.
+ # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
+ print(attr)
@run
def testGetDenseElementsBool():
- with Context():
- bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
- array = np.packbits(bool_array, axis=None, bitorder="little")
- attr = DenseElementsAttr.get(
- array, type=IntegerType.get_signless(1), shape=bool_array.shape)
- # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
- print(attr)
+ with Context():
+ bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
+ array = np.packbits(bool_array, axis=None, bitorder="little")
+ attr = DenseElementsAttr.get(
+ array, type=IntegerType.get_signless(1), shape=bool_array.shape
+ )
+ # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
+ print(attr)
@run
def testGetDenseElementsBoolSplat():
- with Context():
- zero = np.array(0, dtype=np.uint8)
- one = np.array(255, dtype=np.uint8)
- print(one)
- # CHECK: dense<false> : tensor<4x2x5xi1>
- print(DenseElementsAttr.get(
- zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
- # CHECK: dense<true> : tensor<4x2x5xi1>
- print(DenseElementsAttr.get(
- one, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
+ with Context():
+ zero = np.array(0, dtype=np.uint8)
+ one = np.array(255, dtype=np.uint8)
+ print(one)
+ # CHECK: dense<false> : tensor<4x2x5xi1>
+ print(
+ DenseElementsAttr.get(
+ zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)
+ )
+ )
+ # CHECK: dense<true> : tensor<4x2x5xi1>
+ print(
+ DenseElementsAttr.get(
+ one, type=IntegerType.get_signless(1), shape=(4, 2, 5)
+ )
+ )
### float and double arrays.
@@ -172,213 +185,213 @@ def testGetDenseElementsBoolSplat():
# CHECK-LABEL: TEST: testGetDenseElementsF16
@run
def testGetDenseElementsF16():
- with Context():
- array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
- print(attr)
- # CHECK: {{\[}}[ 2. 4. 8.]
- # CHECK: {{\[}}16. 32. 64.]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
+ print(attr)
+ # CHECK: {{\[}}[ 2. 4. 8.]
+ # CHECK: {{\[}}16. 32. 64.]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsF32
@run
def testGetDenseElementsF32():
- with Context():
- array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
- print(attr)
- # CHECK: {{\[}}[1.1 2.2 3.3]
- # CHECK: {{\[}}4.4 5.5 6.6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
+ print(attr)
+ # CHECK: {{\[}}[1.1 2.2 3.3]
+ # CHECK: {{\[}}4.4 5.5 6.6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsF64
@run
def testGetDenseElementsF64():
- with Context():
- array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
- print(attr)
- # CHECK: {{\[}}[1.1 2.2 3.3]
- # CHECK: {{\[}}4.4 5.5 6.6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
+ print(attr)
+ # CHECK: {{\[}}[1.1 2.2 3.3]
+ # CHECK: {{\[}}4.4 5.5 6.6]]
+ print(np.array(attr))
### 16 bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
@run
def testGetDenseElementsI16Signless():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless
@run
def testGetDenseElementsUI16Signless():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsI16
@run
def testGetDenseElementsI16():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
- attr = DenseElementsAttr.get(array, signless=False)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
+ attr = DenseElementsAttr.get(array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsUI16
@run
def testGetDenseElementsUI16():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
- attr = DenseElementsAttr.get(array, signless=False)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
+ attr = DenseElementsAttr.get(array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
+
### 32 bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
@run
def testGetDenseElementsI32Signless():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
@run
def testGetDenseElementsUI32Signless():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsI32
@run
def testGetDenseElementsI32():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
- attr = DenseElementsAttr.get(array, signless=False)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+ attr = DenseElementsAttr.get(array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsUI32
@run
def testGetDenseElementsUI32():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
- attr = DenseElementsAttr.get(array, signless=False)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+ attr = DenseElementsAttr.get(array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
## 64bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
@run
def testGetDenseElementsI64Signless():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
@run
def testGetDenseElementsUI64Signless():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
- attr = DenseElementsAttr.get(array)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsI64
@run
def testGetDenseElementsI64():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
- attr = DenseElementsAttr.get(array, signless=False)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+ attr = DenseElementsAttr.get(array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsUI64
@run
def testGetDenseElementsUI64():
- with Context():
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
- attr = DenseElementsAttr.get(array, signless=False)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
- print(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(np.array(attr))
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+ attr = DenseElementsAttr.get(array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testGetDenseElementsIndex
@run
def testGetDenseElementsIndex():
- with Context(), Location.unknown():
- idx_type = IndexType.get()
- array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
- attr = DenseElementsAttr.get(array, type=idx_type)
- # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex>
- print(attr)
- arr = np.array(attr)
- # CHECK: {{\[}}[1 2 3]
- # CHECK: {{\[}}4 5 6]]
- print(arr)
- # CHECK: True
- print(arr.dtype == np.int64)
-
+ with Context(), Location.unknown():
+ idx_type = IndexType.get()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+ attr = DenseElementsAttr.get(array, type=idx_type)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex>
+ print(attr)
+ arr = np.array(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(arr)
+ # CHECK: True
+ print(arr.dtype == np.int64)
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 6aad94317e6f1..29074052796b4 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -6,554 +6,550 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testParsePrint
@run
def testParsePrint():
- with Context() as ctx:
- t = Attribute.parse('"hello"')
- assert t.context is ctx
- ctx = None
- gc.collect()
- # CHECK: "hello"
- print(str(t))
- # CHECK: Attribute("hello")
- print(repr(t))
+ with Context() as ctx:
+ t = Attribute.parse('"hello"')
+ assert t.context is ctx
+ ctx = None
+ gc.collect()
+ # CHECK: "hello"
+ print(str(t))
+ # CHECK: Attribute("hello")
+ print(repr(t))
# CHECK-LABEL: TEST: testParseError
@run
def testParseError():
- with Context():
- try:
- t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
- except MLIRError as e:
- # CHECK: testParseError: <
- # CHECK: Unable to parse attribute:
- # CHECK: error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value
- # CHECK: >
- print(f"testParseError: <{e}>")
- else:
- print("Exception not produced")
+ with Context():
+ try:
+ t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
+ except MLIRError as e:
+ # CHECK: testParseError: <
+ # CHECK: Unable to parse attribute:
+ # CHECK: error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value
+ # CHECK: >
+ print(f"testParseError: <{e}>")
+ else:
+ print("Exception not produced")
# CHECK-LABEL: TEST: testAttrEq
@run
def testAttrEq():
- with Context():
- a1 = Attribute.parse('"attr1"')
- a2 = Attribute.parse('"attr2"')
- a3 = Attribute.parse('"attr1"')
- # CHECK: a1 == a1: True
- print("a1 == a1:", a1 == a1)
- # CHECK: a1 == a2: False
- print("a1 == a2:", a1 == a2)
- # CHECK: a1 == a3: True
- print("a1 == a3:", a1 == a3)
- # CHECK: a1 == None: False
- print("a1 == None:", a1 == None)
+ with Context():
+ a1 = Attribute.parse('"attr1"')
+ a2 = Attribute.parse('"attr2"')
+ a3 = Attribute.parse('"attr1"')
+ # CHECK: a1 == a1: True
+ print("a1 == a1:", a1 == a1)
+ # CHECK: a1 == a2: False
+ print("a1 == a2:", a1 == a2)
+ # CHECK: a1 == a3: True
+ print("a1 == a3:", a1 == a3)
+ # CHECK: a1 == None: False
+ print("a1 == None:", a1 == None)
# CHECK-LABEL: TEST: testAttrHash
@run
def testAttrHash():
- with Context():
- a1 = Attribute.parse('"attr1"')
- a2 = Attribute.parse('"attr2"')
- a3 = Attribute.parse('"attr1"')
- # CHECK: hash(a1) == hash(a3): True
- print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
+ with Context():
+ a1 = Attribute.parse('"attr1"')
+ a2 = Attribute.parse('"attr2"')
+ a3 = Attribute.parse('"attr1"')
+ # CHECK: hash(a1) == hash(a3): True
+ print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
- s = set()
- s.add(a1)
- s.add(a2)
- s.add(a3)
- # CHECK: len(s): 2
- print("len(s): ", len(s))
+ s = set()
+ s.add(a1)
+ s.add(a2)
+ s.add(a3)
+ # CHECK: len(s): 2
+ print("len(s): ", len(s))
# CHECK-LABEL: TEST: testAttrCast
@run
def testAttrCast():
- with Context():
- a1 = Attribute.parse('"attr1"')
- a2 = Attribute(a1)
- # CHECK: a1 == a2: True
- print("a1 == a2:", a1 == a2)
+ with Context():
+ a1 = Attribute.parse('"attr1"')
+ a2 = Attribute(a1)
+ # CHECK: a1 == a2: True
+ print("a1 == a2:", a1 == a2)
# CHECK-LABEL: TEST: testAttrIsInstance
@run
def testAttrIsInstance():
- with Context():
- a1 = Attribute.parse("42")
- a2 = Attribute.parse("[42]")
- assert IntegerAttr.isinstance(a1)
- assert not IntegerAttr.isinstance(a2)
- assert not ArrayAttr.isinstance(a1)
- assert ArrayAttr.isinstance(a2)
+ with Context():
+ a1 = Attribute.parse("42")
+ a2 = Attribute.parse("[42]")
+ assert IntegerAttr.isinstance(a1)
+ assert not IntegerAttr.isinstance(a2)
+ assert not ArrayAttr.isinstance(a1)
+ assert ArrayAttr.isinstance(a2)
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
@run
def testAttrEqDoesNotRaise():
- with Context():
- a1 = Attribute.parse('"attr1"')
- not_an_attr = "foo"
- # CHECK: False
- print(a1 == not_an_attr)
- # CHECK: False
- print(a1 == None)
- # CHECK: True
- print(a1 != None)
+ with Context():
+ a1 = Attribute.parse('"attr1"')
+ not_an_attr = "foo"
+ # CHECK: False
+ print(a1 == not_an_attr)
+ # CHECK: False
+ print(a1 == None)
+ # CHECK: True
+ print(a1 != None)
# CHECK-LABEL: TEST: testAttrCapsule
@run
def testAttrCapsule():
- with Context() as ctx:
- a1 = Attribute.parse('"attr1"')
- # CHECK: mlir.ir.Attribute._CAPIPtr
- attr_capsule = a1._CAPIPtr
- print(attr_capsule)
- a2 = Attribute._CAPICreate(attr_capsule)
- assert a2 == a1
- assert a2.context is ctx
+ with Context() as ctx:
+ a1 = Attribute.parse('"attr1"')
+ # CHECK: mlir.ir.Attribute._CAPIPtr
+ attr_capsule = a1._CAPIPtr
+ print(attr_capsule)
+ a2 = Attribute._CAPICreate(attr_capsule)
+ assert a2 == a1
+ assert a2.context is ctx
# CHECK-LABEL: TEST: testStandardAttrCasts
@run
def testStandardAttrCasts():
- with Context():
- a1 = Attribute.parse('"attr1"')
- astr = StringAttr(a1)
- aself = StringAttr(astr)
- # CHECK: Attribute("attr1")
- print(repr(astr))
- try:
- tillegal = StringAttr(Attribute.parse("1.0"))
- except ValueError as e:
- # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
- print("ValueError:", e)
- else:
- print("Exception not produced")
+ with Context():
+ a1 = Attribute.parse('"attr1"')
+ astr = StringAttr(a1)
+ aself = StringAttr(astr)
+ # CHECK: Attribute("attr1")
+ print(repr(astr))
+ try:
+ tillegal = StringAttr(Attribute.parse("1.0"))
+ except ValueError as e:
+ # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
+ print("ValueError:", e)
+ else:
+ print("Exception not produced")
# CHECK-LABEL: TEST: testAffineMapAttr
@run
def testAffineMapAttr():
- with Context() as ctx:
- d0 = AffineDimExpr.get(0)
- d1 = AffineDimExpr.get(1)
- c2 = AffineConstantExpr.get(2)
- map0 = AffineMap.get(2, 3, [])
+ with Context() as ctx:
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ map0 = AffineMap.get(2, 3, [])
- # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
- attr_built = AffineMapAttr.get(map0)
- print(str(attr_built))
+ # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
+ attr_built = AffineMapAttr.get(map0)
+ print(str(attr_built))
- attr_parsed = Attribute.parse(str(attr_built))
- assert attr_built == attr_parsed
+ attr_parsed = Attribute.parse(str(attr_built))
+ assert attr_built == attr_parsed
# CHECK-LABEL: TEST: testFloatAttr
@run
def testFloatAttr():
- with Context(), Location.unknown():
- fattr = FloatAttr(Attribute.parse("42.0 : f32"))
- # CHECK: fattr value: 42.0
- print("fattr value:", fattr.value)
-
- # Test factory methods.
- # CHECK: default_get: 4.200000e+01 : f32
- print("default_get:", FloatAttr.get(
- F32Type.get(), 42.0))
- # CHECK: f32_get: 4.200000e+01 : f32
- print("f32_get:", FloatAttr.get_f32(42.0))
- # CHECK: f64_get: 4.200000e+01 : f64
- print("f64_get:", FloatAttr.get_f64(42.0))
- try:
- fattr_invalid = FloatAttr.get(
- IntegerType.get_signless(32), 42)
- except MLIRError as e:
- # CHECK: Invalid attribute:
- # CHECK: error: unknown: expected floating point type
- print(e)
- else:
- print("Exception not produced")
+ with Context(), Location.unknown():
+ fattr = FloatAttr(Attribute.parse("42.0 : f32"))
+ # CHECK: fattr value: 42.0
+ print("fattr value:", fattr.value)
+
+ # Test factory methods.
+ # CHECK: default_get: 4.200000e+01 : f32
+ print("default_get:", FloatAttr.get(F32Type.get(), 42.0))
+ # CHECK: f32_get: 4.200000e+01 : f32
+ print("f32_get:", FloatAttr.get_f32(42.0))
+ # CHECK: f64_get: 4.200000e+01 : f64
+ print("f64_get:", FloatAttr.get_f64(42.0))
+ try:
+ fattr_invalid = FloatAttr.get(IntegerType.get_signless(32), 42)
+ except MLIRError as e:
+ # CHECK: Invalid attribute:
+ # CHECK: error: unknown: expected floating point type
+ print(e)
+ else:
+ print("Exception not produced")
# CHECK-LABEL: TEST: testIntegerAttr
@run
def testIntegerAttr():
- with Context() as ctx:
- i_attr = IntegerAttr(Attribute.parse("42"))
- # CHECK: i_attr value: 42
- print("i_attr value:", i_attr.value)
- # CHECK: i_attr type: i64
- print("i_attr type:", i_attr.type)
- si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
- # CHECK: si_attr value: -1
- print("si_attr value:", si_attr.value)
- ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
- # CHECK: ui_attr value: 255
- print("ui_attr value:", ui_attr.value)
- idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
- # CHECK: idx_attr value: -1
- print("idx_attr value:", idx_attr.value)
-
- # Test factory methods.
- # CHECK: default_get: 42 : i32
- print("default_get:", IntegerAttr.get(
- IntegerType.get_signless(32), 42))
+ with Context() as ctx:
+ i_attr = IntegerAttr(Attribute.parse("42"))
+ # CHECK: i_attr value: 42
+ print("i_attr value:", i_attr.value)
+ # CHECK: i_attr type: i64
+ print("i_attr type:", i_attr.type)
+ si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
+ # CHECK: si_attr value: -1
+ print("si_attr value:", si_attr.value)
+ ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
+ # CHECK: ui_attr value: 255
+ print("ui_attr value:", ui_attr.value)
+ idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
+ # CHECK: idx_attr value: -1
+ print("idx_attr value:", idx_attr.value)
+
+ # Test factory methods.
+ # CHECK: default_get: 42 : i32
+ print("default_get:", IntegerAttr.get(IntegerType.get_signless(32), 42))
# CHECK-LABEL: TEST: testBoolAttr
@run
def testBoolAttr():
- with Context() as ctx:
- battr = BoolAttr(Attribute.parse("true"))
- # CHECK: iattr value: True
- print("iattr value:", battr.value)
+ with Context() as ctx:
+ battr = BoolAttr(Attribute.parse("true"))
+ # CHECK: iattr value: True
+ print("iattr value:", battr.value)
- # Test factory methods.
- # CHECK: default_get: true
- print("default_get:", BoolAttr.get(True))
+ # Test factory methods.
+ # CHECK: default_get: true
+ print("default_get:", BoolAttr.get(True))
# CHECK-LABEL: TEST: testFlatSymbolRefAttr
@run
def testFlatSymbolRefAttr():
- with Context() as ctx:
- sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
- # CHECK: symattr value: symbol
- print("symattr value:", sattr.value)
+ with Context() as ctx:
+ sattr = FlatSymbolRefAttr(Attribute.parse("@symbol"))
+ # CHECK: symattr value: symbol
+ print("symattr value:", sattr.value)
- # Test factory methods.
- # CHECK: default_get: @foobar
- print("default_get:", FlatSymbolRefAttr.get("foobar"))
+ # Test factory methods.
+ # CHECK: default_get: @foobar
+ print("default_get:", FlatSymbolRefAttr.get("foobar"))
# CHECK-LABEL: TEST: testOpaqueAttr
@run
def testOpaqueAttr():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
- oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>"))
- # CHECK: oattr value: pytest_dummy
- print("oattr value:", oattr.dialect_namespace)
- # CHECK: oattr value: b'dummyattr<>'
- print("oattr value:", oattr.data)
-
- # Test factory methods.
- # CHECK: default_get: #foobar<123>
- print(
- "default_get:",
- OpaqueAttr.get("foobar", bytes("123", "utf-8"), NoneType.get()))
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>"))
+ # CHECK: oattr value: pytest_dummy
+ print("oattr value:", oattr.dialect_namespace)
+ # CHECK: oattr value: b'dummyattr<>'
+ print("oattr value:", oattr.data)
+
+ # Test factory methods.
+ # CHECK: default_get: #foobar<123>
+ print(
+ "default_get:",
+ OpaqueAttr.get("foobar", bytes("123", "utf-8"), NoneType.get()),
+ )
# CHECK-LABEL: TEST: testStringAttr
@run
def testStringAttr():
- with Context() as ctx:
- sattr = StringAttr(Attribute.parse('"stringattr"'))
- # CHECK: sattr value: stringattr
- print("sattr value:", sattr.value)
- # CHECK: sattr value: b'stringattr'
- print("sattr value:", sattr.value_bytes)
+ with Context() as ctx:
+ sattr = StringAttr(Attribute.parse('"stringattr"'))
+ # CHECK: sattr value: stringattr
+ print("sattr value:", sattr.value)
+ # CHECK: sattr value: b'stringattr'
+ print("sattr value:", sattr.value_bytes)
- # Test factory methods.
- # CHECK: default_get: "foobar"
- print("default_get:", StringAttr.get("foobar"))
- # CHECK: typed_get: "12345" : i32
- print("typed_get:", StringAttr.get_typed(
- IntegerType.get_signless(32), "12345"))
+ # Test factory methods.
+ # CHECK: default_get: "foobar"
+ print("default_get:", StringAttr.get("foobar"))
+ # CHECK: typed_get: "12345" : i32
+ print("typed_get:", StringAttr.get_typed(IntegerType.get_signless(32), "12345"))
# CHECK-LABEL: TEST: testNamedAttr
@run
def testNamedAttr():
- with Context():
- a = Attribute.parse('"stringattr"')
- named = a.get_named("foobar") # Note: under the small object threshold
- # CHECK: attr: "stringattr"
- print("attr:", named.attr)
- # CHECK: name: foobar
- print("name:", named.name)
- # CHECK: named: NamedAttribute(foobar="stringattr")
- print("named:", named)
+ with Context():
+ a = Attribute.parse('"stringattr"')
+ named = a.get_named("foobar") # Note: under the small object threshold
+ # CHECK: attr: "stringattr"
+ print("attr:", named.attr)
+ # CHECK: name: foobar
+ print("name:", named.name)
+ # CHECK: named: NamedAttribute(foobar="stringattr")
+ print("named:", named)
# CHECK-LABEL: TEST: testDenseIntAttr
@run
def testDenseIntAttr():
- with Context():
- raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
- # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
- print("attr:", raw)
+ with Context():
+ raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
+ # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
+ print("attr:", raw)
- a = DenseIntElementsAttr(raw)
- assert len(a) == 6
+ a = DenseIntElementsAttr(raw)
+ assert len(a) == 6
- # CHECK: 0 1 2 3 4 5
- for value in a:
- print(value, end=" ")
- print()
+ # CHECK: 0 1 2 3 4 5
+ for value in a:
+ print(value, end=" ")
+ print()
- # CHECK: i32
- print(ShapedType(a.type).element_type)
+ # CHECK: i32
+ print(ShapedType(a.type).element_type)
- raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
- # CHECK: attr: dense<[true, false, true, false]>
- print("attr:", raw)
+ raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
+ # CHECK: attr: dense<[true, false, true, false]>
+ print("attr:", raw)
- a = DenseIntElementsAttr(raw)
- assert len(a) == 4
+ a = DenseIntElementsAttr(raw)
+ assert len(a) == 4
- # CHECK: 1 0 1 0
- for value in a:
- print(value, end=" ")
- print()
+ # CHECK: 1 0 1 0
+ for value in a:
+ print(value, end=" ")
+ print()
- # CHECK: i1
- print(ShapedType(a.type).element_type)
+ # CHECK: i1
+ print(ShapedType(a.type).element_type)
@run
def testDenseArrayGetItem():
- def print_item(AttrClass, attr_asm):
- attr = AttrClass(Attribute.parse(attr_asm))
- print(f"{len(attr)}: {attr[0]}, {attr[1]}")
-
- with Context():
- # CHECK: 2: 0, 1
- print_item(DenseBoolArrayAttr, "array<i1: false, true>")
- # CHECK: 2: 2, 3
- print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
- # CHECK: 2: 4, 5
- print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
- # CHECK: 2: 6, 7
- print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
- # CHECK: 2: 8, 9
- print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
- # CHECK: 2: 1.{{0+}}, 2.{{0+}}
- print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
- # CHECK: 2: 3.{{0+}}, 4.{{0+}}
- print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
+ def print_item(AttrClass, attr_asm):
+ attr = AttrClass(Attribute.parse(attr_asm))
+ print(f"{len(attr)}: {attr[0]}, {attr[1]}")
+
+ with Context():
+ # CHECK: 2: 0, 1
+ print_item(DenseBoolArrayAttr, "array<i1: false, true>")
+ # CHECK: 2: 2, 3
+ print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
+ # CHECK: 2: 4, 5
+ print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
+ # CHECK: 2: 6, 7
+ print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
+ # CHECK: 2: 8, 9
+ print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
+ # CHECK: 2: 1.{{0+}}, 2.{{0+}}
+ print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
+ # CHECK: 2: 3.{{0+}}, 4.{{0+}}
+ print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
@run
def testDenseIntAttrGetItem():
- def print_item(attr_asm):
- attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
- dtype = ShapedType(attr.type).element_type
- try:
- item = attr[0]
- print(f"{dtype}:", item)
- except TypeError as e:
- print(f"{dtype}:", e)
-
- with Context():
- # CHECK: i1: 1
- print_item("dense<true> : tensor<i1>")
- # CHECK: i8: 123
- print_item("dense<123> : tensor<i8>")
- # CHECK: i16: 123
- print_item("dense<123> : tensor<i16>")
- # CHECK: i32: 123
- print_item("dense<123> : tensor<i32>")
- # CHECK: i64: 123
- print_item("dense<123> : tensor<i64>")
- # CHECK: ui8: 123
- print_item("dense<123> : tensor<ui8>")
- # CHECK: ui16: 123
- print_item("dense<123> : tensor<ui16>")
- # CHECK: ui32: 123
- print_item("dense<123> : tensor<ui32>")
- # CHECK: ui64: 123
- print_item("dense<123> : tensor<ui64>")
- # CHECK: si8: -123
- print_item("dense<-123> : tensor<si8>")
- # CHECK: si16: -123
- print_item("dense<-123> : tensor<si16>")
- # CHECK: si32: -123
- print_item("dense<-123> : tensor<si32>")
- # CHECK: si64: -123
- print_item("dense<-123> : tensor<si64>")
-
- # CHECK: i7: Unsupported integer type
- print_item("dense<123> : tensor<i7>")
+ def print_item(attr_asm):
+ attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
+ dtype = ShapedType(attr.type).element_type
+ try:
+ item = attr[0]
+ print(f"{dtype}:", item)
+ except TypeError as e:
+ print(f"{dtype}:", e)
+
+ with Context():
+ # CHECK: i1: 1
+ print_item("dense<true> : tensor<i1>")
+ # CHECK: i8: 123
+ print_item("dense<123> : tensor<i8>")
+ # CHECK: i16: 123
+ print_item("dense<123> : tensor<i16>")
+ # CHECK: i32: 123
+ print_item("dense<123> : tensor<i32>")
+ # CHECK: i64: 123
+ print_item("dense<123> : tensor<i64>")
+ # CHECK: ui8: 123
+ print_item("dense<123> : tensor<ui8>")
+ # CHECK: ui16: 123
+ print_item("dense<123> : tensor<ui16>")
+ # CHECK: ui32: 123
+ print_item("dense<123> : tensor<ui32>")
+ # CHECK: ui64: 123
+ print_item("dense<123> : tensor<ui64>")
+ # CHECK: si8: -123
+ print_item("dense<-123> : tensor<si8>")
+ # CHECK: si16: -123
+ print_item("dense<-123> : tensor<si16>")
+ # CHECK: si32: -123
+ print_item("dense<-123> : tensor<si32>")
+ # CHECK: si64: -123
+ print_item("dense<-123> : tensor<si64>")
+
+ # CHECK: i7: Unsupported integer type
+ print_item("dense<123> : tensor<i7>")
# CHECK-LABEL: TEST: testDenseFPAttr
@run
def testDenseFPAttr():
- with Context():
- raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
- # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
+ with Context():
+ raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
+ # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
- print("attr:", raw)
+ print("attr:", raw)
- a = DenseFPElementsAttr(raw)
- assert len(a) == 4
+ a = DenseFPElementsAttr(raw)
+ assert len(a) == 4
- # CHECK: 0.0 1.0 2.0 3.0
- for value in a:
- print(value, end=" ")
- print()
+ # CHECK: 0.0 1.0 2.0 3.0
+ for value in a:
+ print(value, end=" ")
+ print()
- # CHECK: f32
- print(ShapedType(a.type).element_type)
+ # CHECK: f32
+ print(ShapedType(a.type).element_type)
# CHECK-LABEL: TEST: testDictAttr
@run
def testDictAttr():
- with Context():
- dict_attr = {
- 'stringattr': StringAttr.get('string'),
- 'integerattr' : IntegerAttr.get(
- IntegerType.get_signless(32), 42)
- }
+ with Context():
+ dict_attr = {
+ "stringattr": StringAttr.get("string"),
+ "integerattr": IntegerAttr.get(IntegerType.get_signless(32), 42),
+ }
- a = DictAttr.get(dict_attr)
+ a = DictAttr.get(dict_attr)
- # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
- print("attr:", a)
+ # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
+ print("attr:", a)
- assert len(a) == 2
+ assert len(a) == 2
- # CHECK: 42 : i32
- print(a['integerattr'])
+ # CHECK: 42 : i32
+ print(a["integerattr"])
- # CHECK: "string"
- print(a['stringattr'])
+ # CHECK: "string"
+ print(a["stringattr"])
- # CHECK: True
- print('stringattr' in a)
+ # CHECK: True
+ print("stringattr" in a)
- # CHECK: False
- print('not_in_dict' in a)
+ # CHECK: False
+ print("not_in_dict" in a)
- # Check that exceptions are raised as expected.
- try:
- _ = a['does_not_exist']
- except KeyError:
- pass
- else:
- assert False, "Exception not produced"
+ # Check that exceptions are raised as expected.
+ try:
+ _ = a["does_not_exist"]
+ except KeyError:
+ pass
+ else:
+ assert False, "Exception not produced"
- try:
- _ = a[42]
- except IndexError:
- pass
- else:
- assert False, "expected IndexError on accessing an out-of-bounds attribute"
+ try:
+ _ = a[42]
+ except IndexError:
+ pass
+ else:
+ assert False, "expected IndexError on accessing an out-of-bounds attribute"
- # CHECK "empty: {}"
- print("empty: ", DictAttr.get())
+ # CHECK "empty: {}"
+ print("empty: ", DictAttr.get())
# CHECK-LABEL: TEST: testTypeAttr
@run
def testTypeAttr():
- with Context():
- raw = Attribute.parse("vector<4xf32>")
- # CHECK: attr: vector<4xf32>
- print("attr:", raw)
- type_attr = TypeAttr(raw)
- # CHECK: f32
- print(ShapedType(type_attr.value).element_type)
+ with Context():
+ raw = Attribute.parse("vector<4xf32>")
+ # CHECK: attr: vector<4xf32>
+ print("attr:", raw)
+ type_attr = TypeAttr(raw)
+ # CHECK: f32
+ print(ShapedType(type_attr.value).element_type)
# CHECK-LABEL: TEST: testArrayAttr
@run
def testArrayAttr():
- with Context():
- raw = Attribute.parse("[42, true, vector<4xf32>]")
- # CHECK: attr: [42, true, vector<4xf32>]
- print("raw attr:", raw)
- # CHECK: - 42
- # CHECK: - true
- # CHECK: - vector<4xf32>
- for attr in ArrayAttr(raw):
- print("- ", attr)
-
- with Context():
- intAttr = Attribute.parse("42")
- vecAttr = Attribute.parse("vector<4xf32>")
- boolAttr = BoolAttr.get(True)
- raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
- # CHECK: attr: [vector<4xf32>, true, 42]
- print("raw attr:", raw)
- # CHECK: - vector<4xf32>
- # CHECK: - true
- # CHECK: - 42
- arr = ArrayAttr(raw)
- for attr in arr:
- print("- ", attr)
- # CHECK: attr[0]: vector<4xf32>
- print("attr[0]:", arr[0])
- # CHECK: attr[1]: true
- print("attr[1]:", arr[1])
- # CHECK: attr[2]: 42
- print("attr[2]:", arr[2])
- try:
- print("attr[3]:", arr[3])
- except IndexError as e:
- # CHECK: Error: ArrayAttribute index out of range
- print("Error: ", e)
- with Context():
+ with Context():
+ raw = Attribute.parse("[42, true, vector<4xf32>]")
+ # CHECK: attr: [42, true, vector<4xf32>]
+ print("raw attr:", raw)
+ # CHECK: - 42
+ # CHECK: - true
+ # CHECK: - vector<4xf32>
+ for attr in ArrayAttr(raw):
+ print("- ", attr)
+
+ with Context():
+ intAttr = Attribute.parse("42")
+ vecAttr = Attribute.parse("vector<4xf32>")
+ boolAttr = BoolAttr.get(True)
+ raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
+ # CHECK: attr: [vector<4xf32>, true, 42]
+ print("raw attr:", raw)
+ # CHECK: - vector<4xf32>
+ # CHECK: - true
+ # CHECK: - 42
+ arr = ArrayAttr(raw)
+ for attr in arr:
+ print("- ", attr)
+ # CHECK: attr[0]: vector<4xf32>
+ print("attr[0]:", arr[0])
+ # CHECK: attr[1]: true
+ print("attr[1]:", arr[1])
+ # CHECK: attr[2]: 42
+ print("attr[2]:", arr[2])
try:
- ArrayAttr.get([None])
- except RuntimeError as e:
- # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
- print("Error: ", e)
- try:
- ArrayAttr.get([42])
- except RuntimeError as e:
- # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
- print("Error: ", e)
-
- with Context():
- array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
- array = array + [StringAttr.get("c")]
- # CHECK: concat: ["a", "b", "c"]
- print("concat: ", array)
+ print("attr[3]:", arr[3])
+ except IndexError as e:
+ # CHECK: Error: ArrayAttribute index out of range
+ print("Error: ", e)
+ with Context():
+ try:
+ ArrayAttr.get([None])
+ except RuntimeError as e:
+ # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
+ print("Error: ", e)
+ try:
+ ArrayAttr.get([42])
+ except RuntimeError as e:
+ # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
+ print("Error: ", e)
+
+ with Context():
+ array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
+ array = array + [StringAttr.get("c")]
+ # CHECK: concat: ["a", "b", "c"]
+ print("concat: ", array)
# CHECK-LABEL: TEST: testStridedLayoutAttr
@run
def testStridedLayoutAttr():
- with Context():
- attr = StridedLayoutAttr.get(42, [5, 7, 13])
- # CHECK: strided<[5, 7, 13], offset: 42>
- print(attr)
- # CHECK: 42
- print(attr.offset)
- # CHECK: 3
- print(len(attr.strides))
- # CHECK: 5
- print(attr.strides[0])
- # CHECK: 7
- print(attr.strides[1])
- # CHECK: 13
- print(attr.strides[2])
-
- attr = StridedLayoutAttr.get_fully_dynamic(3)
- dynamic = ShapedType.get_dynamic_stride_or_offset()
- # CHECK: strided<[?, ?, ?], offset: ?>
- print(attr)
- # CHECK: offset is dynamic: True
- print(f"offset is dynamic: {attr.offset == dynamic}")
- # CHECK: rank: 3
- print(f"rank: {len(attr.strides)}")
- # CHECK: strides are dynamic: [True, True, True]
- print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
+ with Context():
+ attr = StridedLayoutAttr.get(42, [5, 7, 13])
+ # CHECK: strided<[5, 7, 13], offset: 42>
+ print(attr)
+ # CHECK: 42
+ print(attr.offset)
+ # CHECK: 3
+ print(len(attr.strides))
+ # CHECK: 5
+ print(attr.strides[0])
+ # CHECK: 7
+ print(attr.strides[1])
+ # CHECK: 13
+ print(attr.strides[2])
+
+ attr = StridedLayoutAttr.get_fully_dynamic(3)
+ dynamic = ShapedType.get_dynamic_stride_or_offset()
+ # CHECK: strided<[?, ?, ?], offset: ?>
+ print(attr)
+ # CHECK: offset is dynamic: True
+ print(f"offset is dynamic: {attr.offset == dynamic}")
+ # CHECK: rank: 3
+ print(f"rank: {len(attr.strides)}")
+ # CHECK: strides are dynamic: [True, True, True]
+ print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index e929d79e6c5cc..8b4d946c97b8d 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -10,11 +10,11 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testBlockCreation
@@ -26,60 +26,66 @@ def run(f):
# CHECK: return
@run
def testBlockCreation():
- with Context() as ctx, Location.unknown():
- module = builtin.ModuleOp()
- with InsertionPoint(module.body):
- f_type = FunctionType.get(
- [IntegerType.get_signless(32),
- IntegerType.get_signless(16)], [])
- f_op = func.FuncOp("test", f_type)
- entry_block = f_op.add_entry_block([Location.name("arg0"), Location.name("arg1")])
- i32_arg, i16_arg = entry_block.arguments
- successor_block = entry_block.create_after(i32_arg.type, arg_locs=[Location.name("successor")])
- with InsertionPoint(successor_block) as successor_ip:
- assert successor_ip.block == successor_block
- func.ReturnOp([])
- middle_block = successor_block.create_before(i16_arg.type, arg_locs=[Location.name("middle")])
-
- with InsertionPoint(entry_block) as entry_ip:
- assert entry_ip.block == entry_block
- cf.BranchOp([i16_arg], dest=middle_block)
-
- with InsertionPoint(middle_block) as middle_ip:
- assert middle_ip.block == middle_block
- cf.BranchOp([i32_arg], dest=successor_block)
- module.print(enable_debug_info=True)
- # Ensure region back references are coherent.
- assert entry_block.region == middle_block.region == successor_block.region
+ with Context() as ctx, Location.unknown():
+ module = builtin.ModuleOp()
+ with InsertionPoint(module.body):
+ f_type = FunctionType.get(
+ [IntegerType.get_signless(32), IntegerType.get_signless(16)], []
+ )
+ f_op = func.FuncOp("test", f_type)
+ entry_block = f_op.add_entry_block(
+ [Location.name("arg0"), Location.name("arg1")]
+ )
+ i32_arg, i16_arg = entry_block.arguments
+ successor_block = entry_block.create_after(
+ i32_arg.type, arg_locs=[Location.name("successor")]
+ )
+ with InsertionPoint(successor_block) as successor_ip:
+ assert successor_ip.block == successor_block
+ func.ReturnOp([])
+ middle_block = successor_block.create_before(
+ i16_arg.type, arg_locs=[Location.name("middle")]
+ )
+
+ with InsertionPoint(entry_block) as entry_ip:
+ assert entry_ip.block == entry_block
+ cf.BranchOp([i16_arg], dest=middle_block)
+
+ with InsertionPoint(middle_block) as middle_ip:
+ assert middle_ip.block == middle_block
+ cf.BranchOp([i32_arg], dest=successor_block)
+ module.print(enable_debug_info=True)
+ # Ensure region back references are coherent.
+ assert entry_block.region == middle_block.region == successor_block.region
# CHECK-LABEL: TEST: testBlockCreationArgLocs
@run
def testBlockCreationArgLocs():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
- f32 = F32Type.get()
- op = Operation.create("test", regions=1, loc=Location.unknown())
- blocks = op.regions[0].blocks
-
- with Location.name("default_loc"):
- blocks.append(f32)
- blocks.append()
- # CHECK: ^bb0(%{{.+}}: f32 loc("default_loc")):
- # CHECK-NEXT: ^bb1:
- op.print(enable_debug_info=True)
-
- try:
- blocks.append(f32)
- except RuntimeError as err:
- # CHECK: Missing loc: An MLIR function requires a Location but none was provided
- print("Missing loc:", err)
-
- try:
- blocks.append(f32, f32, arg_locs=[Location.unknown()])
- except ValueError as err:
- # CHECK: Wrong loc count: Expected 2 locations, got: 1
- print("Wrong loc count:", err)
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ f32 = F32Type.get()
+ op = Operation.create("test", regions=1, loc=Location.unknown())
+ blocks = op.regions[0].blocks
+
+ with Location.name("default_loc"):
+ blocks.append(f32)
+ blocks.append()
+ # CHECK: ^bb0(%{{.+}}: f32 loc("default_loc")):
+ # CHECK-NEXT: ^bb1:
+ op.print(enable_debug_info=True)
+
+ try:
+ blocks.append(f32)
+ except RuntimeError as err:
+ # CHECK: Missing loc: An MLIR function requires a Location but none was provided
+ print("Missing loc:", err)
+
+ try:
+ blocks.append(f32, f32, arg_locs=[Location.unknown()])
+ except ValueError as err:
+ # CHECK: Wrong loc count: Expected 2 locations, got: 1
+ print("Wrong loc count:", err)
# CHECK-LABEL: TEST: testFirstBlockCreation
@@ -87,19 +93,20 @@ def testBlockCreationArgLocs():
# CHECK: return
@run
def testFirstBlockCreation():
- with Context() as ctx, Location.unknown():
- module = builtin.ModuleOp()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
- f = func.FuncOp("test", ([f32], []))
- entry_block = Block.create_at_start(f.operation.regions[0],
- [f32], [Location.name("arg_loc")])
- with InsertionPoint(entry_block):
- func.ReturnOp([])
-
- module.print(enable_debug_info=True)
- assert module.verify()
- assert f.body.blocks[0] == entry_block
+ with Context() as ctx, Location.unknown():
+ module = builtin.ModuleOp()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ f = func.FuncOp("test", ([f32], []))
+ entry_block = Block.create_at_start(
+ f.operation.regions[0], [f32], [Location.name("arg_loc")]
+ )
+ with InsertionPoint(entry_block):
+ func.ReturnOp([])
+
+ module.print(enable_debug_info=True)
+ assert module.verify()
+ assert f.body.blocks[0] == entry_block
# CHECK-LABEL: TEST: testBlockMove
@@ -109,32 +116,32 @@ def testFirstBlockCreation():
# CHECK: }) : () -> f32
@run
def testBlockMove():
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
- dummy = Operation.create("dummy", regions=1)
- block = Block.create_at_start(dummy.operation.regions[0], [f32])
- with InsertionPoint(block):
- ret_op = Operation.create("ret", operands=[block.arguments[0]])
- realop = Operation.create("realop",
- results=[r.type for r in ret_op.operands],
- regions=1)
- block.append_to(realop.operation.regions[0])
- dummy.operation.erase()
- print(module)
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ dummy = Operation.create("dummy", regions=1)
+ block = Block.create_at_start(dummy.operation.regions[0], [f32])
+ with InsertionPoint(block):
+ ret_op = Operation.create("ret", operands=[block.arguments[0]])
+ realop = Operation.create(
+ "realop", results=[r.type for r in ret_op.operands], regions=1
+ )
+ block.append_to(realop.operation.regions[0])
+ dummy.operation.erase()
+ print(module)
# CHECK-LABEL: TEST: testBlockHash
@run
def testBlockHash():
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
- dummy = Operation.create("dummy", regions=1)
- block1 = Block.create_at_start(dummy.operation.regions[0], [f32])
- block2 = Block.create_at_start(dummy.operation.regions[0], [f32])
- assert hash(block1) != hash(block2)
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ dummy = Operation.create("dummy", regions=1)
+ block1 = Block.create_at_start(dummy.operation.regions[0], [f32])
+ block2 = Block.create_at_start(dummy.operation.regions[0], [f32])
+ assert hash(block1) != hash(block2)
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 19e21fff8dba0..fc484a5050839 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -5,246 +5,246 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testParsePrint
@run
def testParsePrint():
- ctx = Context()
- t = Type.parse("i32", ctx)
- assert t.context is ctx
- ctx = None
- gc.collect()
- # CHECK: i32
- print(str(t))
- # CHECK: Type(i32)
- print(repr(t))
+ ctx = Context()
+ t = Type.parse("i32", ctx)
+ assert t.context is ctx
+ ctx = None
+ gc.collect()
+ # CHECK: i32
+ print(str(t))
+ # CHECK: Type(i32)
+ print(repr(t))
# CHECK-LABEL: TEST: testParseError
@run
def testParseError():
- ctx = Context()
- try:
- t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
- except MLIRError as e:
- # CHECK: testParseError: <
- # CHECK: Unable to parse type:
- # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type
- # CHECK: >
- print(f"testParseError: <{e}>")
- else:
- print("Exception not produced")
+ ctx = Context()
+ try:
+ t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
+ except MLIRError as e:
+ # CHECK: testParseError: <
+ # CHECK: Unable to parse type:
+ # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type
+ # CHECK: >
+ print(f"testParseError: <{e}>")
+ else:
+ print("Exception not produced")
# CHECK-LABEL: TEST: testTypeEq
@run
def testTypeEq():
- ctx = Context()
- t1 = Type.parse("i32", ctx)
- t2 = Type.parse("f32", ctx)
- t3 = Type.parse("i32", ctx)
- # CHECK: t1 == t1: True
- print("t1 == t1:", t1 == t1)
- # CHECK: t1 == t2: False
- print("t1 == t2:", t1 == t2)
- # CHECK: t1 == t3: True
- print("t1 == t3:", t1 == t3)
- # CHECK: t1 == None: False
- print("t1 == None:", t1 == None)
+ ctx = Context()
+ t1 = Type.parse("i32", ctx)
+ t2 = Type.parse("f32", ctx)
+ t3 = Type.parse("i32", ctx)
+ # CHECK: t1 == t1: True
+ print("t1 == t1:", t1 == t1)
+ # CHECK: t1 == t2: False
+ print("t1 == t2:", t1 == t2)
+ # CHECK: t1 == t3: True
+ print("t1 == t3:", t1 == t3)
+ # CHECK: t1 == None: False
+ print("t1 == None:", t1 == None)
# CHECK-LABEL: TEST: testTypeHash
@run
def testTypeHash():
- ctx = Context()
- t1 = Type.parse("i32", ctx)
- t2 = Type.parse("f32", ctx)
- t3 = Type.parse("i32", ctx)
+ ctx = Context()
+ t1 = Type.parse("i32", ctx)
+ t2 = Type.parse("f32", ctx)
+ t3 = Type.parse("i32", ctx)
- # CHECK: hash(t1) == hash(t3): True
- print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
+ # CHECK: hash(t1) == hash(t3): True
+ print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
- s = set()
- s.add(t1)
- s.add(t2)
- s.add(t3)
- # CHECK: len(s): 2
- print("len(s): ", len(s))
+ s = set()
+ s.add(t1)
+ s.add(t2)
+ s.add(t3)
+ # CHECK: len(s): 2
+ print("len(s): ", len(s))
# CHECK-LABEL: TEST: testTypeCast
@run
def testTypeCast():
- ctx = Context()
- t1 = Type.parse("i32", ctx)
- t2 = Type(t1)
- # CHECK: t1 == t2: True
- print("t1 == t2:", t1 == t2)
+ ctx = Context()
+ t1 = Type.parse("i32", ctx)
+ t2 = Type(t1)
+ # CHECK: t1 == t2: True
+ print("t1 == t2:", t1 == t2)
# CHECK-LABEL: TEST: testTypeIsInstance
@run
def testTypeIsInstance():
- ctx = Context()
- t1 = Type.parse("i32", ctx)
- t2 = Type.parse("f32", ctx)
- # CHECK: True
- print(IntegerType.isinstance(t1))
- # CHECK: False
- print(F32Type.isinstance(t1))
- # CHECK: True
- print(F32Type.isinstance(t2))
+ ctx = Context()
+ t1 = Type.parse("i32", ctx)
+ t2 = Type.parse("f32", ctx)
+ # CHECK: True
+ print(IntegerType.isinstance(t1))
+ # CHECK: False
+ print(F32Type.isinstance(t1))
+ # CHECK: True
+ print(F32Type.isinstance(t2))
# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
@run
def testTypeEqDoesNotRaise():
- ctx = Context()
- t1 = Type.parse("i32", ctx)
- not_a_type = "foo"
- # CHECK: False
- print(t1 == not_a_type)
- # CHECK: False
- print(t1 == None)
- # CHECK: True
- print(t1 != None)
+ ctx = Context()
+ t1 = Type.parse("i32", ctx)
+ not_a_type = "foo"
+ # CHECK: False
+ print(t1 == not_a_type)
+ # CHECK: False
+ print(t1 == None)
+ # CHECK: True
+ print(t1 != None)
# CHECK-LABEL: TEST: testTypeCapsule
@run
def testTypeCapsule():
- with Context() as ctx:
- t1 = Type.parse("i32", ctx)
- # CHECK: mlir.ir.Type._CAPIPtr
- type_capsule = t1._CAPIPtr
- print(type_capsule)
- t2 = Type._CAPICreate(type_capsule)
- assert t2 == t1
- assert t2.context is ctx
+ with Context() as ctx:
+ t1 = Type.parse("i32", ctx)
+ # CHECK: mlir.ir.Type._CAPIPtr
+ type_capsule = t1._CAPIPtr
+ print(type_capsule)
+ t2 = Type._CAPICreate(type_capsule)
+ assert t2 == t1
+ assert t2.context is ctx
# CHECK-LABEL: TEST: testStandardTypeCasts
@run
def testStandardTypeCasts():
- ctx = Context()
- t1 = Type.parse("i32", ctx)
- tint = IntegerType(t1)
- tself = IntegerType(tint)
- # CHECK: Type(i32)
- print(repr(tint))
- try:
- tillegal = IntegerType(Type.parse("f32", ctx))
- except ValueError as e:
- # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
- print("ValueError:", e)
- else:
- print("Exception not produced")
+ ctx = Context()
+ t1 = Type.parse("i32", ctx)
+ tint = IntegerType(t1)
+ tself = IntegerType(tint)
+ # CHECK: Type(i32)
+ print(repr(tint))
+ try:
+ tillegal = IntegerType(Type.parse("f32", ctx))
+ except ValueError as e:
+ # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
+ print("ValueError:", e)
+ else:
+ print("Exception not produced")
# CHECK-LABEL: TEST: testIntegerType
@run
def testIntegerType():
- with Context() as ctx:
- i32 = IntegerType(Type.parse("i32"))
- # CHECK: i32 width: 32
- print("i32 width:", i32.width)
- # CHECK: i32 signless: True
- print("i32 signless:", i32.is_signless)
- # CHECK: i32 signed: False
- print("i32 signed:", i32.is_signed)
- # CHECK: i32 unsigned: False
- print("i32 unsigned:", i32.is_unsigned)
-
- s32 = IntegerType(Type.parse("si32"))
- # CHECK: s32 signless: False
- print("s32 signless:", s32.is_signless)
- # CHECK: s32 signed: True
- print("s32 signed:", s32.is_signed)
- # CHECK: s32 unsigned: False
- print("s32 unsigned:", s32.is_unsigned)
-
- u32 = IntegerType(Type.parse("ui32"))
- # CHECK: u32 signless: False
- print("u32 signless:", u32.is_signless)
- # CHECK: u32 signed: False
- print("u32 signed:", u32.is_signed)
- # CHECK: u32 unsigned: True
- print("u32 unsigned:", u32.is_unsigned)
-
- # CHECK: signless: i16
- print("signless:", IntegerType.get_signless(16))
- # CHECK: signed: si8
- print("signed:", IntegerType.get_signed(8))
- # CHECK: unsigned: ui64
- print("unsigned:", IntegerType.get_unsigned(64))
+ with Context() as ctx:
+ i32 = IntegerType(Type.parse("i32"))
+ # CHECK: i32 width: 32
+ print("i32 width:", i32.width)
+ # CHECK: i32 signless: True
+ print("i32 signless:", i32.is_signless)
+ # CHECK: i32 signed: False
+ print("i32 signed:", i32.is_signed)
+ # CHECK: i32 unsigned: False
+ print("i32 unsigned:", i32.is_unsigned)
+
+ s32 = IntegerType(Type.parse("si32"))
+ # CHECK: s32 signless: False
+ print("s32 signless:", s32.is_signless)
+ # CHECK: s32 signed: True
+ print("s32 signed:", s32.is_signed)
+ # CHECK: s32 unsigned: False
+ print("s32 unsigned:", s32.is_unsigned)
+
+ u32 = IntegerType(Type.parse("ui32"))
+ # CHECK: u32 signless: False
+ print("u32 signless:", u32.is_signless)
+ # CHECK: u32 signed: False
+ print("u32 signed:", u32.is_signed)
+ # CHECK: u32 unsigned: True
+ print("u32 unsigned:", u32.is_unsigned)
+
+ # CHECK: signless: i16
+ print("signless:", IntegerType.get_signless(16))
+ # CHECK: signed: si8
+ print("signed:", IntegerType.get_signed(8))
+ # CHECK: unsigned: ui64
+ print("unsigned:", IntegerType.get_unsigned(64))
# CHECK-LABEL: TEST: testIndexType
@run
def testIndexType():
- with Context() as ctx:
- # CHECK: index type: index
- print("index type:", IndexType.get())
+ with Context() as ctx:
+ # CHECK: index type: index
+ print("index type:", IndexType.get())
# CHECK-LABEL: TEST: testFloatType
@run
def testFloatType():
- with Context():
- # CHECK: float: f8E4M3FN
- print("float:", Float8E4M3FNType.get())
- # CHECK: float: f8E5M2
- print("float:", Float8E5M2Type.get())
- # CHECK: float: f8E5M2FNUZ
- print("float:", Float8E5M2FNUZType.get())
- # CHECK: float: f8E4M3FNUZ
- print("float:", Float8E4M3FNUZType.get())
- # CHECK: float: f8E4M3B11FNUZ
- print("float:", Float8E4M3B11FNUZType.get())
- # CHECK: float: bf16
- print("float:", BF16Type.get())
- # CHECK: float: f16
- print("float:", F16Type.get())
- # CHECK: float: f32
- print("float:", F32Type.get())
- # CHECK: float: f64
- print("float:", F64Type.get())
+ with Context():
+ # CHECK: float: f8E4M3FN
+ print("float:", Float8E4M3FNType.get())
+ # CHECK: float: f8E5M2
+ print("float:", Float8E5M2Type.get())
+ # CHECK: float: f8E5M2FNUZ
+ print("float:", Float8E5M2FNUZType.get())
+ # CHECK: float: f8E4M3FNUZ
+ print("float:", Float8E4M3FNUZType.get())
+ # CHECK: float: f8E4M3B11FNUZ
+ print("float:", Float8E4M3B11FNUZType.get())
+ # CHECK: float: bf16
+ print("float:", BF16Type.get())
+ # CHECK: float: f16
+ print("float:", F16Type.get())
+ # CHECK: float: f32
+ print("float:", F32Type.get())
+ # CHECK: float: f64
+ print("float:", F64Type.get())
# CHECK-LABEL: TEST: testNoneType
@run
def testNoneType():
- with Context():
- # CHECK: none type: none
- print("none type:", NoneType.get())
+ with Context():
+ # CHECK: none type: none
+ print("none type:", NoneType.get())
# CHECK-LABEL: TEST: testComplexType
@run
def testComplexType():
- with Context() as ctx:
- complex_i32 = ComplexType(Type.parse("complex<i32>"))
- # CHECK: complex type element: i32
- print("complex type element:", complex_i32.element_type)
+ with Context() as ctx:
+ complex_i32 = ComplexType(Type.parse("complex<i32>"))
+ # CHECK: complex type element: i32
+ print("complex type element:", complex_i32.element_type)
- f32 = F32Type.get()
- # CHECK: complex type: complex<f32>
- print("complex type:", ComplexType.get(f32))
+ f32 = F32Type.get()
+ # CHECK: complex type: complex<f32>
+ print("complex type:", ComplexType.get(f32))
- index = IndexType.get()
- try:
- complex_invalid = ComplexType.get(index)
- except ValueError as e:
- # CHECK: invalid 'Type(index)' and expected floating point or integer type.
- print(e)
- else:
- print("Exception not produced")
+ index = IndexType.get()
+ try:
+ complex_invalid = ComplexType.get(index)
+ except ValueError as e:
+ # CHECK: invalid 'Type(index)' and expected floating point or integer type.
+ print(e)
+ else:
+ print("Exception not produced")
# CHECK-LABEL: TEST: testConcreteShapedType
@@ -253,27 +253,26 @@ def testComplexType():
# shaped type. The class hierarchy is preserved on the python side.
@run
def testConcreteShapedType():
- with Context() as ctx:
- vector = VectorType(Type.parse("vector<2x3xf32>"))
- # CHECK: element type: f32
- print("element type:", vector.element_type)
- # CHECK: whether the given shaped type is ranked: True
- print("whether the given shaped type is ranked:", vector.has_rank)
- # CHECK: rank: 2
- print("rank:", vector.rank)
- # CHECK: whether the shaped type has a static shape: True
- print("whether the shaped type has a static shape:",
- vector.has_static_shape)
- # CHECK: whether the dim-th dimension is dynamic: False
- print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
- # CHECK: dim size: 3
- print("dim size:", vector.get_dim_size(1))
- # CHECK: is_dynamic_size: False
- print("is_dynamic_size:", vector.is_dynamic_size(3))
- # CHECK: is_dynamic_stride_or_offset: False
- print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
- # CHECK: isinstance(ShapedType): True
- print("isinstance(ShapedType):", isinstance(vector, ShapedType))
+ with Context() as ctx:
+ vector = VectorType(Type.parse("vector<2x3xf32>"))
+ # CHECK: element type: f32
+ print("element type:", vector.element_type)
+ # CHECK: whether the given shaped type is ranked: True
+ print("whether the given shaped type is ranked:", vector.has_rank)
+ # CHECK: rank: 2
+ print("rank:", vector.rank)
+ # CHECK: whether the shaped type has a static shape: True
+ print("whether the shaped type has a static shape:", vector.has_static_shape)
+ # CHECK: whether the dim-th dimension is dynamic: False
+ print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
+ # CHECK: dim size: 3
+ print("dim size:", vector.get_dim_size(1))
+ # CHECK: is_dynamic_size: False
+ print("is_dynamic_size:", vector.is_dynamic_size(3))
+ # CHECK: is_dynamic_stride_or_offset: False
+ print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
+ # CHECK: isinstance(ShapedType): True
+ print("isinstance(ShapedType):", isinstance(vector, ShapedType))
# CHECK-LABEL: TEST: testAbstractShapedType
@@ -281,321 +280,322 @@ def testConcreteShapedType():
# shaped type (using vector as an example).
@run
def testAbstractShapedType():
- ctx = Context()
- vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
- # CHECK: element type: f32
- print("element type:", vector.element_type)
+ ctx = Context()
+ vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
+ # CHECK: element type: f32
+ print("element type:", vector.element_type)
# CHECK-LABEL: TEST: testVectorType
@run
def testVectorType():
- with Context(), Location.unknown():
- f32 = F32Type.get()
- shape = [2, 3]
- # CHECK: vector type: vector<2x3xf32>
- print("vector type:", VectorType.get(shape, f32))
-
- none = NoneType.get()
- try:
- vector_invalid = VectorType.get(shape, none)
- except MLIRError as e:
- # CHECK: Invalid type:
- # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
- print(e)
- else:
- print("Exception not produced")
+ with Context(), Location.unknown():
+ f32 = F32Type.get()
+ shape = [2, 3]
+ # CHECK: vector type: vector<2x3xf32>
+ print("vector type:", VectorType.get(shape, f32))
+
+ none = NoneType.get()
+ try:
+ vector_invalid = VectorType.get(shape, none)
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
+ print(e)
+ else:
+ print("Exception not produced")
# CHECK-LABEL: TEST: testRankedTensorType
@run
def testRankedTensorType():
- with Context(), Location.unknown():
- f32 = F32Type.get()
- shape = [2, 3]
- loc = Location.unknown()
- # CHECK: ranked tensor type: tensor<2x3xf32>
- print("ranked tensor type:", RankedTensorType.get(shape, f32))
-
- none = NoneType.get()
- try:
- tensor_invalid = RankedTensorType.get(shape, none)
- except MLIRError as e:
- # CHECK: Invalid type:
- # CHECK: error: unknown: invalid tensor element type: 'none'
- print(e)
- else:
- print("Exception not produced")
-
- # Encoding should be None.
- assert RankedTensorType.get(shape, f32).encoding is None
-
- tensor = RankedTensorType.get(shape, f32)
- assert tensor.shape == shape
+ with Context(), Location.unknown():
+ f32 = F32Type.get()
+ shape = [2, 3]
+ loc = Location.unknown()
+ # CHECK: ranked tensor type: tensor<2x3xf32>
+ print("ranked tensor type:", RankedTensorType.get(shape, f32))
+
+ none = NoneType.get()
+ try:
+ tensor_invalid = RankedTensorType.get(shape, none)
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid tensor element type: 'none'
+ print(e)
+ else:
+ print("Exception not produced")
+
+ # Encoding should be None.
+ assert RankedTensorType.get(shape, f32).encoding is None
+
+ tensor = RankedTensorType.get(shape, f32)
+ assert tensor.shape == shape
# CHECK-LABEL: TEST: testUnrankedTensorType
@run
def testUnrankedTensorType():
- with Context(), Location.unknown():
- f32 = F32Type.get()
- loc = Location.unknown()
- unranked_tensor = UnrankedTensorType.get(f32)
- # CHECK: unranked tensor type: tensor<*xf32>
- print("unranked tensor type:", unranked_tensor)
- try:
- invalid_rank = unranked_tensor.rank
- except ValueError as e:
- # CHECK: calling this method requires that the type has a rank.
- print(e)
- else:
- print("Exception not produced")
- try:
- invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
- except ValueError as e:
- # CHECK: calling this method requires that the type has a rank.
- print(e)
- else:
- print("Exception not produced")
- try:
- invalid_get_dim_size = unranked_tensor.get_dim_size(1)
- except ValueError as e:
- # CHECK: calling this method requires that the type has a rank.
- print(e)
- else:
- print("Exception not produced")
-
- none = NoneType.get()
- try:
- tensor_invalid = UnrankedTensorType.get(none)
- except MLIRError as e:
- # CHECK: Invalid type:
- # CHECK: error: unknown: invalid tensor element type: 'none'
- print(e)
- else:
- print("Exception not produced")
+ with Context(), Location.unknown():
+ f32 = F32Type.get()
+ loc = Location.unknown()
+ unranked_tensor = UnrankedTensorType.get(f32)
+ # CHECK: unranked tensor type: tensor<*xf32>
+ print("unranked tensor type:", unranked_tensor)
+ try:
+ invalid_rank = unranked_tensor.rank
+ except ValueError as e:
+ # CHECK: calling this method requires that the type has a rank.
+ print(e)
+ else:
+ print("Exception not produced")
+ try:
+ invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
+ except ValueError as e:
+ # CHECK: calling this method requires that the type has a rank.
+ print(e)
+ else:
+ print("Exception not produced")
+ try:
+ invalid_get_dim_size = unranked_tensor.get_dim_size(1)
+ except ValueError as e:
+ # CHECK: calling this method requires that the type has a rank.
+ print(e)
+ else:
+ print("Exception not produced")
+
+ none = NoneType.get()
+ try:
+ tensor_invalid = UnrankedTensorType.get(none)
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid tensor element type: 'none'
+ print(e)
+ else:
+ print("Exception not produced")
# CHECK-LABEL: TEST: testMemRefType
@run
def testMemRefType():
- with Context(), Location.unknown():
- f32 = F32Type.get()
- shape = [2, 3]
- loc = Location.unknown()
- memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
- # CHECK: memref type: memref<2x3xf32, 2>
- print("memref type:", memref)
- # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
- print("memref layout:", memref.layout)
- # CHECK: memref affine map: (d0, d1) -> (d0, d1)
- print("memref affine map:", memref.affine_map)
- # CHECK: memory space: 2
- print("memory space:", memref.memory_space)
-
- layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
- memref_layout = MemRefType.get(shape, f32, layout=layout)
- # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
- print("memref type:", memref_layout)
- # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)>
- print("memref layout:", memref_layout.layout)
- # CHECK: memref affine map: (d0, d1) -> (d1, d0)
- print("memref affine map:", memref_layout.affine_map)
- # CHECK: memory space: <<NULL ATTRIBUTE>>
- print("memory space:", memref_layout.memory_space)
-
- none = NoneType.get()
- try:
- memref_invalid = MemRefType.get(shape, none)
- except MLIRError as e:
- # CHECK: Invalid type:
- # CHECK: error: unknown: invalid memref element type
- print(e)
- else:
- print("Exception not produced")
-
- assert memref.shape == shape
+ with Context(), Location.unknown():
+ f32 = F32Type.get()
+ shape = [2, 3]
+ loc = Location.unknown()
+ memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
+ # CHECK: memref type: memref<2x3xf32, 2>
+ print("memref type:", memref)
+ # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
+ print("memref layout:", memref.layout)
+ # CHECK: memref affine map: (d0, d1) -> (d0, d1)
+ print("memref affine map:", memref.affine_map)
+ # CHECK: memory space: 2
+ print("memory space:", memref.memory_space)
+
+ layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
+ memref_layout = MemRefType.get(shape, f32, layout=layout)
+ # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
+ print("memref type:", memref_layout)
+ # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)>
+ print("memref layout:", memref_layout.layout)
+ # CHECK: memref affine map: (d0, d1) -> (d1, d0)
+ print("memref affine map:", memref_layout.affine_map)
+ # CHECK: memory space: <<NULL ATTRIBUTE>>
+ print("memory space:", memref_layout.memory_space)
+
+ none = NoneType.get()
+ try:
+ memref_invalid = MemRefType.get(shape, none)
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid memref element type
+ print(e)
+ else:
+ print("Exception not produced")
+
+ assert memref.shape == shape
# CHECK-LABEL: TEST: testUnrankedMemRefType
@run
def testUnrankedMemRefType():
- with Context(), Location.unknown():
- f32 = F32Type.get()
- loc = Location.unknown()
- unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2"))
- # CHECK: unranked memref type: memref<*xf32, 2>
- print("unranked memref type:", unranked_memref)
- try:
- invalid_rank = unranked_memref.rank
- except ValueError as e:
- # CHECK: calling this method requires that the type has a rank.
- print(e)
- else:
- print("Exception not produced")
- try:
- invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
- except ValueError as e:
- # CHECK: calling this method requires that the type has a rank.
- print(e)
- else:
- print("Exception not produced")
- try:
- invalid_get_dim_size = unranked_memref.get_dim_size(1)
- except ValueError as e:
- # CHECK: calling this method requires that the type has a rank.
- print(e)
- else:
- print("Exception not produced")
-
- none = NoneType.get()
- try:
- memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
- except MLIRError as e:
- # CHECK: Invalid type:
- # CHECK: error: unknown: invalid memref element type
- print(e)
- else:
- print("Exception not produced")
+ with Context(), Location.unknown():
+ f32 = F32Type.get()
+ loc = Location.unknown()
+ unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2"))
+ # CHECK: unranked memref type: memref<*xf32, 2>
+ print("unranked memref type:", unranked_memref)
+ try:
+ invalid_rank = unranked_memref.rank
+ except ValueError as e:
+ # CHECK: calling this method requires that the type has a rank.
+ print(e)
+ else:
+ print("Exception not produced")
+ try:
+ invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
+ except ValueError as e:
+ # CHECK: calling this method requires that the type has a rank.
+ print(e)
+ else:
+ print("Exception not produced")
+ try:
+ invalid_get_dim_size = unranked_memref.get_dim_size(1)
+ except ValueError as e:
+ # CHECK: calling this method requires that the type has a rank.
+ print(e)
+ else:
+ print("Exception not produced")
+
+ none = NoneType.get()
+ try:
+ memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid memref element type
+ print(e)
+ else:
+ print("Exception not produced")
# CHECK-LABEL: TEST: testTupleType
@run
def testTupleType():
- with Context() as ctx:
- i32 = IntegerType(Type.parse("i32"))
- f32 = F32Type.get()
- vector = VectorType(Type.parse("vector<2x3xf32>"))
- l = [i32, f32, vector]
- tuple_type = TupleType.get_tuple(l)
- # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
- print("tuple type:", tuple_type)
- # CHECK: number of types: 3
- print("number of types:", tuple_type.num_types)
- # CHECK: pos-th type in the tuple type: f32
- print("pos-th type in the tuple type:", tuple_type.get_type(1))
+ with Context() as ctx:
+ i32 = IntegerType(Type.parse("i32"))
+ f32 = F32Type.get()
+ vector = VectorType(Type.parse("vector<2x3xf32>"))
+ l = [i32, f32, vector]
+ tuple_type = TupleType.get_tuple(l)
+ # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
+ print("tuple type:", tuple_type)
+ # CHECK: number of types: 3
+ print("number of types:", tuple_type.num_types)
+ # CHECK: pos-th type in the tuple type: f32
+ print("pos-th type in the tuple type:", tuple_type.get_type(1))
# CHECK-LABEL: TEST: testFunctionType
@run
def testFunctionType():
- with Context() as ctx:
- input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
- result_types = [IndexType.get()]
- func = FunctionType.get(input_types, result_types)
- # CHECK: INPUTS: [Type(i32), Type(i16)]
- print("INPUTS:", func.inputs)
- # CHECK: RESULTS: [Type(index)]
- print("RESULTS:", func.results)
+ with Context() as ctx:
+ input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
+ result_types = [IndexType.get()]
+ func = FunctionType.get(input_types, result_types)
+ # CHECK: INPUTS: [Type(i32), Type(i16)]
+ print("INPUTS:", func.inputs)
+ # CHECK: RESULTS: [Type(index)]
+ print("RESULTS:", func.results)
# CHECK-LABEL: TEST: testOpaqueType
@run
def testOpaqueType():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
- opaque = OpaqueType.get("dialect", "type")
- # CHECK: opaque type: !dialect.type
- print("opaque type:", opaque)
- # CHECK: dialect namespace: dialect
- print("dialect namespace:", opaque.dialect_namespace)
- # CHECK: data: type
- print("data:", opaque.data)
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ opaque = OpaqueType.get("dialect", "type")
+ # CHECK: opaque type: !dialect.type
+ print("opaque type:", opaque)
+ # CHECK: dialect namespace: dialect
+ print("dialect namespace:", opaque.dialect_namespace)
+ # CHECK: data: type
+ print("data:", opaque.data)
# CHECK-LABEL: TEST: testShapedTypeConstants
# Tests that ShapedType exposes magic value constants.
@run
def testShapedTypeConstants():
- # CHECK: <class 'int'>
- print(type(ShapedType.get_dynamic_size()))
- # CHECK: <class 'int'>
- print(type(ShapedType.get_dynamic_stride_or_offset()))
+ # CHECK: <class 'int'>
+ print(type(ShapedType.get_dynamic_size()))
+ # CHECK: <class 'int'>
+ print(type(ShapedType.get_dynamic_stride_or_offset()))
# CHECK-LABEL: TEST: testTypeIDs
@run
def testTypeIDs():
- with Context(), Location.unknown():
- f32 = F32Type.get()
-
- types = [
- (IntegerType, IntegerType.get_signless(16)),
- (IndexType, IndexType.get()),
- (Float8E4M3FNType, Float8E4M3FNType.get()),
- (Float8E5M2Type, Float8E5M2Type.get()),
- (Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
- (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
- (Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
- (BF16Type, BF16Type.get()),
- (F16Type, F16Type.get()),
- (F32Type, F32Type.get()),
- (F64Type, F64Type.get()),
- (NoneType, NoneType.get()),
- (ComplexType, ComplexType.get(f32)),
- (VectorType, VectorType.get([2, 3], f32)),
- (RankedTensorType, RankedTensorType.get([2, 3], f32)),
- (UnrankedTensorType, UnrankedTensorType.get(f32)),
- (MemRefType, MemRefType.get([2, 3], f32)),
- (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))),
- (TupleType, TupleType.get_tuple([f32])),
- (FunctionType, FunctionType.get([], [])),
- (OpaqueType, OpaqueType.get("tensor", "bob")),
- ]
-
- # CHECK: IntegerType(i16)
- # CHECK: IndexType(index)
- # CHECK: Float8E4M3FNType(f8E4M3FN)
- # CHECK: Float8E5M2Type(f8E5M2)
- # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
- # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
- # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
- # CHECK: BF16Type(bf16)
- # CHECK: F16Type(f16)
- # CHECK: F32Type(f32)
- # CHECK: F64Type(f64)
- # CHECK: NoneType(none)
- # CHECK: ComplexType(complex<f32>)
- # CHECK: VectorType(vector<2x3xf32>)
- # CHECK: RankedTensorType(tensor<2x3xf32>)
- # CHECK: UnrankedTensorType(tensor<*xf32>)
- # CHECK: MemRefType(memref<2x3xf32>)
- # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
- # CHECK: TupleType(tuple<f32>)
- # CHECK: FunctionType(() -> ())
- # CHECK: OpaqueType(!tensor.bob)
- for _, t in types:
- print(repr(t))
-
- # Test getTypeIdFunction agrees with
- # mlirTypeGetTypeID(self) for an instance.
- # CHECK: all equal
- for t1, t2 in types:
- tid1, tid2 = t1.static_typeid, Type(t2).typeid
- assert tid1 == tid2 and hash(tid1) == hash(
- tid2), f"expected hash and value equality {t1} {t2}"
- else:
- print("all equal")
-
- # Test that storing PyTypeID in python dicts
- # works as expected.
- typeid_dict = dict(types)
- assert len(typeid_dict)
-
- # CHECK: all equal
- for t1, t2 in typeid_dict.items():
- assert t1.static_typeid == t2.typeid and hash(
- t1.static_typeid) == hash(
- t2.typeid), f"expected hash and value equality {t1} {t2}"
- else:
- print("all equal")
-
- # CHECK: ShapedType has no typeid.
- try:
- print(ShapedType.static_typeid)
- except AttributeError as e:
- print(e)
-
- vector_type = Type.parse("vector<2x3xf32>")
- # CHECK: True
- print(ShapedType(vector_type).typeid == vector_type.typeid)
+ with Context(), Location.unknown():
+ f32 = F32Type.get()
+
+ types = [
+ (IntegerType, IntegerType.get_signless(16)),
+ (IndexType, IndexType.get()),
+ (Float8E4M3FNType, Float8E4M3FNType.get()),
+ (Float8E5M2Type, Float8E5M2Type.get()),
+ (Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
+ (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
+ (Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
+ (BF16Type, BF16Type.get()),
+ (F16Type, F16Type.get()),
+ (F32Type, F32Type.get()),
+ (F64Type, F64Type.get()),
+ (NoneType, NoneType.get()),
+ (ComplexType, ComplexType.get(f32)),
+ (VectorType, VectorType.get([2, 3], f32)),
+ (RankedTensorType, RankedTensorType.get([2, 3], f32)),
+ (UnrankedTensorType, UnrankedTensorType.get(f32)),
+ (MemRefType, MemRefType.get([2, 3], f32)),
+ (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))),
+ (TupleType, TupleType.get_tuple([f32])),
+ (FunctionType, FunctionType.get([], [])),
+ (OpaqueType, OpaqueType.get("tensor", "bob")),
+ ]
+
+ # CHECK: IntegerType(i16)
+ # CHECK: IndexType(index)
+ # CHECK: Float8E4M3FNType(f8E4M3FN)
+ # CHECK: Float8E5M2Type(f8E5M2)
+ # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
+ # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
+ # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
+ # CHECK: BF16Type(bf16)
+ # CHECK: F16Type(f16)
+ # CHECK: F32Type(f32)
+ # CHECK: F64Type(f64)
+ # CHECK: NoneType(none)
+ # CHECK: ComplexType(complex<f32>)
+ # CHECK: VectorType(vector<2x3xf32>)
+ # CHECK: RankedTensorType(tensor<2x3xf32>)
+ # CHECK: UnrankedTensorType(tensor<*xf32>)
+ # CHECK: MemRefType(memref<2x3xf32>)
+ # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
+ # CHECK: TupleType(tuple<f32>)
+ # CHECK: FunctionType(() -> ())
+ # CHECK: OpaqueType(!tensor.bob)
+ for _, t in types:
+ print(repr(t))
+
+ # Test getTypeIdFunction agrees with
+ # mlirTypeGetTypeID(self) for an instance.
+ # CHECK: all equal
+ for t1, t2 in types:
+ tid1, tid2 = t1.static_typeid, Type(t2).typeid
+ assert tid1 == tid2 and hash(tid1) == hash(
+ tid2
+ ), f"expected hash and value equality {t1} {t2}"
+ else:
+ print("all equal")
+
+ # Test that storing PyTypeID in python dicts
+ # works as expected.
+ typeid_dict = dict(types)
+ assert len(typeid_dict)
+
+ # CHECK: all equal
+ for t1, t2 in typeid_dict.items():
+ assert t1.static_typeid == t2.typeid and hash(t1.static_typeid) == hash(
+ t2.typeid
+ ), f"expected hash and value equality {t1} {t2}"
+ else:
+ print("all equal")
+
+ # CHECK: ShapedType has no typeid.
+ try:
+ print(ShapedType.static_typeid)
+ except AttributeError as e:
+ print(e)
+
+ vector_type = Type.parse("vector<2x3xf32>")
+ # CHECK: True
+ print(ShapedType(vector_type).typeid == vector_type.typeid)
diff --git a/mlir/test/python/ir/context_managers.py b/mlir/test/python/ir/context_managers.py
index b93fcf70ac482..48d9e357324c9 100644
--- a/mlir/test/python/ir/context_managers.py
+++ b/mlir/test/python/ir/context_managers.py
@@ -3,97 +3,110 @@
import gc
from mlir.ir import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
# CHECK-LABEL: TEST: testContextEnterExit
def testContextEnterExit():
- with Context() as ctx:
- assert Context.current is ctx
- try:
- _ = Context.current
- except ValueError as e:
- # CHECK: No current Context
- print(e)
- else: assert False, "Expected exception"
+ with Context() as ctx:
+ assert Context.current is ctx
+ try:
+ _ = Context.current
+ except ValueError as e:
+ # CHECK: No current Context
+ print(e)
+ else:
+ assert False, "Expected exception"
+
run(testContextEnterExit)
# CHECK-LABEL: TEST: testLocationEnterExit
def testLocationEnterExit():
- ctx1 = Context()
- with Location.unknown(ctx1) as loc1:
- assert Context.current is ctx1
- assert Location.current is loc1
-
- # Re-asserting the same context should not change the location.
- with ctx1:
- assert Context.current is ctx1
- assert Location.current is loc1
- # Asserting a
diff erent context should clear it.
- with Context() as ctx2:
- assert Context.current is ctx2
- try:
- _ = Location.current
- except ValueError: pass
- else: assert False, "Expected exception"
-
- # And should restore.
- assert Context.current is ctx1
- assert Location.current is loc1
-
- # All should clear.
- try:
- _ = Location.current
- except ValueError as e:
- # CHECK: No current Location
- print(e)
- else: assert False, "Expected exception"
+ ctx1 = Context()
+ with Location.unknown(ctx1) as loc1:
+ assert Context.current is ctx1
+ assert Location.current is loc1
+
+ # Re-asserting the same context should not change the location.
+ with ctx1:
+ assert Context.current is ctx1
+ assert Location.current is loc1
+ # Asserting a
diff erent context should clear it.
+ with Context() as ctx2:
+ assert Context.current is ctx2
+ try:
+ _ = Location.current
+ except ValueError:
+ pass
+ else:
+ assert False, "Expected exception"
+
+ # And should restore.
+ assert Context.current is ctx1
+ assert Location.current is loc1
+
+ # All should clear.
+ try:
+ _ = Location.current
+ except ValueError as e:
+ # CHECK: No current Location
+ print(e)
+ else:
+ assert False, "Expected exception"
+
run(testLocationEnterExit)
# CHECK-LABEL: TEST: testInsertionPointEnterExit
def testInsertionPointEnterExit():
- ctx1 = Context()
- m = Module.create(Location.unknown(ctx1))
- ip = InsertionPoint(m.body)
-
- with ip:
- assert InsertionPoint.current is ip
- # Asserting a location from the same context should preserve.
- with Location.unknown(ctx1) as loc1:
- assert InsertionPoint.current is ip
- assert Location.current is loc1
- # Location should clear.
+ ctx1 = Context()
+ m = Module.create(Location.unknown(ctx1))
+ ip = InsertionPoint(m.body)
+
+ with ip:
+ assert InsertionPoint.current is ip
+ # Asserting a location from the same context should preserve.
+ with Location.unknown(ctx1) as loc1:
+ assert InsertionPoint.current is ip
+ assert Location.current is loc1
+ # Location should clear.
+ try:
+ _ = Location.current
+ except ValueError:
+ pass
+ else:
+ assert False, "Expected exception"
+
+ # Asserting the same Context should preserve.
+ with ctx1:
+ assert InsertionPoint.current is ip
+
+ # Asserting a
diff erent context should clear it.
+ with Context() as ctx2:
+ assert Context.current is ctx2
+ try:
+ _ = InsertionPoint.current
+ except ValueError:
+ pass
+ else:
+ assert False, "Expected exception"
+
+ # All should clear.
try:
- _ = Location.current
- except ValueError: pass
- else: assert False, "Expected exception"
-
- # Asserting the same Context should preserve.
- with ctx1:
- assert InsertionPoint.current is ip
-
- # Asserting a
diff erent context should clear it.
- with Context() as ctx2:
- assert Context.current is ctx2
- try:
_ = InsertionPoint.current
- except ValueError: pass
- else: assert False, "Expected exception"
-
- # All should clear.
- try:
- _ = InsertionPoint.current
- except ValueError as e:
- # CHECK: No current InsertionPoint
- print(e)
- else: assert False, "Expected exception"
+ except ValueError as e:
+ # CHECK: No current InsertionPoint
+ print(e)
+ else:
+ assert False, "Expected exception"
+
run(testInsertionPointEnterExit)
diff --git a/mlir/test/python/ir/debug.py b/mlir/test/python/ir/debug.py
index 3268d9fa0865c..629a710e68585 100644
--- a/mlir/test/python/ir/debug.py
+++ b/mlir/test/python/ir/debug.py
@@ -2,38 +2,40 @@
from mlir.ir import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
+ print("\nTEST:", f.__name__)
+ f()
# CHECK-LABEL: TEST: testNameIsPrivate
def testNameIsPrivate():
- # `import *` ignores private names starting with an understore, so the debug
- # flag shouldn't be visible unless explicitly imported.
- try:
- _GlobalDebug.flag = True
- except NameError:
- pass
- else:
- assert False, "_GlobalDebug must not be available by default"
+ # `import *` ignores private names starting with an understore, so the debug
+ # flag shouldn't be visible unless explicitly imported.
+ try:
+ _GlobalDebug.flag = True
+ except NameError:
+ pass
+ else:
+ assert False, "_GlobalDebug must not be available by default"
+
run(testNameIsPrivate)
# CHECK-LABEL: TEST: testDebugDlag
def testDebugDlag():
- # Private names must be imported expilcitly.
- from mlir.ir import _GlobalDebug
-
- # CHECK: False
- print(_GlobalDebug.flag)
- _GlobalDebug.flag = True
- # CHECK: True
- print(_GlobalDebug.flag)
- _GlobalDebug.flag = False
- # CHECK: False
- print(_GlobalDebug.flag)
+ # Private names must be imported expilcitly.
+ from mlir.ir import _GlobalDebug
+
+ # CHECK: False
+ print(_GlobalDebug.flag)
+ _GlobalDebug.flag = True
+ # CHECK: True
+ print(_GlobalDebug.flag)
+ _GlobalDebug.flag = False
+ # CHECK: False
+ print(_GlobalDebug.flag)
-run(testDebugDlag)
+run(testDebugDlag)
diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index cc07f6eaf56ed..2f4300d2c55df 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -3,191 +3,222 @@
import gc
from mlir.ir import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
@run
def testLifecycleContextDestroy():
- ctx = Context()
- def callback(foo): ...
- handler = ctx.attach_diagnostic_handler(callback)
- assert handler.attached
- # If context is destroyed before the handler, it should auto-detach.
- ctx = None
- gc.collect()
- assert not handler.attached
+ ctx = Context()
+
+ def callback(foo):
+ ...
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ assert handler.attached
+ # If context is destroyed before the handler, it should auto-detach.
+ ctx = None
+ gc.collect()
+ assert not handler.attached
- # And finally collecting the handler should be fine.
- handler = None
- gc.collect()
+ # And finally collecting the handler should be fine.
+ handler = None
+ gc.collect()
@run
def testLifecycleExplicitDetach():
- ctx = Context()
- def callback(foo): ...
- handler = ctx.attach_diagnostic_handler(callback)
- assert handler.attached
- handler.detach()
- assert not handler.attached
+ ctx = Context()
+
+ def callback(foo):
+ ...
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ assert handler.attached
+ handler.detach()
+ assert not handler.attached
@run
def testLifecycleWith():
- ctx = Context()
- def callback(foo): ...
- with ctx.attach_diagnostic_handler(callback) as handler:
- assert handler.attached
- assert not handler.attached
+ ctx = Context()
+
+ def callback(foo):
+ ...
+
+ with ctx.attach_diagnostic_handler(callback) as handler:
+ assert handler.attached
+ assert not handler.attached
@run
def testLifecycleWithAndExplicitDetach():
- ctx = Context()
- def callback(foo): ...
- with ctx.attach_diagnostic_handler(callback) as handler:
- assert handler.attached
- handler.detach()
- assert not handler.attached
+ ctx = Context()
+
+ def callback(foo):
+ ...
+
+ with ctx.attach_diagnostic_handler(callback) as handler:
+ assert handler.attached
+ handler.detach()
+ assert not handler.attached
# CHECK-LABEL: TEST: testDiagnosticCallback
@run
def testDiagnosticCallback():
- ctx = Context()
- def callback(d):
- # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
- print(f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}")
- return True
- handler = ctx.attach_diagnostic_handler(callback)
- loc = Location.unknown(ctx)
- loc.emit_error("foobar")
- assert not handler.had_error
+ ctx = Context()
+
+ def callback(d):
+ # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
+ print(
+ f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}"
+ )
+ return True
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ loc.emit_error("foobar")
+ assert not handler.had_error
# CHECK-LABEL: TEST: testDiagnosticEmptyNotes
# TODO: Come up with a way to inject a diagnostic with notes from this API.
@run
def testDiagnosticEmptyNotes():
- ctx = Context()
- def callback(d):
- # CHECK: DIAGNOSTIC: notes=()
- print(f"DIAGNOSTIC: notes={d.notes}")
- return True
- handler = ctx.attach_diagnostic_handler(callback)
- loc = Location.unknown(ctx)
- loc.emit_error("foobar")
- assert not handler.had_error
+ ctx = Context()
+
+ def callback(d):
+ # CHECK: DIAGNOSTIC: notes=()
+ print(f"DIAGNOSTIC: notes={d.notes}")
+ return True
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ loc.emit_error("foobar")
+ assert not handler.had_error
# CHECK-LABEL: TEST: testDiagnosticNonEmptyNotes
@run
def testDiagnosticNonEmptyNotes():
- ctx = Context()
- ctx.emit_error_diagnostics = True
- def callback(d):
- # CHECK: DIAGNOSTIC:
- # CHECK: message='arith.addi' op requires one result
- # CHECK: notes=['see current operation: "arith.addi"() : () -> ()']
- print(f"DIAGNOSTIC:")
- print(f" message={d.message}")
- print(f" notes={list(map(str, d.notes))}")
- return True
- handler = ctx.attach_diagnostic_handler(callback)
- loc = Location.unknown(ctx)
- try:
- Operation.create('arith.addi', loc=loc).verify()
- except MLIRError:
- pass
- assert not handler.had_error
+ ctx = Context()
+ ctx.emit_error_diagnostics = True
+
+ def callback(d):
+ # CHECK: DIAGNOSTIC:
+ # CHECK: message='arith.addi' op requires one result
+ # CHECK: notes=['see current operation: "arith.addi"() : () -> ()']
+ print(f"DIAGNOSTIC:")
+ print(f" message={d.message}")
+ print(f" notes={list(map(str, d.notes))}")
+ return True
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ try:
+ Operation.create("arith.addi", loc=loc).verify()
+ except MLIRError:
+ pass
+ assert not handler.had_error
+
# CHECK-LABEL: TEST: testDiagnosticCallbackException
@run
def testDiagnosticCallbackException():
- ctx = Context()
- def callback(d):
- raise ValueError("Error in handler")
- handler = ctx.attach_diagnostic_handler(callback)
- loc = Location.unknown(ctx)
- loc.emit_error("foobar")
- assert handler.had_error
+ ctx = Context()
+
+ def callback(d):
+ raise ValueError("Error in handler")
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ loc.emit_error("foobar")
+ assert handler.had_error
# CHECK-LABEL: TEST: testEscapingDiagnostic
@run
def testEscapingDiagnostic():
- ctx = Context()
- diags = []
- def callback(d):
- diags.append(d)
- return True
- handler = ctx.attach_diagnostic_handler(callback)
- loc = Location.unknown(ctx)
- loc.emit_error("foobar")
- assert not handler.had_error
-
- # CHECK: DIAGNOSTIC: <Invalid Diagnostic>
- print(f"DIAGNOSTIC: {str(diags[0])}")
- try:
- diags[0].severity
- raise RuntimeError("expected exception")
- except ValueError:
- pass
- try:
- diags[0].location
- raise RuntimeError("expected exception")
- except ValueError:
- pass
- try:
- diags[0].message
- raise RuntimeError("expected exception")
- except ValueError:
- pass
- try:
- diags[0].notes
- raise RuntimeError("expected exception")
- except ValueError:
- pass
-
+ ctx = Context()
+ diags = []
+
+ def callback(d):
+ diags.append(d)
+ return True
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ loc.emit_error("foobar")
+ assert not handler.had_error
+
+ # CHECK: DIAGNOSTIC: <Invalid Diagnostic>
+ print(f"DIAGNOSTIC: {str(diags[0])}")
+ try:
+ diags[0].severity
+ raise RuntimeError("expected exception")
+ except ValueError:
+ pass
+ try:
+ diags[0].location
+ raise RuntimeError("expected exception")
+ except ValueError:
+ pass
+ try:
+ diags[0].message
+ raise RuntimeError("expected exception")
+ except ValueError:
+ pass
+ try:
+ diags[0].notes
+ raise RuntimeError("expected exception")
+ except ValueError:
+ pass
# CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles
@run
def testDiagnosticReturnTrueHandles():
- ctx = Context()
- def callback1(d):
- print(f"CALLBACK1: {d}")
- return True
- def callback2(d):
- print(f"CALLBACK2: {d}")
- return True
- ctx.attach_diagnostic_handler(callback1)
- ctx.attach_diagnostic_handler(callback2)
- loc = Location.unknown(ctx)
- # CHECK-NOT: CALLBACK1
- # CHECK: CALLBACK2: foobar
- # CHECK-NOT: CALLBACK1
- loc.emit_error("foobar")
+ ctx = Context()
+
+ def callback1(d):
+ print(f"CALLBACK1: {d}")
+ return True
+
+ def callback2(d):
+ print(f"CALLBACK2: {d}")
+ return True
+
+ ctx.attach_diagnostic_handler(callback1)
+ ctx.attach_diagnostic_handler(callback2)
+ loc = Location.unknown(ctx)
+ # CHECK-NOT: CALLBACK1
+ # CHECK: CALLBACK2: foobar
+ # CHECK-NOT: CALLBACK1
+ loc.emit_error("foobar")
# CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle
@run
def testDiagnosticReturnFalseDoesNotHandle():
- ctx = Context()
- def callback1(d):
- print(f"CALLBACK1: {d}")
- return True
- def callback2(d):
- print(f"CALLBACK2: {d}")
- return False
- ctx.attach_diagnostic_handler(callback1)
- ctx.attach_diagnostic_handler(callback2)
- loc = Location.unknown(ctx)
- # CHECK: CALLBACK2: foobar
- # CHECK: CALLBACK1: foobar
- loc.emit_error("foobar")
+ ctx = Context()
+
+ def callback1(d):
+ print(f"CALLBACK1: {d}")
+ return True
+
+ def callback2(d):
+ print(f"CALLBACK2: {d}")
+ return False
+
+ ctx.attach_diagnostic_handler(callback1)
+ ctx.attach_diagnostic_handler(callback2)
+ loc = Location.unknown(ctx)
+ # CHECK: CALLBACK2: foobar
+ # CHECK: CALLBACK1: foobar
+ loc.emit_error("foobar")
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index 65e81e84354c2..eebf7c3e48989 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -5,60 +5,60 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testDialectDescriptor
@run
def testDialectDescriptor():
- ctx = Context()
- d = ctx.get_dialect_descriptor("func")
- # CHECK: <DialectDescriptor func>
- print(d)
- # CHECK: func
- print(d.namespace)
- try:
- _ = ctx.get_dialect_descriptor("not_existing")
- except ValueError:
- pass
- else:
- assert False, "Expected exception"
+ ctx = Context()
+ d = ctx.get_dialect_descriptor("func")
+ # CHECK: <DialectDescriptor func>
+ print(d)
+ # CHECK: func
+ print(d.namespace)
+ try:
+ _ = ctx.get_dialect_descriptor("not_existing")
+ except ValueError:
+ pass
+ else:
+ assert False, "Expected exception"
# CHECK-LABEL: TEST: testUserDialectClass
@run
def testUserDialectClass():
- ctx = Context()
- # Access using attribute.
- d = ctx.dialects.func
- # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
- print(d)
- try:
- _ = ctx.dialects.not_existing
- except AttributeError:
- pass
- else:
- assert False, "Expected exception"
-
- # Access using index.
- d = ctx.dialects["func"]
- # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
- print(d)
- try:
- _ = ctx.dialects["not_existing"]
- except IndexError:
- pass
- else:
- assert False, "Expected exception"
-
- # Using the 'd' alias.
- d = ctx.d["func"]
- # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
- print(d)
+ ctx = Context()
+ # Access using attribute.
+ d = ctx.dialects.func
+ # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
+ print(d)
+ try:
+ _ = ctx.dialects.not_existing
+ except AttributeError:
+ pass
+ else:
+ assert False, "Expected exception"
+
+ # Access using index.
+ d = ctx.dialects["func"]
+ # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
+ print(d)
+ try:
+ _ = ctx.dialects["not_existing"]
+ except IndexError:
+ pass
+ else:
+ assert False, "Expected exception"
+
+ # Using the 'd' alias.
+ d = ctx.d["func"]
+ # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
+ print(d)
# CHECK-LABEL: TEST: testCustomOpView
@@ -67,40 +67,40 @@ def testUserDialectClass():
# additional capabilities come online.
@run
def testCustomOpView():
+ def createInput():
+ op = Operation.create("pytest_dummy.intinput", results=[f32])
+ # TODO: Auto result cast from operation
+ return op.results[0]
- def createInput():
- op = Operation.create("pytest_dummy.intinput", results=[f32])
- # TODO: Auto result cast from operation
- return op.results[0]
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- m = Module.create()
+ with InsertionPoint(m.body):
+ f32 = F32Type.get()
+ # Create via dialects context collection.
+ input1 = createInput()
+ input2 = createInput()
+ op1 = ctx.dialects.arith.AddFOp(input1, input2)
- with InsertionPoint(m.body):
- f32 = F32Type.get()
- # Create via dialects context collection.
- input1 = createInput()
- input2 = createInput()
- op1 = ctx.dialects.arith.AddFOp(input1, input2)
+ # Create via an import
+ from mlir.dialects.arith import AddFOp
- # Create via an import
- from mlir.dialects.arith import AddFOp
- AddFOp(input1, op1.result)
+ AddFOp(input1, op1.result)
- # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
- # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
- # CHECK: %[[R0:.*]] = arith.addf %[[INPUT0]], %[[INPUT1]] : f32
- # CHECK: %[[R1:.*]] = arith.addf %[[INPUT0]], %[[R0]] : f32
- m.operation.print()
+ # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
+ # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
+ # CHECK: %[[R0:.*]] = arith.addf %[[INPUT0]], %[[INPUT1]] : f32
+ # CHECK: %[[R1:.*]] = arith.addf %[[INPUT0]], %[[R0]] : f32
+ m.operation.print()
# CHECK-LABEL: TEST: testIsRegisteredOperation
@run
def testIsRegisteredOperation():
- ctx = Context()
+ ctx = Context()
- # CHECK: cf.cond_br: True
- print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}")
- # CHECK: func.not_existing: False
- print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}")
+ # CHECK: cf.cond_br: True
+ print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}")
+ # CHECK: func.not_existing: False
+ print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}")
diff --git a/mlir/test/python/ir/exception.py b/mlir/test/python/ir/exception.py
index 6cb2375a13247..74085cd349643 100644
--- a/mlir/test/python/ir/exception.py
+++ b/mlir/test/python/ir/exception.py
@@ -3,75 +3,93 @@
import gc
from mlir.ir import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: test_exception
@run
def test_exception():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- try:
- Operation.parse("""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ try:
+ Operation.parse(
+ """
func.func @foo() {
"test.use"(%0) : (i64) -> () loc("use")
%0 = "test.def"() : () -> i64 loc("def")
return
}
- """, context=ctx)
- except MLIRError as e:
- # CHECK: Exception: <
- # CHECK: Unable to parse operation assembly:
- # CHECK: error: "use": operand #0 does not dominate this use
- # CHECK: note: "use": see current operation: "test.use"(%0) : (i64) -> ()
- # CHECK: note: "def": operand defined here (op in the same block)
- # CHECK: >
- print(f"Exception: <{e}>")
+ """,
+ context=ctx,
+ )
+ except MLIRError as e:
+ # CHECK: Exception: <
+ # CHECK: Unable to parse operation assembly:
+ # CHECK: error: "use": operand #0 does not dominate this use
+ # CHECK: note: "use": see current operation: "test.use"(%0) : (i64) -> ()
+ # CHECK: note: "def": operand defined here (op in the same block)
+ # CHECK: >
+ print(f"Exception: <{e}>")
- # CHECK: message: Unable to parse operation assembly
- print(f"message: {e.message}")
+ # CHECK: message: Unable to parse operation assembly
+ print(f"message: {e.message}")
- # CHECK: error_diagnostics[0]: loc("use") operand #0 does not dominate this use
- # CHECK: error_diagnostics[0].notes[0]: loc("use") see current operation: "test.use"(%0) : (i64) -> ()
- # CHECK: error_diagnostics[0].notes[1]: loc("def") operand defined here (op in the same block)
- print("error_diagnostics[0]: ", e.error_diagnostics[0].location, e.error_diagnostics[0].message)
- print("error_diagnostics[0].notes[0]: ", e.error_diagnostics[0].notes[0].location, e.error_diagnostics[0].notes[0].message)
- print("error_diagnostics[0].notes[1]: ", e.error_diagnostics[0].notes[1].location, e.error_diagnostics[0].notes[1].message)
+ # CHECK: error_diagnostics[0]: loc("use") operand #0 does not dominate this use
+ # CHECK: error_diagnostics[0].notes[0]: loc("use") see current operation: "test.use"(%0) : (i64) -> ()
+ # CHECK: error_diagnostics[0].notes[1]: loc("def") operand defined here (op in the same block)
+ print(
+ "error_diagnostics[0]: ",
+ e.error_diagnostics[0].location,
+ e.error_diagnostics[0].message,
+ )
+ print(
+ "error_diagnostics[0].notes[0]: ",
+ e.error_diagnostics[0].notes[0].location,
+ e.error_diagnostics[0].notes[0].message,
+ )
+ print(
+ "error_diagnostics[0].notes[1]: ",
+ e.error_diagnostics[0].notes[1].location,
+ e.error_diagnostics[0].notes[1].message,
+ )
# CHECK-LABEL: test_emit_error_diagnostics
@run
def test_emit_error_diagnostics():
- ctx = Context()
- loc = Location.unknown(ctx)
- handler_diags = []
- def handler(d):
- handler_diags.append(str(d))
- return True
- ctx.attach_diagnostic_handler(handler)
+ ctx = Context()
+ loc = Location.unknown(ctx)
+ handler_diags = []
+
+ def handler(d):
+ handler_diags.append(str(d))
+ return True
+
+ ctx.attach_diagnostic_handler(handler)
- try:
- Attribute.parse("not an attr", ctx)
- except MLIRError as e:
- # CHECK: emit_error_diagnostics=False:
- # CHECK: e.error_diagnostics: ['expected attribute value']
- # CHECK: handler_diags: []
- print(f"emit_error_diagnostics=False:")
- print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
- print(f"handler_diags: {handler_diags}")
+ try:
+ Attribute.parse("not an attr", ctx)
+ except MLIRError as e:
+ # CHECK: emit_error_diagnostics=False:
+ # CHECK: e.error_diagnostics: ['expected attribute value']
+ # CHECK: handler_diags: []
+ print(f"emit_error_diagnostics=False:")
+ print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+ print(f"handler_diags: {handler_diags}")
- ctx.emit_error_diagnostics = True
- try:
- Attribute.parse("not an attr", ctx)
- except MLIRError as e:
- # CHECK: emit_error_diagnostics=True:
- # CHECK: e.error_diagnostics: []
- # CHECK: handler_diags: ['expected attribute value']
- print(f"emit_error_diagnostics=True:")
- print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
- print(f"handler_diags: {handler_diags}")
+ ctx.emit_error_diagnostics = True
+ try:
+ Attribute.parse("not an attr", ctx)
+ except MLIRError as e:
+ # CHECK: emit_error_diagnostics=True:
+ # CHECK: e.error_diagnostics: []
+ # CHECK: handler_diags: ['expected attribute value']
+ print(f"emit_error_diagnostics=True:")
+ print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+ print(f"handler_diags: {handler_diags}")
diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py
index 81a6ec2984d54..0dc7d757f56d1 100644
--- a/mlir/test/python/ir/insertion_point.py
+++ b/mlir/test/python/ir/insertion_point.py
@@ -5,168 +5,191 @@
import itertools
from mlir.ir import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
# CHECK-LABEL: TEST: test_insert_at_block_end
def test_insert_at_block_end():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- module = Module.parse(r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ module = Module.parse(
+ r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
}
- """)
- entry_block = module.body.operations[0].regions[0].blocks[0]
- ip = InsertionPoint(entry_block)
- ip.insert(Operation.create("custom.op2"))
- # CHECK: "custom.op1"
- # CHECK: "custom.op2"
- module.operation.print()
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ ip = InsertionPoint(entry_block)
+ ip.insert(Operation.create("custom.op2"))
+ # CHECK: "custom.op1"
+ # CHECK: "custom.op2"
+ module.operation.print()
+
run(test_insert_at_block_end)
# CHECK-LABEL: TEST: test_insert_before_operation
def test_insert_before_operation():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- module = Module.parse(r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ module = Module.parse(
+ r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
"custom.op2"() : () -> ()
}
- """)
- entry_block = module.body.operations[0].regions[0].blocks[0]
- ip = InsertionPoint(entry_block.operations[1])
- ip.insert(Operation.create("custom.op3"))
- # CHECK: "custom.op1"
- # CHECK: "custom.op3"
- # CHECK: "custom.op2"
- module.operation.print()
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ ip = InsertionPoint(entry_block.operations[1])
+ ip.insert(Operation.create("custom.op3"))
+ # CHECK: "custom.op1"
+ # CHECK: "custom.op3"
+ # CHECK: "custom.op2"
+ module.operation.print()
+
run(test_insert_before_operation)
# CHECK-LABEL: TEST: test_insert_at_block_begin
def test_insert_at_block_begin():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- module = Module.parse(r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ module = Module.parse(
+ r"""
func.func @foo() -> () {
"custom.op2"() : () -> ()
}
- """)
- entry_block = module.body.operations[0].regions[0].blocks[0]
- ip = InsertionPoint.at_block_begin(entry_block)
- ip.insert(Operation.create("custom.op1"))
- # CHECK: "custom.op1"
- # CHECK: "custom.op2"
- module.operation.print()
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ ip = InsertionPoint.at_block_begin(entry_block)
+ ip.insert(Operation.create("custom.op1"))
+ # CHECK: "custom.op1"
+ # CHECK: "custom.op2"
+ module.operation.print()
+
run(test_insert_at_block_begin)
# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
def test_insert_at_block_begin_empty():
- # TODO: Write this test case when we can create such a situation.
- pass
+ # TODO: Write this test case when we can create such a situation.
+ pass
+
run(test_insert_at_block_begin_empty)
# CHECK-LABEL: TEST: test_insert_at_terminator
def test_insert_at_terminator():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- module = Module.parse(r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ module = Module.parse(
+ r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
return
}
- """)
- entry_block = module.body.operations[0].regions[0].blocks[0]
- ip = InsertionPoint.at_block_terminator(entry_block)
- ip.insert(Operation.create("custom.op2"))
- # CHECK: "custom.op1"
- # CHECK: "custom.op2"
- module.operation.print()
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ ip = InsertionPoint.at_block_terminator(entry_block)
+ ip.insert(Operation.create("custom.op2"))
+ # CHECK: "custom.op1"
+ # CHECK: "custom.op2"
+ module.operation.print()
+
run(test_insert_at_terminator)
# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
def test_insert_at_block_terminator_missing():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with ctx:
- module = Module.parse(r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with ctx:
+ module = Module.parse(
+ r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
}
- """)
- entry_block = module.body.operations[0].regions[0].blocks[0]
- try:
- ip = InsertionPoint.at_block_terminator(entry_block)
- except ValueError as e:
- # CHECK: Block has no terminator
- print(e)
- else:
- assert False, "Expected exception"
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ try:
+ ip = InsertionPoint.at_block_terminator(entry_block)
+ except ValueError as e:
+ # CHECK: Block has no terminator
+ print(e)
+ else:
+ assert False, "Expected exception"
+
run(test_insert_at_block_terminator_missing)
# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors
def test_insert_at_end_with_terminator_errors():
- with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
- module = Module.parse(r"""
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @foo() -> () {
return
}
- """)
- entry_block = module.body.operations[0].regions[0].blocks[0]
- with InsertionPoint(entry_block):
- try:
- Operation.create("custom.op1", results=[], operands=[])
- except IndexError as e:
- # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
- print(f"ERROR: {e}")
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ with InsertionPoint(entry_block):
+ try:
+ Operation.create("custom.op1", results=[], operands=[])
+ except IndexError as e:
+ # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
+ print(f"ERROR: {e}")
+
run(test_insert_at_end_with_terminator_errors)
# CHECK-LABEL: TEST: test_insertion_point_context
def test_insertion_point_context():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- module = Module.parse(r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ module = Module.parse(
+ r"""
func.func @foo() -> () {
"custom.op1"() : () -> ()
}
- """)
- entry_block = module.body.operations[0].regions[0].blocks[0]
- with InsertionPoint(entry_block):
- Operation.create("custom.op2")
- with InsertionPoint.at_block_begin(entry_block):
- Operation.create("custom.opa")
- Operation.create("custom.opb")
- Operation.create("custom.op3")
- # CHECK: "custom.opa"
- # CHECK: "custom.opb"
- # CHECK: "custom.op1"
- # CHECK: "custom.op2"
- # CHECK: "custom.op3"
- module.operation.print()
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ with InsertionPoint(entry_block):
+ Operation.create("custom.op2")
+ with InsertionPoint.at_block_begin(entry_block):
+ Operation.create("custom.opa")
+ Operation.create("custom.opb")
+ Operation.create("custom.op3")
+ # CHECK: "custom.opa"
+ # CHECK: "custom.opb"
+ # CHECK: "custom.op1"
+ # CHECK: "custom.op2"
+ # CHECK: "custom.op3"
+ module.operation.print()
+
run(test_insertion_point_context)
diff --git a/mlir/test/python/ir/integer_set.py b/mlir/test/python/ir/integer_set.py
index d9f158c0d29ce..9fe0480c33a2a 100644
--- a/mlir/test/python/ir/integer_set.py
+++ b/mlir/test/python/ir/integer_set.py
@@ -3,139 +3,140 @@
import gc
from mlir.ir import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testIntegerSetCapsule
@run
def testIntegerSetCapsule():
- with Context() as ctx:
- is1 = IntegerSet.get_empty(1, 1, ctx)
- capsule = is1._CAPIPtr
- # CHECK: mlir.ir.IntegerSet._CAPIPtr
- print(capsule)
- is2 = IntegerSet._CAPICreate(capsule)
- assert is1 == is2
- assert is2.context is ctx
+ with Context() as ctx:
+ is1 = IntegerSet.get_empty(1, 1, ctx)
+ capsule = is1._CAPIPtr
+ # CHECK: mlir.ir.IntegerSet._CAPIPtr
+ print(capsule)
+ is2 = IntegerSet._CAPICreate(capsule)
+ assert is1 == is2
+ assert is2.context is ctx
# CHECK-LABEL: TEST: testIntegerSetGet
@run
def testIntegerSetGet():
- with Context():
- d0 = AffineDimExpr.get(0)
- d1 = AffineDimExpr.get(1)
- s0 = AffineSymbolExpr.get(0)
- c42 = AffineConstantExpr.get(42)
-
- # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
- set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
- print(set0)
-
- # CHECK: (d0)[s0] : (1 == 0)
- set1 = IntegerSet.get_empty(1, 1)
- print(set1)
-
- # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
- set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
- print(set2)
-
- try:
- IntegerSet.get(2, 1, [], [])
- except ValueError as e:
- # CHECK: Expected non-empty list of constraints
- print(e)
-
- try:
- IntegerSet.get(2, 1, [d0 - d1], [True, False])
- except ValueError as e:
- # CHECK: Expected the number of constraints to match that of equality flags
- print(e)
-
- try:
- IntegerSet.get(2, 1, [0], [True])
- except RuntimeError as e:
- # CHECK: Invalid expression when attempting to create an IntegerSet
- print(e)
-
- try:
- IntegerSet.get(2, 1, [None], [True])
- except RuntimeError as e:
- # CHECK: Invalid expression (None?) when attempting to create an IntegerSet
- print(e)
-
- try:
- set0.get_replaced([d0], [s0], 1, 1)
- except ValueError as e:
- # CHECK: Expected the number of dimension replacement expressions to match that of dimensions
- print(e)
-
- try:
- set0.get_replaced([d0, d1], [s0, s0], 1, 1)
- except ValueError as e:
- # CHECK: Expected the number of symbol replacement expressions to match that of symbols
- print(e)
-
- try:
- set0.get_replaced([d0, 1], [s0], 1, 1)
- except RuntimeError as e:
- # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
- print(e)
-
- try:
- set0.get_replaced([d0, d1], [None], 1, 1)
- except RuntimeError as e:
- # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
- print(e)
+ with Context():
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ s0 = AffineSymbolExpr.get(0)
+ c42 = AffineConstantExpr.get(42)
+
+ # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
+ set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
+ print(set0)
+
+ # CHECK: (d0)[s0] : (1 == 0)
+ set1 = IntegerSet.get_empty(1, 1)
+ print(set1)
+
+ # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
+ set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
+ print(set2)
+
+ try:
+ IntegerSet.get(2, 1, [], [])
+ except ValueError as e:
+ # CHECK: Expected non-empty list of constraints
+ print(e)
+
+ try:
+ IntegerSet.get(2, 1, [d0 - d1], [True, False])
+ except ValueError as e:
+ # CHECK: Expected the number of constraints to match that of equality flags
+ print(e)
+
+ try:
+ IntegerSet.get(2, 1, [0], [True])
+ except RuntimeError as e:
+ # CHECK: Invalid expression when attempting to create an IntegerSet
+ print(e)
+
+ try:
+ IntegerSet.get(2, 1, [None], [True])
+ except RuntimeError as e:
+ # CHECK: Invalid expression (None?) when attempting to create an IntegerSet
+ print(e)
+
+ try:
+ set0.get_replaced([d0], [s0], 1, 1)
+ except ValueError as e:
+ # CHECK: Expected the number of dimension replacement expressions to match that of dimensions
+ print(e)
+
+ try:
+ set0.get_replaced([d0, d1], [s0, s0], 1, 1)
+ except ValueError as e:
+ # CHECK: Expected the number of symbol replacement expressions to match that of symbols
+ print(e)
+
+ try:
+ set0.get_replaced([d0, 1], [s0], 1, 1)
+ except RuntimeError as e:
+ # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
+ print(e)
+
+ try:
+ set0.get_replaced([d0, d1], [None], 1, 1)
+ except RuntimeError as e:
+ # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
+ print(e)
# CHECK-LABEL: TEST: testIntegerSetProperties
@run
def testIntegerSetProperties():
- with Context():
- d0 = AffineDimExpr.get(0)
- d1 = AffineDimExpr.get(1)
- s0 = AffineSymbolExpr.get(0)
- c42 = AffineConstantExpr.get(42)
-
- set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
- # CHECK: 2
- print(set0.n_dims)
- # CHECK: 1
- print(set0.n_symbols)
- # CHECK: 3
- print(set0.n_inputs)
- # CHECK: 1
- print(set0.n_equalities)
- # CHECK: 2
- print(set0.n_inequalities)
-
- # CHECK: 3
- print(len(set0.constraints))
-
- # CHECK-DAG: d0 - d1 == 0
- # CHECK-DAG: s0 - 42 >= 0
- # CHECK-DAG: -d0 + s0 >= 0
- for cstr in set0.constraints:
- print(cstr.expr, end='')
- print(" == 0" if cstr.is_eq else " >= 0")
+ with Context():
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ s0 = AffineSymbolExpr.get(0)
+ c42 = AffineConstantExpr.get(42)
+
+ set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
+ # CHECK: 2
+ print(set0.n_dims)
+ # CHECK: 1
+ print(set0.n_symbols)
+ # CHECK: 3
+ print(set0.n_inputs)
+ # CHECK: 1
+ print(set0.n_equalities)
+ # CHECK: 2
+ print(set0.n_inequalities)
+
+ # CHECK: 3
+ print(len(set0.constraints))
+
+ # CHECK-DAG: d0 - d1 == 0
+ # CHECK-DAG: s0 - 42 >= 0
+ # CHECK-DAG: -d0 + s0 >= 0
+ for cstr in set0.constraints:
+ print(cstr.expr, end="")
+ print(" == 0" if cstr.is_eq else " >= 0")
# TODO-LABEL: TEST: testHash
@run
def testHash():
- with Context():
- d0 = AffineDimExpr.get(0)
- d1 = AffineDimExpr.get(1)
- set = IntegerSet.get(2, 0, [d0 + d1], [True])
+ with Context():
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ set = IntegerSet.get(2, 0, [d0 + d1], [True])
- assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True]))
+ assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True]))
- dictionary = dict()
- dictionary[set] = 42
- assert set in dictionary
+ dictionary = dict()
+ dictionary[set] = 42
+ assert set in dictionary
diff --git a/mlir/test/python/ir/location.py b/mlir/test/python/ir/location.py
index 6a30a1d25b318..f66d6c501dcf5 100644
--- a/mlir/test/python/ir/location.py
+++ b/mlir/test/python/ir/location.py
@@ -3,143 +3,150 @@
import gc
from mlir.ir import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
# CHECK-LABEL: TEST: testUnknown
def testUnknown():
- with Context() as ctx:
- loc = Location.unknown()
- assert loc.context is ctx
- ctx = None
- gc.collect()
- # CHECK: unknown str: loc(unknown)
- print("unknown str:", str(loc))
- # CHECK: unknown repr: loc(unknown)
- print("unknown repr:", repr(loc))
+ with Context() as ctx:
+ loc = Location.unknown()
+ assert loc.context is ctx
+ ctx = None
+ gc.collect()
+ # CHECK: unknown str: loc(unknown)
+ print("unknown str:", str(loc))
+ # CHECK: unknown repr: loc(unknown)
+ print("unknown repr:", repr(loc))
+
run(testUnknown)
# CHECK-LABEL: TEST: testLocationAttr
def testLocationAttr():
- with Context() as ctxt:
- loc = Location.unknown()
- attr = loc.attr
- clone = Location.from_attr(attr)
- gc.collect()
- # CHECK: loc: loc(unknown)
- print("loc:", str(loc))
- # CHECK: clone: loc(unknown)
- print("clone:", str(clone))
- assert loc == clone
+ with Context() as ctxt:
+ loc = Location.unknown()
+ attr = loc.attr
+ clone = Location.from_attr(attr)
+ gc.collect()
+ # CHECK: loc: loc(unknown)
+ print("loc:", str(loc))
+ # CHECK: clone: loc(unknown)
+ print("clone:", str(clone))
+ assert loc == clone
+
run(testLocationAttr)
# CHECK-LABEL: TEST: testFileLineCol
def testFileLineCol():
- with Context() as ctx:
- loc = Location.file("foo.txt", 123, 56)
- ctx = None
- gc.collect()
- # CHECK: file str: loc("foo.txt":123:56)
- print("file str:", str(loc))
- # CHECK: file repr: loc("foo.txt":123:56)
- print("file repr:", repr(loc))
+ with Context() as ctx:
+ loc = Location.file("foo.txt", 123, 56)
+ ctx = None
+ gc.collect()
+ # CHECK: file str: loc("foo.txt":123:56)
+ print("file str:", str(loc))
+ # CHECK: file repr: loc("foo.txt":123:56)
+ print("file repr:", repr(loc))
+
run(testFileLineCol)
# CHECK-LABEL: TEST: testName
def testName():
- with Context() as ctx:
- loc = Location.name("nombre")
- locWithChildLoc = Location.name("naam", loc)
- ctx = None
- gc.collect()
- # CHECK: file str: loc("nombre")
- print("file str:", str(loc))
- # CHECK: file repr: loc("nombre")
- print("file repr:", repr(loc))
- # CHECK: file str: loc("naam"("nombre"))
- print("file str:", str(locWithChildLoc))
- # CHECK: file repr: loc("naam"("nombre"))
- print("file repr:", repr(locWithChildLoc))
+ with Context() as ctx:
+ loc = Location.name("nombre")
+ locWithChildLoc = Location.name("naam", loc)
+ ctx = None
+ gc.collect()
+ # CHECK: file str: loc("nombre")
+ print("file str:", str(loc))
+ # CHECK: file repr: loc("nombre")
+ print("file repr:", repr(loc))
+ # CHECK: file str: loc("naam"("nombre"))
+ print("file str:", str(locWithChildLoc))
+ # CHECK: file repr: loc("naam"("nombre"))
+ print("file repr:", repr(locWithChildLoc))
+
run(testName)
# CHECK-LABEL: TEST: testCallSite
def testCallSite():
- with Context() as ctx:
- loc = Location.callsite(
- Location.file("foo.text", 123, 45), [
- Location.file("util.foo", 379, 21),
- Location.file("main.foo", 100, 63)
- ])
- ctx = None
- # CHECK: file str: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63))
- print("file str:", str(loc))
- # CHECK: file repr: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63))
- print("file repr:", repr(loc))
+ with Context() as ctx:
+ loc = Location.callsite(
+ Location.file("foo.text", 123, 45),
+ [Location.file("util.foo", 379, 21), Location.file("main.foo", 100, 63)],
+ )
+ ctx = None
+ # CHECK: file str: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63))
+ print("file str:", str(loc))
+ # CHECK: file repr: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63))
+ print("file repr:", repr(loc))
+
run(testCallSite)
# CHECK-LABEL: TEST: testFused
def testFused():
- with Context() as ctx:
- loc_single = Location.fused([Location.name("apple")])
- loc = Location.fused(
- [Location.name("apple"), Location.name("banana")])
- attr = Attribute.parse('"sauteed"')
- loc_attr = Location.fused([Location.name("carrot"),
- Location.name("potatoes")], attr)
- loc_empty = Location.fused([])
- loc_empty_attr = Location.fused([], attr)
- loc_single_attr = Location.fused([Location.name("apple")], attr)
- ctx = None
- # CHECK: file str: loc("apple")
- print("file str:", str(loc_single))
- # CHECK: file repr: loc("apple")
- print("file repr:", repr(loc_single))
- # CHECK: file str: loc(fused["apple", "banana"])
- print("file str:", str(loc))
- # CHECK: file repr: loc(fused["apple", "banana"])
- print("file repr:", repr(loc))
- # CHECK: file str: loc(fused<"sauteed">["carrot", "potatoes"])
- print("file str:", str(loc_attr))
- # CHECK: file repr: loc(fused<"sauteed">["carrot", "potatoes"])
- print("file repr:", repr(loc_attr))
- # CHECK: file str: loc(unknown)
- print("file str:", str(loc_empty))
- # CHECK: file repr: loc(unknown)
- print("file repr:", repr(loc_empty))
- # CHECK: file str: loc(fused<"sauteed">[unknown])
- print("file str:", str(loc_empty_attr))
- # CHECK: file repr: loc(fused<"sauteed">[unknown])
- print("file repr:", repr(loc_empty_attr))
- # CHECK: file str: loc(fused<"sauteed">["apple"])
- print("file str:", str(loc_single_attr))
- # CHECK: file repr: loc(fused<"sauteed">["apple"])
- print("file repr:", repr(loc_single_attr))
+ with Context() as ctx:
+ loc_single = Location.fused([Location.name("apple")])
+ loc = Location.fused([Location.name("apple"), Location.name("banana")])
+ attr = Attribute.parse('"sauteed"')
+ loc_attr = Location.fused(
+ [Location.name("carrot"), Location.name("potatoes")], attr
+ )
+ loc_empty = Location.fused([])
+ loc_empty_attr = Location.fused([], attr)
+ loc_single_attr = Location.fused([Location.name("apple")], attr)
+ ctx = None
+ # CHECK: file str: loc("apple")
+ print("file str:", str(loc_single))
+ # CHECK: file repr: loc("apple")
+ print("file repr:", repr(loc_single))
+ # CHECK: file str: loc(fused["apple", "banana"])
+ print("file str:", str(loc))
+ # CHECK: file repr: loc(fused["apple", "banana"])
+ print("file repr:", repr(loc))
+ # CHECK: file str: loc(fused<"sauteed">["carrot", "potatoes"])
+ print("file str:", str(loc_attr))
+ # CHECK: file repr: loc(fused<"sauteed">["carrot", "potatoes"])
+ print("file repr:", repr(loc_attr))
+ # CHECK: file str: loc(unknown)
+ print("file str:", str(loc_empty))
+ # CHECK: file repr: loc(unknown)
+ print("file repr:", repr(loc_empty))
+ # CHECK: file str: loc(fused<"sauteed">[unknown])
+ print("file str:", str(loc_empty_attr))
+ # CHECK: file repr: loc(fused<"sauteed">[unknown])
+ print("file repr:", repr(loc_empty_attr))
+ # CHECK: file str: loc(fused<"sauteed">["apple"])
+ print("file str:", str(loc_single_attr))
+ # CHECK: file repr: loc(fused<"sauteed">["apple"])
+ print("file repr:", repr(loc_single_attr))
+
run(testFused)
# CHECK-LABEL: TEST: testLocationCapsule
def testLocationCapsule():
- with Context() as ctx:
- loc1 = Location.file("foo.txt", 123, 56)
- # CHECK: mlir.ir.Location._CAPIPtr
- loc_capsule = loc1._CAPIPtr
- print(loc_capsule)
- loc2 = Location._CAPICreate(loc_capsule)
- assert loc2 == loc1
- assert loc2.context is ctx
+ with Context() as ctx:
+ loc1 = Location.file("foo.txt", 123, 56)
+ # CHECK: mlir.ir.Location._CAPIPtr
+ loc_capsule = loc1._CAPIPtr
+ print(loc_capsule)
+ loc2 = Location._CAPICreate(loc_capsule)
+ assert loc2 == loc1
+ assert loc2.context is ctx
+
run(testLocationCapsule)
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 2d00923683339..a5c38a6b0b076 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -3,12 +3,13 @@
import gc
from mlir.ir import *
+
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# Verify successful parse.
@@ -16,14 +17,14 @@ def run(f):
# CHECK: module @successfulParse
@run
def testParseSuccess():
- ctx = Context()
- module = Module.parse(r"""module @successfulParse {}""", ctx)
- assert module.context is ctx
- print("CLEAR CONTEXT")
- ctx = None # Ensure that module captures the context.
- gc.collect()
- module.dump() # Just outputs to stderr. Verifies that it functions.
- print(str(module))
+ ctx = Context()
+ module = Module.parse(r"""module @successfulParse {}""", ctx)
+ assert module.context is ctx
+ print("CLEAR CONTEXT")
+ ctx = None # Ensure that module captures the context.
+ gc.collect()
+ module.dump() # Just outputs to stderr. Verifies that it functions.
+ print(str(module))
# Verify parse error.
@@ -34,13 +35,13 @@ def testParseSuccess():
# CHECK: >
@run
def testParseError():
- ctx = Context()
- try:
- module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
- except MLIRError as e:
- print(f"testParseError: <{e}>")
- else:
- print("Exception not produced")
+ ctx = Context()
+ try:
+ module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
+ except MLIRError as e:
+ print(f"testParseError: <{e}>")
+ else:
+ print("Exception not produced")
# Verify successful parse.
@@ -48,13 +49,13 @@ def testParseError():
# CHECK: module {
@run
def testCreateEmpty():
- ctx = Context()
- loc = Location.unknown(ctx)
- module = Module.create(loc)
- print("CLEAR CONTEXT")
- ctx = None # Ensure that module captures the context.
- gc.collect()
- print(str(module))
+ ctx = Context()
+ loc = Location.unknown(ctx)
+ module = Module.create(loc)
+ print("CLEAR CONTEXT")
+ ctx = None # Ensure that module captures the context.
+ gc.collect()
+ print(str(module))
# Verify round-trip of ASM that contains unicode.
@@ -65,11 +66,14 @@ def testCreateEmpty():
# CHECK: foo = "\F0\9F\98\8A"
@run
def testRoundtripUnicode():
- ctx = Context()
- module = Module.parse(r"""
+ ctx = Context()
+ module = Module.parse(
+ r"""
func.func private @roundtripUnicode() attributes { foo = "😊" }
- """, ctx)
- print(str(module))
+ """,
+ ctx,
+ )
+ print(str(module))
# Verify round-trip of ASM that contains unicode.
@@ -80,73 +84,74 @@ def testRoundtripUnicode():
# CHECK: foo = "\F0\9F\98\8A"
@run
def testRoundtripBinary():
- with Context():
- module = Module.parse(r"""
+ with Context():
+ module = Module.parse(
+ r"""
func.func private @roundtripUnicode() attributes { foo = "😊" }
- """)
- binary_asm = module.operation.get_asm(binary=True)
- assert isinstance(binary_asm, bytes)
- module = Module.parse(binary_asm)
- print(module)
+ """
+ )
+ binary_asm = module.operation.get_asm(binary=True)
+ assert isinstance(binary_asm, bytes)
+ module = Module.parse(binary_asm)
+ print(module)
# Tests that module.operation works and correctly interns instances.
# CHECK-LABEL: TEST: testModuleOperation
@run
def testModuleOperation():
- ctx = Context()
- module = Module.parse(r"""module @successfulParse {}""", ctx)
- assert ctx._get_live_module_count() == 1
- op1 = module.operation
- assert ctx._get_live_operation_count() == 1
- # CHECK: module @successfulParse
- print(op1)
-
- # Ensure that operations are the same on multiple calls.
- op2 = module.operation
- assert ctx._get_live_operation_count() == 1
- assert op1 is op2
-
- # Test live operation clearing.
- op1 = module.operation
- assert ctx._get_live_operation_count() == 1
- num_invalidated = ctx._clear_live_operations()
- assert num_invalidated == 1
- assert ctx._get_live_operation_count() == 0
- op1 = None
- gc.collect()
- op1 = module.operation
-
- # Ensure that if module is de-referenced, the operations are still valid.
- module = None
- gc.collect()
- print(op1)
-
- # Collect and verify lifetime.
- op1 = None
- op2 = None
- gc.collect()
- print("LIVE OPERATIONS:", ctx._get_live_operation_count())
- assert ctx._get_live_operation_count() == 0
- assert ctx._get_live_module_count() == 0
+ ctx = Context()
+ module = Module.parse(r"""module @successfulParse {}""", ctx)
+ assert ctx._get_live_module_count() == 1
+ op1 = module.operation
+ assert ctx._get_live_operation_count() == 1
+ # CHECK: module @successfulParse
+ print(op1)
+
+ # Ensure that operations are the same on multiple calls.
+ op2 = module.operation
+ assert ctx._get_live_operation_count() == 1
+ assert op1 is op2
+
+ # Test live operation clearing.
+ op1 = module.operation
+ assert ctx._get_live_operation_count() == 1
+ num_invalidated = ctx._clear_live_operations()
+ assert num_invalidated == 1
+ assert ctx._get_live_operation_count() == 0
+ op1 = None
+ gc.collect()
+ op1 = module.operation
+
+ # Ensure that if module is de-referenced, the operations are still valid.
+ module = None
+ gc.collect()
+ print(op1)
+
+ # Collect and verify lifetime.
+ op1 = None
+ op2 = None
+ gc.collect()
+ print("LIVE OPERATIONS:", ctx._get_live_operation_count())
+ assert ctx._get_live_operation_count() == 0
+ assert ctx._get_live_module_count() == 0
# CHECK-LABEL: TEST: testModuleCapsule
@run
def testModuleCapsule():
- ctx = Context()
- module = Module.parse(r"""module @successfulParse {}""", ctx)
- assert ctx._get_live_module_count() == 1
- # CHECK: "mlir.ir.Module._CAPIPtr"
- module_capsule = module._CAPIPtr
- print(module_capsule)
- module_dup = Module._CAPICreate(module_capsule)
- assert module is module_dup
- assert module_dup.context is ctx
- # Gc and verify destructed.
- module = None
- module_capsule = None
- module_dup = None
- gc.collect()
- assert ctx._get_live_module_count() == 0
-
+ ctx = Context()
+ module = Module.parse(r"""module @successfulParse {}""", ctx)
+ assert ctx._get_live_module_count() == 1
+ # CHECK: "mlir.ir.Module._CAPIPtr"
+ module_capsule = module._CAPIPtr
+ print(module_capsule)
+ module_dup = Module._CAPICreate(module_capsule)
+ assert module is module_dup
+ assert module_dup.context is ctx
+ # Gc and verify destructed.
+ module = None
+ module_capsule = None
+ module_dup = None
+ gc.collect()
+ assert ctx._get_live_module_count() == 0
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 22a80891c7a90..639f8ff2b4255 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -8,232 +8,242 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
def expect_index_error(callback):
- try:
- _ = callback()
- raise RuntimeError("Expected IndexError")
- except IndexError:
- pass
+ try:
+ _ = callback()
+ raise RuntimeError("Expected IndexError")
+ except IndexError:
+ pass
# Verify iterator based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
@run
def testTraverseOpRegionBlockIterators():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- module = Module.parse(
- r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
- """, ctx)
- op = module.operation
- assert op.context is ctx
- # Get the block using iterators off of the named collections.
- regions = list(op.regions)
- blocks = list(regions[0].blocks)
- # CHECK: MODULE REGIONS=1 BLOCKS=1
- print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
-
- # Should verify.
- # CHECK: .verify = True
- print(f".verify = {module.operation.verify()}")
-
- # Get the blocks from the default collection.
- default_blocks = list(regions[0])
- # They should compare equal regardless of how obtained.
- assert default_blocks == blocks
-
- # Should be able to get the operations from either the named collection
- # or the block.
- operations = list(blocks[0].operations)
- default_operations = list(blocks[0])
- assert default_operations == operations
-
- def walk_operations(indent, op):
- for i, region in enumerate(op.regions):
- print(f"{indent}REGION {i}:")
- for j, block in enumerate(region):
- print(f"{indent} BLOCK {j}:")
- for k, child_op in enumerate(block):
- print(f"{indent} OP {k}: {child_op}")
- walk_operations(indent + " ", child_op)
-
- # CHECK: REGION 0:
- # CHECK: BLOCK 0:
- # CHECK: OP 0: func
- # CHECK: REGION 0:
- # CHECK: BLOCK 0:
- # CHECK: OP 0: %0 = "custom.addi"
- # CHECK: OP 1: func.return
- walk_operations("", op)
-
- # CHECK: Region iter: <mlir.{{.+}}.RegionIterator
- # CHECK: Block iter: <mlir.{{.+}}.BlockIterator
- # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
- print(" Region iter:", iter(op.regions))
- print(" Block iter:", iter(op.regions[0]))
- print("Operation iter:", iter(op.regions[0].blocks[0]))
+ """,
+ ctx,
+ )
+ op = module.operation
+ assert op.context is ctx
+ # Get the block using iterators off of the named collections.
+ regions = list(op.regions)
+ blocks = list(regions[0].blocks)
+ # CHECK: MODULE REGIONS=1 BLOCKS=1
+ print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
+
+ # Should verify.
+ # CHECK: .verify = True
+ print(f".verify = {module.operation.verify()}")
+
+ # Get the blocks from the default collection.
+ default_blocks = list(regions[0])
+ # They should compare equal regardless of how obtained.
+ assert default_blocks == blocks
+
+ # Should be able to get the operations from either the named collection
+ # or the block.
+ operations = list(blocks[0].operations)
+ default_operations = list(blocks[0])
+ assert default_operations == operations
+
+ def walk_operations(indent, op):
+ for i, region in enumerate(op.regions):
+ print(f"{indent}REGION {i}:")
+ for j, block in enumerate(region):
+ print(f"{indent} BLOCK {j}:")
+ for k, child_op in enumerate(block):
+ print(f"{indent} OP {k}: {child_op}")
+ walk_operations(indent + " ", child_op)
+
+ # CHECK: REGION 0:
+ # CHECK: BLOCK 0:
+ # CHECK: OP 0: func
+ # CHECK: REGION 0:
+ # CHECK: BLOCK 0:
+ # CHECK: OP 0: %0 = "custom.addi"
+ # CHECK: OP 1: func.return
+ walk_operations("", op)
+
+ # CHECK: Region iter: <mlir.{{.+}}.RegionIterator
+ # CHECK: Block iter: <mlir.{{.+}}.BlockIterator
+ # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
+ print(" Region iter:", iter(op.regions))
+ print(" Block iter:", iter(op.regions[0]))
+ print("Operation iter:", iter(op.regions[0].blocks[0]))
# Verify index based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
@run
def testTraverseOpRegionBlockIndices():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- module = Module.parse(
- r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
- """, ctx)
-
- def walk_operations(indent, op):
- for i in range(len(op.regions)):
- region = op.regions[i]
- print(f"{indent}REGION {i}:")
- for j in range(len(region.blocks)):
- block = region.blocks[j]
- print(f"{indent} BLOCK {j}:")
- for k in range(len(block.operations)):
- child_op = block.operations[k]
- print(f"{indent} OP {k}: {child_op}")
- print(f"{indent} OP {k}: parent {child_op.operation.parent.name}")
- walk_operations(indent + " ", child_op)
-
- # CHECK: REGION 0:
- # CHECK: BLOCK 0:
- # CHECK: OP 0: func
- # CHECK: OP 0: parent builtin.module
- # CHECK: REGION 0:
- # CHECK: BLOCK 0:
- # CHECK: OP 0: %0 = "custom.addi"
- # CHECK: OP 0: parent func.func
- # CHECK: OP 1: func.return
- # CHECK: OP 1: parent func.func
- walk_operations("", module.operation)
+ """,
+ ctx,
+ )
+
+ def walk_operations(indent, op):
+ for i in range(len(op.regions)):
+ region = op.regions[i]
+ print(f"{indent}REGION {i}:")
+ for j in range(len(region.blocks)):
+ block = region.blocks[j]
+ print(f"{indent} BLOCK {j}:")
+ for k in range(len(block.operations)):
+ child_op = block.operations[k]
+ print(f"{indent} OP {k}: {child_op}")
+ print(
+ f"{indent} OP {k}: parent {child_op.operation.parent.name}"
+ )
+ walk_operations(indent + " ", child_op)
+
+ # CHECK: REGION 0:
+ # CHECK: BLOCK 0:
+ # CHECK: OP 0: func
+ # CHECK: OP 0: parent builtin.module
+ # CHECK: REGION 0:
+ # CHECK: BLOCK 0:
+ # CHECK: OP 0: %0 = "custom.addi"
+ # CHECK: OP 0: parent func.func
+ # CHECK: OP 1: func.return
+ # CHECK: OP 1: parent func.func
+ walk_operations("", module.operation)
# CHECK-LABEL: TEST: testBlockAndRegionOwners
@run
def testBlockAndRegionOwners():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- module = Module.parse(
- r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
builtin.module {
func.func @f() {
func.return
}
}
- """, ctx)
+ """,
+ ctx,
+ )
- assert module.operation.regions[0].owner == module.operation
- assert module.operation.regions[0].blocks[0].owner == module.operation
+ assert module.operation.regions[0].owner == module.operation
+ assert module.operation.regions[0].blocks[0].owner == module.operation
- func = module.body.operations[0]
- assert func.operation.regions[0].owner == func
- assert func.operation.regions[0].blocks[0].owner == func
+ func = module.body.operations[0]
+ assert func.operation.regions[0].owner == func
+ assert func.operation.regions[0].blocks[0].owner == func
# CHECK-LABEL: TEST: testBlockArgumentList
@run
def testBlockArgumentList():
- with Context() as ctx:
- module = Module.parse(
- r"""
+ with Context() as ctx:
+ module = Module.parse(
+ r"""
func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
return
}
- """, ctx)
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
- assert len(entry_block.arguments) == 3
- # CHECK: Argument 0, type i32
- # CHECK: Argument 1, type f64
- # CHECK: Argument 2, type index
- for arg in entry_block.arguments:
- print(f"Argument {arg.arg_number}, type {arg.type}")
- new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
- arg.set_type(new_type)
-
- # CHECK: Argument 0, type i8
- # CHECK: Argument 1, type i16
- # CHECK: Argument 2, type i24
- for arg in entry_block.arguments:
- print(f"Argument {arg.arg_number}, type {arg.type}")
-
- # Check that slicing works for block argument lists.
- # CHECK: Argument 1, type i16
- # CHECK: Argument 2, type i24
- for arg in entry_block.arguments[1:]:
- print(f"Argument {arg.arg_number}, type {arg.type}")
-
- # Check that we can concatenate slices of argument lists.
- # CHECK: Length: 4
- print("Length: ",
- len(entry_block.arguments[:2] + entry_block.arguments[1:]))
-
- # CHECK: Type: i8
- # CHECK: Type: i16
- # CHECK: Type: i24
- for t in entry_block.arguments.types:
- print("Type: ", t)
-
- # Check that slicing and type access compose.
- # CHECK: Sliced type: i16
- # CHECK: Sliced type: i24
- for t in entry_block.arguments[1:].types:
- print("Sliced type: ", t)
-
- # Check that slice addition works as expected.
- # CHECK: Argument 2, type i24
- # CHECK: Argument 0, type i8
- restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
- for arg in restructured:
- print(f"Argument {arg.arg_number}, type {arg.type}")
+ """,
+ ctx,
+ )
+ func = module.body.operations[0]
+ entry_block = func.regions[0].blocks[0]
+ assert len(entry_block.arguments) == 3
+ # CHECK: Argument 0, type i32
+ # CHECK: Argument 1, type f64
+ # CHECK: Argument 2, type index
+ for arg in entry_block.arguments:
+ print(f"Argument {arg.arg_number}, type {arg.type}")
+ new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
+ arg.set_type(new_type)
+
+ # CHECK: Argument 0, type i8
+ # CHECK: Argument 1, type i16
+ # CHECK: Argument 2, type i24
+ for arg in entry_block.arguments:
+ print(f"Argument {arg.arg_number}, type {arg.type}")
+
+ # Check that slicing works for block argument lists.
+ # CHECK: Argument 1, type i16
+ # CHECK: Argument 2, type i24
+ for arg in entry_block.arguments[1:]:
+ print(f"Argument {arg.arg_number}, type {arg.type}")
+
+ # Check that we can concatenate slices of argument lists.
+ # CHECK: Length: 4
+ print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:]))
+
+ # CHECK: Type: i8
+ # CHECK: Type: i16
+ # CHECK: Type: i24
+ for t in entry_block.arguments.types:
+ print("Type: ", t)
+
+ # Check that slicing and type access compose.
+ # CHECK: Sliced type: i16
+ # CHECK: Sliced type: i24
+ for t in entry_block.arguments[1:].types:
+ print("Sliced type: ", t)
+
+ # Check that slice addition works as expected.
+ # CHECK: Argument 2, type i24
+ # CHECK: Argument 0, type i8
+ restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
+ for arg in restructured:
+ print(f"Argument {arg.arg_number}, type {arg.type}")
# CHECK-LABEL: TEST: testOperationOperands
@run
def testOperationOperands():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
- module = Module.parse(r"""
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @f1(%arg0: i32) {
%0 = "test.producer"() : () -> i64
"test.consumer"(%arg0, %0) : (i32, i64) -> ()
return
- }""")
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
- consumer = entry_block.operations[1]
- assert len(consumer.operands) == 2
- # CHECK: Operand 0, type i32
- # CHECK: Operand 1, type i64
- for i, operand in enumerate(consumer.operands):
- print(f"Operand {i}, type {operand.type}")
-
-
+ }"""
+ )
+ func = module.body.operations[0]
+ entry_block = func.regions[0].blocks[0]
+ consumer = entry_block.operations[1]
+ assert len(consumer.operands) == 2
+ # CHECK: Operand 0, type i32
+ # CHECK: Operand 1, type i64
+ for i, operand in enumerate(consumer.operands):
+ print(f"Operand {i}, type {operand.type}")
# CHECK-LABEL: TEST: testOperationOperandsSlice
@run
def testOperationOperandsSlice():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
- module = Module.parse(r"""
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
@@ -242,708 +252,727 @@ def testOperationOperandsSlice():
%4 = "test.producer4"() : () -> i64
"test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
return
- }""")
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
- consumer = entry_block.operations[5]
- assert len(consumer.operands) == 5
- for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
- assert left == right
-
- # CHECK: test.producer0
- # CHECK: test.producer1
- # CHECK: test.producer2
- # CHECK: test.producer3
- # CHECK: test.producer4
- full_slice = consumer.operands[:]
- for operand in full_slice:
- print(operand)
-
- # CHECK: test.producer0
- # CHECK: test.producer1
- first_two = consumer.operands[0:2]
- for operand in first_two:
- print(operand)
-
- # CHECK: test.producer3
- # CHECK: test.producer4
- last_two = consumer.operands[3:]
- for operand in last_two:
- print(operand)
-
- # CHECK: test.producer0
- # CHECK: test.producer2
- # CHECK: test.producer4
- even = consumer.operands[::2]
- for operand in even:
- print(operand)
-
- # CHECK: test.producer2
- fourth = consumer.operands[::2][1::2]
- for operand in fourth:
- print(operand)
-
-
+ }"""
+ )
+ func = module.body.operations[0]
+ entry_block = func.regions[0].blocks[0]
+ consumer = entry_block.operations[5]
+ assert len(consumer.operands) == 5
+ for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
+ assert left == right
+
+ # CHECK: test.producer0
+ # CHECK: test.producer1
+ # CHECK: test.producer2
+ # CHECK: test.producer3
+ # CHECK: test.producer4
+ full_slice = consumer.operands[:]
+ for operand in full_slice:
+ print(operand)
+
+ # CHECK: test.producer0
+ # CHECK: test.producer1
+ first_two = consumer.operands[0:2]
+ for operand in first_two:
+ print(operand)
+
+ # CHECK: test.producer3
+ # CHECK: test.producer4
+ last_two = consumer.operands[3:]
+ for operand in last_two:
+ print(operand)
+
+ # CHECK: test.producer0
+ # CHECK: test.producer2
+ # CHECK: test.producer4
+ even = consumer.operands[::2]
+ for operand in even:
+ print(operand)
+
+ # CHECK: test.producer2
+ fourth = consumer.operands[::2][1::2]
+ for operand in fourth:
+ print(operand)
# CHECK-LABEL: TEST: testOperationOperandsSet
@run
def testOperationOperandsSet():
- with Context() as ctx, Location.unknown(ctx):
- ctx.allow_unregistered_dialects = True
- module = Module.parse(r"""
+ with Context() as ctx, Location.unknown(ctx):
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
%2 = "test.producer2"() : () -> i64
"test.consumer"(%0) : (i64) -> ()
return
- }""")
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
- producer1 = entry_block.operations[1]
- producer2 = entry_block.operations[2]
- consumer = entry_block.operations[3]
- assert len(consumer.operands) == 1
- type = consumer.operands[0].type
-
- # CHECK: test.producer1
- consumer.operands[0] = producer1.result
- print(consumer.operands[0])
-
- # CHECK: test.producer2
- consumer.operands[-1] = producer2.result
- print(consumer.operands[0])
+ }"""
+ )
+ func = module.body.operations[0]
+ entry_block = func.regions[0].blocks[0]
+ producer1 = entry_block.operations[1]
+ producer2 = entry_block.operations[2]
+ consumer = entry_block.operations[3]
+ assert len(consumer.operands) == 1
+ type = consumer.operands[0].type
+ # CHECK: test.producer1
+ consumer.operands[0] = producer1.result
+ print(consumer.operands[0])
+ # CHECK: test.producer2
+ consumer.operands[-1] = producer2.result
+ print(consumer.operands[0])
# CHECK-LABEL: TEST: testDetachedOperation
@run
def testDetachedOperation():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- i32 = IntegerType.get_signed(32)
- op1 = Operation.create(
- "custom.op1",
- results=[i32, i32],
- regions=1,
- attributes={
- "foo": StringAttr.get("foo_value"),
- "bar": StringAttr.get("bar_value"),
- })
- # CHECK: %0:2 = "custom.op1"() ({
- # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
- print(op1)
-
- # TODO: Check successors once enough infra exists to do it properly.
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signed(32)
+ op1 = Operation.create(
+ "custom.op1",
+ results=[i32, i32],
+ regions=1,
+ attributes={
+ "foo": StringAttr.get("foo_value"),
+ "bar": StringAttr.get("bar_value"),
+ },
+ )
+ # CHECK: %0:2 = "custom.op1"() ({
+ # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
+ print(op1)
+
+ # TODO: Check successors once enough infra exists to do it properly.
# CHECK-LABEL: TEST: testOperationInsertionPoint
@run
def testOperationInsertionPoint():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- module = Module.parse(
- r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
- """, ctx)
-
- # Create test op.
- with Location.unknown(ctx):
- op1 = Operation.create("custom.op1")
- op2 = Operation.create("custom.op2")
-
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
- ip = InsertionPoint.at_block_begin(entry_block)
- ip.insert(op1)
- ip.insert(op2)
- # CHECK: func @f1
- # CHECK: "custom.op1"()
- # CHECK: "custom.op2"()
- # CHECK: %0 = "custom.addi"
- print(module)
-
- # Trying to add a previously added op should raise.
- try:
- ip.insert(op1)
- except ValueError:
- pass
- else:
- assert False, "expected insert of attached op to raise"
+ """,
+ ctx,
+ )
+
+ # Create test op.
+ with Location.unknown(ctx):
+ op1 = Operation.create("custom.op1")
+ op2 = Operation.create("custom.op2")
+
+ func = module.body.operations[0]
+ entry_block = func.regions[0].blocks[0]
+ ip = InsertionPoint.at_block_begin(entry_block)
+ ip.insert(op1)
+ ip.insert(op2)
+ # CHECK: func @f1
+ # CHECK: "custom.op1"()
+ # CHECK: "custom.op2"()
+ # CHECK: %0 = "custom.addi"
+ print(module)
+
+ # Trying to add a previously added op should raise.
+ try:
+ ip.insert(op1)
+ except ValueError:
+ pass
+ else:
+ assert False, "expected insert of attached op to raise"
# CHECK-LABEL: TEST: testOperationWithRegion
@run
def testOperationWithRegion():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- i32 = IntegerType.get_signed(32)
- op1 = Operation.create("custom.op1", regions=1)
- block = op1.regions[0].blocks.append(i32, i32)
- # CHECK: "custom.op1"() ({
- # CHECK: ^bb0(%arg0: si32, %arg1: si32):
- # CHECK: "custom.terminator"() : () -> ()
- # CHECK: }) : () -> ()
- terminator = Operation.create("custom.terminator")
- ip = InsertionPoint(block)
- ip.insert(terminator)
- print(op1)
-
- # Now add the whole operation to another op.
- # TODO: Verify lifetime hazard by nulling out the new owning module and
- # accessing op1.
- # TODO: Also verify accessing the terminator once both parents are nulled
- # out.
- module = Module.parse(r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signed(32)
+ op1 = Operation.create("custom.op1", regions=1)
+ block = op1.regions[0].blocks.append(i32, i32)
+ # CHECK: "custom.op1"() ({
+ # CHECK: ^bb0(%arg0: si32, %arg1: si32):
+ # CHECK: "custom.terminator"() : () -> ()
+ # CHECK: }) : () -> ()
+ terminator = Operation.create("custom.terminator")
+ ip = InsertionPoint(block)
+ ip.insert(terminator)
+ print(op1)
+
+ # Now add the whole operation to another op.
+ # TODO: Verify lifetime hazard by nulling out the new owning module and
+ # accessing op1.
+ # TODO: Also verify accessing the terminator once both parents are nulled
+ # out.
+ module = Module.parse(
+ r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
- """)
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
- ip = InsertionPoint.at_block_begin(entry_block)
- ip.insert(op1)
- # CHECK: func @f1
- # CHECK: "custom.op1"()
- # CHECK: "custom.terminator"
- # CHECK: %0 = "custom.addi"
- print(module)
+ """
+ )
+ func = module.body.operations[0]
+ entry_block = func.regions[0].blocks[0]
+ ip = InsertionPoint.at_block_begin(entry_block)
+ ip.insert(op1)
+ # CHECK: func @f1
+ # CHECK: "custom.op1"()
+ # CHECK: "custom.terminator"
+ # CHECK: %0 = "custom.addi"
+ print(module)
# CHECK-LABEL: TEST: testOperationResultList
@run
def testOperationResultList():
- ctx = Context()
- module = Module.parse(
- r"""
+ ctx = Context()
+ module = Module.parse(
+ r"""
func.func @f1() {
%0:3 = call @f2() : () -> (i32, f64, index)
return
}
func.func private @f2() -> (i32, f64, index)
- """, ctx)
- caller = module.body.operations[0]
- call = caller.regions[0].blocks[0].operations[0]
- assert len(call.results) == 3
- # CHECK: Result 0, type i32
- # CHECK: Result 1, type f64
- # CHECK: Result 2, type index
- for res in call.results:
- print(f"Result {res.result_number}, type {res.type}")
-
- # CHECK: Result type i32
- # CHECK: Result type f64
- # CHECK: Result type index
- for t in call.results.types:
- print(f"Result type {t}")
-
- # Out of range
- expect_index_error(lambda: call.results[3])
- expect_index_error(lambda: call.results[-4])
+ """,
+ ctx,
+ )
+ caller = module.body.operations[0]
+ call = caller.regions[0].blocks[0].operations[0]
+ assert len(call.results) == 3
+ # CHECK: Result 0, type i32
+ # CHECK: Result 1, type f64
+ # CHECK: Result 2, type index
+ for res in call.results:
+ print(f"Result {res.result_number}, type {res.type}")
+
+ # CHECK: Result type i32
+ # CHECK: Result type f64
+ # CHECK: Result type index
+ for t in call.results.types:
+ print(f"Result type {t}")
+
+ # Out of range
+ expect_index_error(lambda: call.results[3])
+ expect_index_error(lambda: call.results[-4])
# CHECK-LABEL: TEST: testOperationResultListSlice
@run
def testOperationResultListSlice():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
- module = Module.parse(r"""
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @f1() {
"some.op"() : () -> (i1, i2, i3, i4, i5)
return
}
- """)
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
- producer = entry_block.operations[0]
-
- assert len(producer.results) == 5
- for left, right in zip(producer.results, producer.results[::-1][::-1]):
- assert left == right
- assert left.result_number == right.result_number
-
- # CHECK: Result 0, type i1
- # CHECK: Result 1, type i2
- # CHECK: Result 2, type i3
- # CHECK: Result 3, type i4
- # CHECK: Result 4, type i5
- full_slice = producer.results[:]
- for res in full_slice:
- print(f"Result {res.result_number}, type {res.type}")
-
- # CHECK: Result 1, type i2
- # CHECK: Result 2, type i3
- # CHECK: Result 3, type i4
- middle = producer.results[1:4]
- for res in middle:
- print(f"Result {res.result_number}, type {res.type}")
-
- # CHECK: Result 1, type i2
- # CHECK: Result 3, type i4
- odd = producer.results[1::2]
- for res in odd:
- print(f"Result {res.result_number}, type {res.type}")
-
- # CHECK: Result 3, type i4
- # CHECK: Result 1, type i2
- inverted_middle = producer.results[-2:0:-2]
- for res in inverted_middle:
- print(f"Result {res.result_number}, type {res.type}")
+ """
+ )
+ func = module.body.operations[0]
+ entry_block = func.regions[0].blocks[0]
+ producer = entry_block.operations[0]
+
+ assert len(producer.results) == 5
+ for left, right in zip(producer.results, producer.results[::-1][::-1]):
+ assert left == right
+ assert left.result_number == right.result_number
+
+ # CHECK: Result 0, type i1
+ # CHECK: Result 1, type i2
+ # CHECK: Result 2, type i3
+ # CHECK: Result 3, type i4
+ # CHECK: Result 4, type i5
+ full_slice = producer.results[:]
+ for res in full_slice:
+ print(f"Result {res.result_number}, type {res.type}")
+
+ # CHECK: Result 1, type i2
+ # CHECK: Result 2, type i3
+ # CHECK: Result 3, type i4
+ middle = producer.results[1:4]
+ for res in middle:
+ print(f"Result {res.result_number}, type {res.type}")
+
+ # CHECK: Result 1, type i2
+ # CHECK: Result 3, type i4
+ odd = producer.results[1::2]
+ for res in odd:
+ print(f"Result {res.result_number}, type {res.type}")
+
+ # CHECK: Result 3, type i4
+ # CHECK: Result 1, type i2
+ inverted_middle = producer.results[-2:0:-2]
+ for res in inverted_middle:
+ print(f"Result {res.result_number}, type {res.type}")
# CHECK-LABEL: TEST: testOperationAttributes
@run
def testOperationAttributes():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- module = Module.parse(
- r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
"some.op"() { some.attribute = 1 : i8,
other.attribute = 3.0,
dependent = "text" } : () -> ()
- """, ctx)
- op = module.body.operations[0]
- assert len(op.attributes) == 3
- iattr = IntegerAttr(op.attributes["some.attribute"])
- fattr = FloatAttr(op.attributes["other.attribute"])
- sattr = StringAttr(op.attributes["dependent"])
- # CHECK: Attribute type i8, value 1
- print(f"Attribute type {iattr.type}, value {iattr.value}")
- # CHECK: Attribute type f64, value 3.0
- print(f"Attribute type {fattr.type}, value {fattr.value}")
- # CHECK: Attribute value text
- print(f"Attribute value {sattr.value}")
- # CHECK: Attribute value b'text'
- print(f"Attribute value {sattr.value_bytes}")
-
- # We don't know in which order the attributes are stored.
- # CHECK-DAG: NamedAttribute(dependent="text")
- # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
- # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
- for attr in op.attributes:
- print(str(attr))
-
- # Check that exceptions are raised as expected.
- try:
- op.attributes["does_not_exist"]
- except KeyError:
- pass
- else:
- assert False, "expected KeyError on accessing a non-existent attribute"
-
- try:
- op.attributes[42]
- except IndexError:
- pass
- else:
- assert False, "expected IndexError on accessing an out-of-bounds attribute"
-
+ """,
+ ctx,
+ )
+ op = module.body.operations[0]
+ assert len(op.attributes) == 3
+ iattr = IntegerAttr(op.attributes["some.attribute"])
+ fattr = FloatAttr(op.attributes["other.attribute"])
+ sattr = StringAttr(op.attributes["dependent"])
+ # CHECK: Attribute type i8, value 1
+ print(f"Attribute type {iattr.type}, value {iattr.value}")
+ # CHECK: Attribute type f64, value 3.0
+ print(f"Attribute type {fattr.type}, value {fattr.value}")
+ # CHECK: Attribute value text
+ print(f"Attribute value {sattr.value}")
+ # CHECK: Attribute value b'text'
+ print(f"Attribute value {sattr.value_bytes}")
+
+ # We don't know in which order the attributes are stored.
+ # CHECK-DAG: NamedAttribute(dependent="text")
+ # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
+ # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
+ for attr in op.attributes:
+ print(str(attr))
+
+ # Check that exceptions are raised as expected.
+ try:
+ op.attributes["does_not_exist"]
+ except KeyError:
+ pass
+ else:
+ assert False, "expected KeyError on accessing a non-existent attribute"
+ try:
+ op.attributes[42]
+ except IndexError:
+ pass
+ else:
+ assert False, "expected IndexError on accessing an out-of-bounds attribute"
# CHECK-LABEL: TEST: testOperationPrint
@run
def testOperationPrint():
- ctx = Context()
- module = Module.parse(
- r"""
+ ctx = Context()
+ module = Module.parse(
+ r"""
func.func @f1(%arg0: i32) -> i32 {
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
return %arg0 : i32
}
- """, ctx)
-
- # Test print to stdout.
- # CHECK: return %arg0 : i32
- module.operation.print()
-
- # Test print to text file.
- f = io.StringIO()
- # CHECK: <class 'str'>
- # CHECK: return %arg0 : i32
- module.operation.print(file=f)
- str_value = f.getvalue()
- print(str_value.__class__)
- print(f.getvalue())
-
- # Test roundtrip to bytecode.
- bytecode_stream = io.BytesIO()
- module.operation.write_bytecode(bytecode_stream, desired_version=1)
- bytecode = bytecode_stream.getvalue()
- assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR"
- module_roundtrip = Module.parse(bytecode, ctx)
- f = io.StringIO()
- module_roundtrip.operation.print(file=f)
- roundtrip_value = f.getvalue()
- assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
-
-
- # Test print to binary file.
- f = io.BytesIO()
- # CHECK: <class 'bytes'>
- # CHECK: return %arg0 : i32
- module.operation.print(file=f, binary=True)
- bytes_value = f.getvalue()
- print(bytes_value.__class__)
- print(bytes_value)
-
- # Test get_asm local_scope.
- # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
- module.operation.print(enable_debug_info=True, use_local_scope=True)
-
- # Test get_asm with options.
- # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
- # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
- module.operation.print(
- large_elements_limit=2,
- enable_debug_info=True,
- pretty_debug_info=True,
- print_generic_op_form=True,
- use_local_scope=True)
-
-
+ """,
+ ctx,
+ )
+
+ # Test print to stdout.
+ # CHECK: return %arg0 : i32
+ module.operation.print()
+
+ # Test print to text file.
+ f = io.StringIO()
+ # CHECK: <class 'str'>
+ # CHECK: return %arg0 : i32
+ module.operation.print(file=f)
+ str_value = f.getvalue()
+ print(str_value.__class__)
+ print(f.getvalue())
+
+ # Test roundtrip to bytecode.
+ bytecode_stream = io.BytesIO()
+ module.operation.write_bytecode(bytecode_stream, desired_version=1)
+ bytecode = bytecode_stream.getvalue()
+ assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
+ module_roundtrip = Module.parse(bytecode, ctx)
+ f = io.StringIO()
+ module_roundtrip.operation.print(file=f)
+ roundtrip_value = f.getvalue()
+ assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
+
+ # Test print to binary file.
+ f = io.BytesIO()
+ # CHECK: <class 'bytes'>
+ # CHECK: return %arg0 : i32
+ module.operation.print(file=f, binary=True)
+ bytes_value = f.getvalue()
+ print(bytes_value.__class__)
+ print(bytes_value)
+
+ # Test get_asm local_scope.
+ # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
+ module.operation.print(enable_debug_info=True, use_local_scope=True)
+
+ # Test get_asm with options.
+ # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
+ # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
+ module.operation.print(
+ large_elements_limit=2,
+ enable_debug_info=True,
+ pretty_debug_info=True,
+ print_generic_op_form=True,
+ use_local_scope=True,
+ )
# CHECK-LABEL: TEST: testKnownOpView
@run
def testKnownOpView():
- with Context(), Location.unknown():
- Context.current.allow_unregistered_dialects = True
- module = Module.parse(r"""
+ with Context(), Location.unknown():
+ Context.current.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
%1 = "custom.f32"() : () -> f32
%2 = "custom.f32"() : () -> f32
%3 = arith.addf %1, %2 : f32
- """)
- print(module)
+ """
+ )
+ print(module)
- # addf should map to a known OpView class in the arithmetic dialect.
- # We know the OpView for it defines an 'lhs' attribute.
- addf = module.body.operations[2]
- # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
- print(repr(addf))
- # CHECK: "custom.f32"()
- print(addf.lhs)
+ # addf should map to a known OpView class in the arithmetic dialect.
+ # We know the OpView for it defines an 'lhs' attribute.
+ addf = module.body.operations[2]
+ # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
+ print(repr(addf))
+ # CHECK: "custom.f32"()
+ print(addf.lhs)
- # One of the custom ops should resolve to the default OpView.
- custom = module.body.operations[0]
- # CHECK: OpView object
- print(repr(custom))
+ # One of the custom ops should resolve to the default OpView.
+ custom = module.body.operations[0]
+ # CHECK: OpView object
+ print(repr(custom))
- # Check again to make sure negative caching works.
- custom = module.body.operations[0]
- # CHECK: OpView object
- print(repr(custom))
+ # Check again to make sure negative caching works.
+ custom = module.body.operations[0]
+ # CHECK: OpView object
+ print(repr(custom))
# CHECK-LABEL: TEST: testSingleResultProperty
@run
def testSingleResultProperty():
- with Context(), Location.unknown():
- Context.current.allow_unregistered_dialects = True
- module = Module.parse(r"""
+ with Context(), Location.unknown():
+ Context.current.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
"custom.no_result"() : () -> ()
%0:2 = "custom.two_result"() : () -> (f32, f32)
%1 = "custom.one_result"() : () -> f32
- """)
- print(module)
+ """
+ )
+ print(module)
- try:
- module.body.operations[0].result
- except ValueError as e:
- # CHECK: Cannot call .result on operation custom.no_result which has 0 results
- print(e)
- else:
- assert False, "Expected exception"
+ try:
+ module.body.operations[0].result
+ except ValueError as e:
+ # CHECK: Cannot call .result on operation custom.no_result which has 0 results
+ print(e)
+ else:
+ assert False, "Expected exception"
- try:
- module.body.operations[1].result
- except ValueError as e:
- # CHECK: Cannot call .result on operation custom.two_result which has 2 results
- print(e)
- else:
- assert False, "Expected exception"
+ try:
+ module.body.operations[1].result
+ except ValueError as e:
+ # CHECK: Cannot call .result on operation custom.two_result which has 2 results
+ print(e)
+ else:
+ assert False, "Expected exception"
- # CHECK: %1 = "custom.one_result"() : () -> f32
- print(module.body.operations[2])
+ # CHECK: %1 = "custom.one_result"() : () -> f32
+ print(module.body.operations[2])
def create_invalid_operation():
- # This module has two region and is invalid verify that we fallback
- # to the generic printer for safety.
- op = Operation.create("builtin.module", regions=2)
- op.regions[0].blocks.append()
- return op
+ # This module has two region and is invalid verify that we fallback
+ # to the generic printer for safety.
+ op = Operation.create("builtin.module", regions=2)
+ op.regions[0].blocks.append()
+ return op
+
# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
@run
def testInvalidOperationStrSoftFails():
- ctx = Context()
- with Location.unknown(ctx):
- invalid_op = create_invalid_operation()
- # Verify that we fallback to the generic printer for safety.
- # CHECK: "builtin.module"() ({
- # CHECK: }) : () -> ()
- print(invalid_op)
- try:
- invalid_op.verify()
- except MLIRError as e:
- # CHECK: Exception: <
- # CHECK: Verification failed:
- # CHECK: error: unknown: 'builtin.module' op requires one region
- # CHECK: note: unknown: see current operation:
- # CHECK: "builtin.module"() ({
- # CHECK: ^bb0:
- # CHECK: }, {
- # CHECK: }) : () -> ()
- # CHECK: >
- print(f"Exception: <{e}>")
+ ctx = Context()
+ with Location.unknown(ctx):
+ invalid_op = create_invalid_operation()
+ # Verify that we fallback to the generic printer for safety.
+ # CHECK: "builtin.module"() ({
+ # CHECK: }) : () -> ()
+ print(invalid_op)
+ try:
+ invalid_op.verify()
+ except MLIRError as e:
+ # CHECK: Exception: <
+ # CHECK: Verification failed:
+ # CHECK: error: unknown: 'builtin.module' op requires one region
+ # CHECK: note: unknown: see current operation:
+ # CHECK: "builtin.module"() ({
+ # CHECK: ^bb0:
+ # CHECK: }, {
+ # CHECK: }) : () -> ()
+ # CHECK: >
+ print(f"Exception: <{e}>")
# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
@run
def testInvalidModuleStrSoftFails():
- ctx = Context()
- with Location.unknown(ctx):
- module = Module.create()
- with InsertionPoint(module.body):
- invalid_op = create_invalid_operation()
- # Verify that we fallback to the generic printer for safety.
- # CHECK: "builtin.module"() ({
- # CHECK: }) : () -> ()
- print(module)
+ ctx = Context()
+ with Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ invalid_op = create_invalid_operation()
+ # Verify that we fallback to the generic printer for safety.
+ # CHECK: "builtin.module"() ({
+ # CHECK: }) : () -> ()
+ print(module)
# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
@run
def testInvalidOperationGetAsmBinarySoftFails():
- ctx = Context()
- with Location.unknown(ctx):
- invalid_op = create_invalid_operation()
- # Verify that we fallback to the generic printer for safety.
- # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
- print(invalid_op.get_asm(binary=True))
+ ctx = Context()
+ with Location.unknown(ctx):
+ invalid_op = create_invalid_operation()
+ # Verify that we fallback to the generic printer for safety.
+ # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
+ print(invalid_op.get_asm(binary=True))
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
@run
def testCreateWithInvalidAttributes():
- ctx = Context()
- with Location.unknown(ctx):
- try:
- Operation.create(
- "builtin.module", attributes={None: StringAttr.get("name")})
- except Exception as e:
- # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
- print(e)
- try:
- Operation.create(
- "builtin.module", attributes={42: StringAttr.get("name")})
- except Exception as e:
- # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
- print(e)
- try:
- Operation.create("builtin.module", attributes={"some_key": ctx})
- except Exception as e:
- # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
- print(e)
- try:
- Operation.create("builtin.module", attributes={"some_key": None})
- except Exception as e:
- # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
- print(e)
+ ctx = Context()
+ with Location.unknown(ctx):
+ try:
+ Operation.create(
+ "builtin.module", attributes={None: StringAttr.get("name")}
+ )
+ except Exception as e:
+ # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
+ print(e)
+ try:
+ Operation.create("builtin.module", attributes={42: StringAttr.get("name")})
+ except Exception as e:
+ # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
+ print(e)
+ try:
+ Operation.create("builtin.module", attributes={"some_key": ctx})
+ except Exception as e:
+ # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
+ print(e)
+ try:
+ Operation.create("builtin.module", attributes={"some_key": None})
+ except Exception as e:
+ # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
+ print(e)
# CHECK-LABEL: TEST: testOperationName
@run
def testOperationName():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- module = Module.parse(
- r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
%0 = "custom.op1"() : () -> f32
%1 = "custom.op2"() : () -> i32
%2 = "custom.op1"() : () -> f32
- """, ctx)
+ """,
+ ctx,
+ )
- # CHECK: custom.op1
- # CHECK: custom.op2
- # CHECK: custom.op1
- for op in module.body.operations:
- print(op.operation.name)
+ # CHECK: custom.op1
+ # CHECK: custom.op2
+ # CHECK: custom.op1
+ for op in module.body.operations:
+ print(op.operation.name)
# CHECK-LABEL: TEST: testCapsuleConversions
@run
def testCapsuleConversions():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- m = Operation.create("custom.op1").operation
- m_capsule = m._CAPIPtr
- assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
- m2 = Operation._CAPICreate(m_capsule)
- assert m2 is m
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ m = Operation.create("custom.op1").operation
+ m_capsule = m._CAPIPtr
+ assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
+ m2 = Operation._CAPICreate(m_capsule)
+ assert m2 is m
# CHECK-LABEL: TEST: testOperationErase
@run
def testOperationErase():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- m = Module.create()
- with InsertionPoint(m.body):
- op = Operation.create("custom.op1")
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ m = Module.create()
+ with InsertionPoint(m.body):
+ op = Operation.create("custom.op1")
- # CHECK: "custom.op1"
- print(m)
+ # CHECK: "custom.op1"
+ print(m)
- op.operation.erase()
+ op.operation.erase()
- # CHECK-NOT: "custom.op1"
- print(m)
+ # CHECK-NOT: "custom.op1"
+ print(m)
- # Ensure we can create another operation
- Operation.create("custom.op2")
+ # Ensure we can create another operation
+ Operation.create("custom.op2")
# CHECK-LABEL: TEST: testOperationClone
@run
def testOperationClone():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- m = Module.create()
- with InsertionPoint(m.body):
- op = Operation.create("custom.op1")
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ m = Module.create()
+ with InsertionPoint(m.body):
+ op = Operation.create("custom.op1")
- # CHECK: "custom.op1"
- print(m)
+ # CHECK: "custom.op1"
+ print(m)
- clone = op.operation.clone()
- op.operation.erase()
+ clone = op.operation.clone()
+ op.operation.erase()
- # CHECK: "custom.op1"
- print(m)
+ # CHECK: "custom.op1"
+ print(m)
# CHECK-LABEL: TEST: testOperationLoc
@run
def testOperationLoc():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with ctx:
- loc = Location.name("loc")
- op = Operation.create("custom.op", loc=loc)
- assert op.location == loc
- assert op.operation.location == loc
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with ctx:
+ loc = Location.name("loc")
+ op = Operation.create("custom.op", loc=loc)
+ assert op.location == loc
+ assert op.operation.location == loc
# CHECK-LABEL: TEST: testModuleMerge
@run
def testModuleMerge():
- with Context():
- m1 = Module.parse("func.func private @foo()")
- m2 = Module.parse("""
+ with Context():
+ m1 = Module.parse("func.func private @foo()")
+ m2 = Module.parse(
+ """
func.func private @bar()
func.func private @qux()
- """)
- foo = m1.body.operations[0]
- bar = m2.body.operations[0]
- qux = m2.body.operations[1]
- bar.move_before(foo)
- qux.move_after(foo)
+ """
+ )
+ foo = m1.body.operations[0]
+ bar = m2.body.operations[0]
+ qux = m2.body.operations[1]
+ bar.move_before(foo)
+ qux.move_after(foo)
- # CHECK: module
- # CHECK: func private @bar
- # CHECK: func private @foo
- # CHECK: func private @qux
- print(m1)
+ # CHECK: module
+ # CHECK: func private @bar
+ # CHECK: func private @foo
+ # CHECK: func private @qux
+ print(m1)
- # CHECK: module {
- # CHECK-NEXT: }
- print(m2)
+ # CHECK: module {
+ # CHECK-NEXT: }
+ print(m2)
# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
@run
def testAppendMoveFromAnotherBlock():
- with Context():
- m1 = Module.parse("func.func private @foo()")
- m2 = Module.parse("func.func private @bar()")
- func = m1.body.operations[0]
- m2.body.append(func)
+ with Context():
+ m1 = Module.parse("func.func private @foo()")
+ m2 = Module.parse("func.func private @bar()")
+ func = m1.body.operations[0]
+ m2.body.append(func)
- # CHECK: module
- # CHECK: func private @bar
- # CHECK: func private @foo
+ # CHECK: module
+ # CHECK: func private @bar
+ # CHECK: func private @foo
- print(m2)
- # CHECK: module {
- # CHECK-NEXT: }
- print(m1)
+ print(m2)
+ # CHECK: module {
+ # CHECK-NEXT: }
+ print(m1)
# CHECK-LABEL: TEST: testDetachFromParent
@run
def testDetachFromParent():
- with Context():
- m1 = Module.parse("func.func private @foo()")
- func = m1.body.operations[0].detach_from_parent()
+ with Context():
+ m1 = Module.parse("func.func private @foo()")
+ func = m1.body.operations[0].detach_from_parent()
- try:
- func.detach_from_parent()
- except ValueError as e:
- if "has no parent" not in str(e):
- raise
- else:
- assert False, "expected ValueError when detaching a detached operation"
+ try:
+ func.detach_from_parent()
+ except ValueError as e:
+ if "has no parent" not in str(e):
+ raise
+ else:
+ assert False, "expected ValueError when detaching a detached operation"
- print(m1)
- # CHECK-NOT: func private @foo
+ print(m1)
+ # CHECK-NOT: func private @foo
# CHECK-LABEL: TEST: testOperationHash
@run
def testOperationHash():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with ctx, Location.unknown():
- op = Operation.create("custom.op1")
- assert hash(op) == hash(op.operation)
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with ctx, Location.unknown():
+ op = Operation.create("custom.op1")
+ assert hash(op) == hash(op.operation)
# CHECK-LABEL: TEST: testOperationParse
@run
def testOperationParse():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
-
- # Generic operation parsing.
- m = Operation.parse('module {}')
- o = Operation.parse('"test.foo"() : () -> ()')
- assert isinstance(m, ModuleOp)
- assert type(o) is OpView
-
- # Parsing specific operation.
- m = ModuleOp.parse('module {}')
- assert isinstance(m, ModuleOp)
- try:
- ModuleOp.parse('"test.foo"() : () -> ()')
- except MLIRError as e:
- # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
- print(f"error: {e}")
- else:
- assert False, "expected error"
-
- o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
- # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
- print(f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}")
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+
+ # Generic operation parsing.
+ m = Operation.parse("module {}")
+ o = Operation.parse('"test.foo"() : () -> ()')
+ assert isinstance(m, ModuleOp)
+ assert type(o) is OpView
+
+ # Parsing specific operation.
+ m = ModuleOp.parse("module {}")
+ assert isinstance(m, ModuleOp)
+ try:
+ ModuleOp.parse('"test.foo"() : () -> ()')
+ except MLIRError as e:
+ # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
+ print(f"error: {e}")
+ else:
+ assert False, "expected error"
+
+ o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
+ # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
+ print(
+ f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
+ )
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 9ce89591c5e12..17f3e354bee2b 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -7,150 +7,162 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testSymbolTableInsert
@run
def testSymbolTableInsert():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
- m1 = Module.parse("""
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ m1 = Module.parse(
+ """
func.func private @foo()
- func.func private @bar()""")
- m2 = Module.parse("""
+ func.func private @bar()"""
+ )
+ m2 = Module.parse(
+ """
func.func private @qux()
func.func private @foo()
- "foo.bar"() : () -> ()""")
-
- symbol_table = SymbolTable(m1.operation)
-
- # CHECK: func private @foo
- # CHECK: func private @bar
- assert "foo" in symbol_table
- print(symbol_table["foo"])
- assert "bar" in symbol_table
- bar = symbol_table["bar"]
- print(symbol_table["bar"])
-
- assert "qux" not in symbol_table
-
- del symbol_table["bar"]
- try:
- symbol_table.erase(symbol_table["bar"])
- except KeyError:
- pass
- else:
- assert False, "expected KeyError"
-
- # CHECK: module
- # CHECK: func private @foo()
- print(m1)
- assert "bar" not in symbol_table
-
- try:
- print(bar)
- except RuntimeError as e:
- if "the operation has been invalidated" not in str(e):
- raise
- else:
- assert False, "expected RuntimeError due to invalidated operation"
-
- qux = m2.body.operations[0]
- m1.body.append(qux)
- symbol_table.insert(qux)
- assert "qux" in symbol_table
-
- # Check that insertion actually renames this symbol in the symbol table.
- foo2 = m2.body.operations[0]
- m1.body.append(foo2)
- updated_name = symbol_table.insert(foo2)
- assert foo2.name.value != "foo"
- assert foo2.name == updated_name
-
- # CHECK: module
- # CHECK: func private @foo()
- # CHECK: func private @qux()
- # CHECK: func private @foo{{.*}}
- print(m1)
-
- try:
- symbol_table.insert(m2.body.operations[0])
- except ValueError as e:
- if "Expected operation to have a symbol name" not in str(e):
- raise
- else:
- assert False, "exepcted ValueError when adding a non-symbol"
+ "foo.bar"() : () -> ()"""
+ )
+
+ symbol_table = SymbolTable(m1.operation)
+
+ # CHECK: func private @foo
+ # CHECK: func private @bar
+ assert "foo" in symbol_table
+ print(symbol_table["foo"])
+ assert "bar" in symbol_table
+ bar = symbol_table["bar"]
+ print(symbol_table["bar"])
+
+ assert "qux" not in symbol_table
+
+ del symbol_table["bar"]
+ try:
+ symbol_table.erase(symbol_table["bar"])
+ except KeyError:
+ pass
+ else:
+ assert False, "expected KeyError"
+
+ # CHECK: module
+ # CHECK: func private @foo()
+ print(m1)
+ assert "bar" not in symbol_table
+
+ try:
+ print(bar)
+ except RuntimeError as e:
+ if "the operation has been invalidated" not in str(e):
+ raise
+ else:
+ assert False, "expected RuntimeError due to invalidated operation"
+
+ qux = m2.body.operations[0]
+ m1.body.append(qux)
+ symbol_table.insert(qux)
+ assert "qux" in symbol_table
+
+ # Check that insertion actually renames this symbol in the symbol table.
+ foo2 = m2.body.operations[0]
+ m1.body.append(foo2)
+ updated_name = symbol_table.insert(foo2)
+ assert foo2.name.value != "foo"
+ assert foo2.name == updated_name
+
+ # CHECK: module
+ # CHECK: func private @foo()
+ # CHECK: func private @qux()
+ # CHECK: func private @foo{{.*}}
+ print(m1)
+
+ try:
+ symbol_table.insert(m2.body.operations[0])
+ except ValueError as e:
+ if "Expected operation to have a symbol name" not in str(e):
+ raise
+ else:
+ assert False, "exepcted ValueError when adding a non-symbol"
# CHECK-LABEL: testSymbolTableRAUW
@run
def testSymbolTableRAUW():
- with Context() as ctx:
- m = Module.parse("""
+ with Context() as ctx:
+ m = Module.parse(
+ """
func.func private @foo() {
call @bar() : () -> ()
return
}
func.func private @bar()
- """)
- foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
- SymbolTable.set_symbol_name(bar, "bam")
- # Note that module.operation counts as a "nested symbol table" which won't
- # be traversed into, so it is necessary to traverse its children.
- SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
- # CHECK: call @bam()
- # CHECK: func private @bam
- print(m)
- # CHECK: Foo symbol: "foo"
- # CHECK: Bar symbol: "bam"
- print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
- print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")
+ """
+ )
+ foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
+ SymbolTable.set_symbol_name(bar, "bam")
+ # Note that module.operation counts as a "nested symbol table" which won't
+ # be traversed into, so it is necessary to traverse its children.
+ SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
+ # CHECK: call @bam()
+ # CHECK: func private @bam
+ print(m)
+ # CHECK: Foo symbol: "foo"
+ # CHECK: Bar symbol: "bam"
+ print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
+ print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")
# CHECK-LABEL: testSymbolTableVisibility
@run
def testSymbolTableVisibility():
- with Context() as ctx:
- m = Module.parse("""
+ with Context() as ctx:
+ m = Module.parse(
+ """
func.func private @foo() {
return
}
- """)
- foo = m.operation.regions[0].blocks[0].operations[0]
- # CHECK: Existing visibility: "private"
- print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
- SymbolTable.set_visibility(foo, "public")
- # CHECK: func public @foo
- print(m)
+ """
+ )
+ foo = m.operation.regions[0].blocks[0].operations[0]
+ # CHECK: Existing visibility: "private"
+ print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
+ SymbolTable.set_visibility(foo, "public")
+ # CHECK: func public @foo
+ print(m)
# CHECK: testWalkSymbolTables
@run
def testWalkSymbolTables():
- with Context() as ctx:
- m = Module.parse("""
+ with Context() as ctx:
+ m = Module.parse(
+ """
module @outer {
module @inner{
}
}
- """)
- def callback(symbol_table_op, uses_visible):
- print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
- # CHECK: SYMBOL TABLE: True: module @inner
- # CHECK: SYMBOL TABLE: True: module @outer
- SymbolTable.walk_symbol_tables(m.operation, True, callback)
-
- # Make sure exceptions in the callback are handled.
- def error_callback(symbol_table_op, uses_visible):
- assert False, "Raised from python"
- try:
- SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
- except RuntimeError as e:
- # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
- print(f"GOT EXCEPTION: {e}")
+ """
+ )
+ def callback(symbol_table_op, uses_visible):
+ print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
+
+ # CHECK: SYMBOL TABLE: True: module @inner
+ # CHECK: SYMBOL TABLE: True: module @outer
+ SymbolTable.walk_symbol_tables(m.operation, True, callback)
+
+ # Make sure exceptions in the callback are handled.
+ def error_callback(symbol_table_op, uses_visible):
+ assert False, "Raised from python"
+
+ try:
+ SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
+ except RuntimeError as e:
+ # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
+ print(f"GOT EXCEPTION: {e}")
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 66568c426216a..8a2ada1f78f1c 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -6,229 +6,235 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testCapsuleConversions
@run
def testCapsuleConversions():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- i32 = IntegerType.get_signless(32)
- value = Operation.create("custom.op1", results=[i32]).result
- value_capsule = value._CAPIPtr
- assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
- value2 = Value._CAPICreate(value_capsule)
- assert value2 == value
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ value = Operation.create("custom.op1", results=[i32]).result
+ value_capsule = value._CAPIPtr
+ assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
+ value2 = Value._CAPICreate(value_capsule)
+ assert value2 == value
# CHECK-LABEL: TEST: testOpResultOwner
@run
def testOpResultOwner():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- i32 = IntegerType.get_signless(32)
- op = Operation.create("custom.op1", results=[i32])
- assert op.result.owner == op
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ op = Operation.create("custom.op1", results=[i32])
+ assert op.result.owner == op
# CHECK-LABEL: TEST: testBlockArgOwner
@run
def testBlockArgOwner():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- module = Module.parse(
- r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @foo(%arg0: f32) {
return
- }""", ctx)
- func = module.body.operations[0]
- block = func.regions[0].blocks[0]
- assert block.arguments[0].owner == block
+ }""",
+ ctx,
+ )
+ func = module.body.operations[0]
+ block = func.regions[0].blocks[0]
+ assert block.arguments[0].owner == block
# CHECK-LABEL: TEST: testValueIsInstance
@run
def testValueIsInstance():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- module = Module.parse(
- r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @foo(%arg0: f32) {
%0 = "some_dialect.some_op"() : () -> f64
return
- }""", ctx)
- func = module.body.operations[0]
- assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
- assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
+ }""",
+ ctx,
+ )
+ func = module.body.operations[0]
+ assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
+ assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
- op = func.regions[0].blocks[0].operations[0]
- assert not BlockArgument.isinstance(op.results[0])
- assert OpResult.isinstance(op.results[0])
+ op = func.regions[0].blocks[0].operations[0]
+ assert not BlockArgument.isinstance(op.results[0])
+ assert OpResult.isinstance(op.results[0])
# CHECK-LABEL: TEST: testValueHash
@run
def testValueHash():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- module = Module.parse(
- r"""
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
func.func @foo(%arg0: f32) -> f32 {
%0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
return %0 : f32
- }""", ctx)
+ }""",
+ ctx,
+ )
- [func] = module.body.operations
- block = func.entry_block
- op, ret = block.operations
- assert hash(block.arguments[0]) == hash(op.operands[0])
- assert hash(op.result) == hash(ret.operands[0])
+ [func] = module.body.operations
+ block = func.entry_block
+ op, ret = block.operations
+ assert hash(block.arguments[0]) == hash(op.operands[0])
+ assert hash(op.result) == hash(ret.operands[0])
# CHECK-LABEL: TEST: testValueUses
@run
def testValueUses():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- i32 = IntegerType.get_signless(32)
- module = Module.create()
- with InsertionPoint(module.body):
- value = Operation.create("custom.op1", results=[i32]).results[0]
- op1 = Operation.create("custom.op2", operands=[value])
- op2 = Operation.create("custom.op2", operands=[value])
-
- # CHECK: Use owner: "custom.op2"
- # CHECK: Use operand_number: 0
- # CHECK: Use owner: "custom.op2"
- # CHECK: Use operand_number: 0
- for use in value.uses:
- assert use.owner in [op1, op2]
- print(f"Use owner: {use.owner}")
- print(f"Use operand_number: {use.operand_number}")
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ value = Operation.create("custom.op1", results=[i32]).results[0]
+ op1 = Operation.create("custom.op2", operands=[value])
+ op2 = Operation.create("custom.op2", operands=[value])
+
+ # CHECK: Use owner: "custom.op2"
+ # CHECK: Use operand_number: 0
+ # CHECK: Use owner: "custom.op2"
+ # CHECK: Use operand_number: 0
+ for use in value.uses:
+ assert use.owner in [op1, op2]
+ print(f"Use owner: {use.owner}")
+ print(f"Use operand_number: {use.operand_number}")
# CHECK-LABEL: TEST: testValueReplaceAllUsesWith
@run
def testValueReplaceAllUsesWith():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- i32 = IntegerType.get_signless(32)
- module = Module.create()
- with InsertionPoint(module.body):
- value = Operation.create("custom.op1", results=[i32]).results[0]
- op1 = Operation.create("custom.op2", operands=[value])
- op2 = Operation.create("custom.op2", operands=[value])
- value2 = Operation.create("custom.op3", results=[i32]).results[0]
- value.replace_all_uses_with(value2)
-
- assert len(list(value.uses)) == 0
-
- # CHECK: Use owner: "custom.op2"
- # CHECK: Use operand_number: 0
- # CHECK: Use owner: "custom.op2"
- # CHECK: Use operand_number: 0
- for use in value2.uses:
- assert use.owner in [op1, op2]
- print(f"Use owner: {use.owner}")
- print(f"Use operand_number: {use.operand_number}")
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ value = Operation.create("custom.op1", results=[i32]).results[0]
+ op1 = Operation.create("custom.op2", operands=[value])
+ op2 = Operation.create("custom.op2", operands=[value])
+ value2 = Operation.create("custom.op3", results=[i32]).results[0]
+ value.replace_all_uses_with(value2)
+
+ assert len(list(value.uses)) == 0
+
+ # CHECK: Use owner: "custom.op2"
+ # CHECK: Use operand_number: 0
+ # CHECK: Use owner: "custom.op2"
+ # CHECK: Use operand_number: 0
+ for use in value2.uses:
+ assert use.owner in [op1, op2]
+ print(f"Use owner: {use.owner}")
+ print(f"Use operand_number: {use.operand_number}")
# CHECK-LABEL: TEST: testValuePrintAsOperand
@run
def testValuePrintAsOperand():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- i32 = IntegerType.get_signless(32)
- module = Module.create()
- with InsertionPoint(module.body):
- value = Operation.create("custom.op1", results=[i32]).results[0]
- # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
- print(value)
-
- value2 = Operation.create("custom.op2", results=[i32]).results[0]
- # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
- print(value2)
-
- f = func.FuncOp("test", ([i32, i32], []))
- entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32])
-
- with InsertionPoint(entry_block1):
- value3 = Operation.create("custom.op3", results=[i32]).results[0]
- # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
- print(value3)
- value4 = Operation.create("custom.op4", results=[i32]).results[0]
- # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
- print(value4)
-
- f = func.FuncOp("test", ([i32, i32], []))
- entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
- with InsertionPoint(entry_block2):
- value5 = Operation.create("custom.op5", results=[i32]).results[0]
- # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
- print(value5)
- value6 = Operation.create("custom.op6", results=[i32]).results[0]
- # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
- print(value6)
-
- func.ReturnOp([])
-
- func.ReturnOp([])
-
- # CHECK: %[[VAL1]]
- print(value.get_name())
- # CHECK: %[[VAL2]]
- print(value2.get_name())
- # CHECK: %[[VAL3]]
- print(value3.get_name())
- # CHECK: %[[VAL4]]
- print(value4.get_name())
-
- # CHECK: %0
- print(value3.get_name(use_local_scope=True))
- # CHECK: %1
- print(value4.get_name(use_local_scope=True))
-
- # CHECK: %[[VAL5]]
- print(value5.get_name())
- # CHECK: %[[VAL6]]
- print(value6.get_name())
-
- # CHECK: %[[ARG0:.*]]
- print(entry_block1.arguments[0].get_name())
- # CHECK: %[[ARG1:.*]]
- print(entry_block1.arguments[1].get_name())
-
- # CHECK: %[[ARG2:.*]]
- print(entry_block2.arguments[0].get_name())
- # CHECK: %[[ARG3:.*]]
- print(entry_block2.arguments[1].get_name())
-
- # CHECK: module {
- # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32
- # CHECK: %[[VAL2]] = "custom.op2"() : () -> i32
- # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
- # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32
- # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32
- # CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
- # CHECK: %[[VAL5]] = "custom.op5"() : () -> i32
- # CHECK: %[[VAL6]] = "custom.op6"() : () -> i32
- # CHECK: return
- # CHECK: }
- # CHECK: return
- # CHECK: }
- # CHECK: }
- print(module)
-
- value2.owner.detach_from_parent()
- # CHECK: %0
- print(value2.get_name())
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ value = Operation.create("custom.op1", results=[i32]).results[0]
+ # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
+ print(value)
+
+ value2 = Operation.create("custom.op2", results=[i32]).results[0]
+ # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
+ print(value2)
+
+ f = func.FuncOp("test", ([i32, i32], []))
+ entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32])
+
+ with InsertionPoint(entry_block1):
+ value3 = Operation.create("custom.op3", results=[i32]).results[0]
+ # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
+ print(value3)
+ value4 = Operation.create("custom.op4", results=[i32]).results[0]
+ # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
+ print(value4)
+
+ f = func.FuncOp("test", ([i32, i32], []))
+ entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
+ with InsertionPoint(entry_block2):
+ value5 = Operation.create("custom.op5", results=[i32]).results[0]
+ # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
+ print(value5)
+ value6 = Operation.create("custom.op6", results=[i32]).results[0]
+ # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
+ print(value6)
+
+ func.ReturnOp([])
+
+ func.ReturnOp([])
+
+ # CHECK: %[[VAL1]]
+ print(value.get_name())
+ # CHECK: %[[VAL2]]
+ print(value2.get_name())
+ # CHECK: %[[VAL3]]
+ print(value3.get_name())
+ # CHECK: %[[VAL4]]
+ print(value4.get_name())
+
+ # CHECK: %0
+ print(value3.get_name(use_local_scope=True))
+ # CHECK: %1
+ print(value4.get_name(use_local_scope=True))
+
+ # CHECK: %[[VAL5]]
+ print(value5.get_name())
+ # CHECK: %[[VAL6]]
+ print(value6.get_name())
+
+ # CHECK: %[[ARG0:.*]]
+ print(entry_block1.arguments[0].get_name())
+ # CHECK: %[[ARG1:.*]]
+ print(entry_block1.arguments[1].get_name())
+
+ # CHECK: %[[ARG2:.*]]
+ print(entry_block2.arguments[0].get_name())
+ # CHECK: %[[ARG3:.*]]
+ print(entry_block2.arguments[1].get_name())
+
+ # CHECK: module {
+ # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32
+ # CHECK: %[[VAL2]] = "custom.op2"() : () -> i32
+ # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
+ # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32
+ # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32
+ # CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
+ # CHECK: %[[VAL5]] = "custom.op5"() : () -> i32
+ # CHECK: %[[VAL6]] = "custom.op6"() : () -> i32
+ # CHECK: return
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ # CHECK: }
+ print(module)
+
+ value2.owner.detach_from_parent()
+ # CHECK: %0
+ print(value2.get_name())
diff --git a/mlir/test/python/lit.local.cfg b/mlir/test/python/lit.local.cfg
index 8a98474044e0a..12d6e1f22744a 100644
--- a/mlir/test/python/lit.local.cfg
+++ b/mlir/test/python/lit.local.cfg
@@ -1,4 +1,4 @@
-config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
+config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
if not config.enable_bindings_python:
- config.unsupported = True
-config.excludes.add('python_test_ops.td')
+ config.unsupported = True
+config.excludes.add("python_test_ops.td")
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 8b276537dddcc..4b3a02ac42bd9 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -8,122 +8,140 @@
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
def log(*args):
- print(*args, file=sys.stderr)
- sys.stderr.flush()
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
def run(f):
- log("\nTEST:", f.__name__)
- f()
- gc.collect()
- assert Context._get_live_count() == 0
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
# Verify capsule interop.
# CHECK-LABEL: TEST: testCapsule
def testCapsule():
- with Context():
- pm = PassManager()
- pm_capsule = pm._CAPIPtr
- assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
- pm._testing_release()
- pm1 = PassManager._CAPICreate(pm_capsule)
- assert pm1 is not None # And does not crash.
+ with Context():
+ pm = PassManager()
+ pm_capsule = pm._CAPIPtr
+ assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
+ pm._testing_release()
+ pm1 = PassManager._CAPICreate(pm_capsule)
+ assert pm1 is not None # And does not crash.
+
+
run(testCapsule)
# CHECK-LABEL: TEST: testConstruct
@run
def testConstruct():
- with Context():
- # CHECK: pm1: 'any()'
- # CHECK: pm2: 'builtin.module()'
- pm1 = PassManager()
- pm2 = PassManager("builtin.module")
- log(f"pm1: '{pm1}'")
- log(f"pm2: '{pm2}'")
+ with Context():
+ # CHECK: pm1: 'any()'
+ # CHECK: pm2: 'builtin.module()'
+ pm1 = PassManager()
+ pm2 = PassManager("builtin.module")
+ log(f"pm1: '{pm1}'")
+ log(f"pm2: '{pm2}'")
# Verify successful round-trip.
# CHECK-LABEL: TEST: testParseSuccess
def testParseSuccess():
- with Context():
- # An unregistered pass should not parse.
- try:
- pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))")
- except ValueError as e:
- # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass
- log("ValueError exception:", e)
- else:
- log("Exception not produced")
-
- # A registered pass should parse successfully.
- pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
- # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
- log("Roundtrip: ", pm)
+ with Context():
+ # An unregistered pass should not parse.
+ try:
+ pm = PassManager.parse(
+ "builtin.module(func.func(not-existing-pass{json=false}))"
+ )
+ except ValueError as e:
+ # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass
+ log("ValueError exception:", e)
+ else:
+ log("Exception not produced")
+
+ # A registered pass should parse successfully.
+ pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
+ # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
+ log("Roundtrip: ", pm)
+
+
run(testParseSuccess)
# Verify successful round-trip.
# CHECK-LABEL: TEST: testParseSpacedPipeline
def testParseSpacedPipeline():
- with Context():
- # A registered pass should parse successfully even if has extras spaces for readability
- pm = PassManager.parse("""builtin.module(
+ with Context():
+ # A registered pass should parse successfully even if has extras spaces for readability
+ pm = PassManager.parse(
+ """builtin.module(
func.func( print-op-stats{ json=false } )
- )""")
- # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
- log("Roundtrip: ", pm)
+ )"""
+ )
+ # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
+ log("Roundtrip: ", pm)
+
+
run(testParseSpacedPipeline)
# Verify failure on unregistered pass.
# CHECK-LABEL: TEST: testParseFail
def testParseFail():
- with Context():
- try:
- pm = PassManager.parse("any(unknown-pass)")
- except ValueError as e:
- # CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error:
- # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline
- # CHECK: unknown-pass
- # CHECK: ^
- log("ValueError exception:", e)
- else:
- log("Exception not produced")
+ with Context():
+ try:
+ pm = PassManager.parse("any(unknown-pass)")
+ except ValueError as e:
+ # CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error:
+ # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline
+ # CHECK: unknown-pass
+ # CHECK: ^
+ log("ValueError exception:", e)
+ else:
+ log("Exception not produced")
+
+
run(testParseFail)
# Check that adding to a pass manager works
# CHECK-LABEL: TEST: testAdd
@run
def testAdd():
- pm = PassManager("any", Context())
- # CHECK: pm: 'any()'
- log(f"pm: '{pm}'")
- # CHECK: pm: 'any(cse)'
- pm.add("cse")
- log(f"pm: '{pm}'")
- # CHECK: pm: 'any(cse,cse)'
- pm.add("cse")
- log(f"pm: '{pm}'")
+ pm = PassManager("any", Context())
+ # CHECK: pm: 'any()'
+ log(f"pm: '{pm}'")
+ # CHECK: pm: 'any(cse)'
+ pm.add("cse")
+ log(f"pm: '{pm}'")
+ # CHECK: pm: 'any(cse,cse)'
+ pm.add("cse")
+ log(f"pm: '{pm}'")
# Verify failure on incorrect level of nesting.
# CHECK-LABEL: TEST: testInvalidNesting
def testInvalidNesting():
- with Context():
- try:
- pm = PassManager.parse("func.func(normalize-memrefs)")
- except ValueError as e:
- # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
- log("ValueError exception:", e)
- else:
- log("Exception not produced")
+ with Context():
+ try:
+ pm = PassManager.parse("func.func(normalize-memrefs)")
+ except ValueError as e:
+ # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
+ log("ValueError exception:", e)
+ else:
+ log("Exception not produced")
+
+
run(testInvalidNesting)
# Verify that a pass manager can execute on IR
# CHECK-LABEL: TEST: testRunPipeline
def testRunPipeline():
- with Context():
- pm = PassManager.parse("any(print-op-stats{json=false})")
- func = FuncOp.parse(r"""func.func @successfulParse() { return }""")
- pm.run(func)
+ with Context():
+ pm = PassManager.parse("any(print-op-stats{json=false})")
+ func = FuncOp.parse(r"""func.func @successfulParse() { return }""")
+ pm.run(func)
+
+
# CHECK: Operations encountered:
# CHECK: func.func , 1
# CHECK: func.return , 1
@@ -132,16 +150,16 @@ def testRunPipeline():
# CHECK-LABEL: TEST: testRunPipelineError
@run
def testRunPipelineError():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
- op = Operation.parse('"test.op"() : () -> ()')
- pm = PassManager.parse("any(cse)")
- try:
- pm.run(op)
- except MLIRError as e:
- # CHECK: Exception: <
- # CHECK: Failure while executing pass pipeline:
- # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
- # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> ()
- # CHECK: >
- print(f"Exception: <{e}>")
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ op = Operation.parse('"test.op"() : () -> ()')
+ pm = PassManager.parse("any(cse)")
+ try:
+ pm.run(op)
+ except MLIRError as e:
+ # CHECK: Exception: <
+ # CHECK: Failure while executing pass pipeline:
+ # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
+ # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> ()
+ # CHECK: >
+ print(f"Exception: <{e}>")
diff --git a/mlir/test/tblgen-lsp-server/lit.local.cfg b/mlir/test/tblgen-lsp-server/lit.local.cfg
index 25d08c7aba306..aa35dbfa8c01f 100644
--- a/mlir/test/tblgen-lsp-server/lit.local.cfg
+++ b/mlir/test/tblgen-lsp-server/lit.local.cfg
@@ -1 +1 @@
-config.excludes = ['include']
+config.excludes = ["include"]
diff --git a/mlir/utils/gdb-scripts/prettyprinters.py b/mlir/utils/gdb-scripts/prettyprinters.py
index 85a1a14b177d1..9ea8bdbe86d77 100644
--- a/mlir/utils/gdb-scripts/prettyprinters.py
+++ b/mlir/utils/gdb-scripts/prettyprinters.py
@@ -4,216 +4,223 @@
class StoragePrinter:
- """Prints bases of a struct and its fields."""
+ """Prints bases of a struct and its fields."""
- def __init__(self, val):
- self.val = val
+ def __init__(self, val):
+ self.val = val
- def children(self):
- for field in self.val.type.fields():
- if field.is_base_class:
- yield '<%s>' % field.name, self.val.cast(field.type)
- else:
- yield field.name, self.val[field.name]
+ def children(self):
+ for field in self.val.type.fields():
+ if field.is_base_class:
+ yield "<%s>" % field.name, self.val.cast(field.type)
+ else:
+ yield field.name, self.val[field.name]
+
+ def to_string(self):
+ return "mlir::Storage"
- def to_string(self):
- return 'mlir::Storage'
class TupleTypeStoragePrinter(StoragePrinter):
+ def children(self):
+ for child in StoragePrinter.children(self):
+ yield child
+ pointer_type = gdb.lookup_type("mlir::Type").pointer()
+ elements = (self.val.address + 1).cast(pointer_type)
+ for i in range(self.val["numElements"]):
+ yield "elements[%u]" % i, elements[i]
- def children(self):
- for child in StoragePrinter.children(self):
- yield child
- pointer_type = gdb.lookup_type('mlir::Type').pointer()
- elements = (self.val.address + 1).cast(pointer_type)
- for i in range(self.val['numElements']):
- yield 'elements[%u]' % i, elements[i]
+ def to_string(self):
+ return "mlir::TupleTypeStorage of %u elements" % self.val["numElements"]
- def to_string(self):
- return 'mlir::TupleTypeStorage of %u elements' % self.val['numElements']
class FusedLocationStoragePrinter(StoragePrinter):
+ def children(self):
+ for child in StoragePrinter.children(self):
+ yield child
+ pointer_type = gdb.lookup_type("mlir::Location").pointer()
+ elements = (self.val.address + 1).cast(pointer_type)
+ for i in range(self.val["numLocs"]):
+ yield "locs[%u]" % i, elements[i]
- def children(self):
- for child in StoragePrinter.children(self):
- yield child
- pointer_type = gdb.lookup_type('mlir::Location').pointer()
- elements = (self.val.address + 1).cast(pointer_type)
- for i in range(self.val['numLocs']):
- yield 'locs[%u]' % i, elements[i]
-
- def to_string(self):
- return 'mlir::FusedLocationStorage of %u locs' % self.val['numLocs']
+ def to_string(self):
+ return "mlir::FusedLocationStorage of %u locs" % self.val["numLocs"]
class StorageTypeMap:
- """Maps a TypeID to the corresponding concrete type.
-
- Types need to be registered by name before the first lookup.
- """
-
- def __init__(self):
- self.map = None
- self.type_names = []
-
- def register_type(self, type_name):
- assert not self.map, 'register_type called after __getitem__'
- self.type_names += [type_name]
-
- def _init_map(self):
- """Lazy initialization of self.map."""
- if self.map:
- return
- self.map = {}
- for type_name in self.type_names:
- concrete_type = gdb.lookup_type(type_name)
- try:
- storage = gdb.parse_and_eval(
- "&'mlir::detail::TypeIDExported::get<%s>()::instance'" % type_name)
- except gdb.error:
- # Skip when TypeID instance cannot be found in current context.
- continue
- if concrete_type and storage:
- self.map[int(storage)] = concrete_type
-
- def __getitem__(self, type_id):
- self._init_map()
- return self.map.get(int(type_id['storage']))
+ """Maps a TypeID to the corresponding concrete type.
+
+ Types need to be registered by name before the first lookup.
+ """
+
+ def __init__(self):
+ self.map = None
+ self.type_names = []
+
+ def register_type(self, type_name):
+ assert not self.map, "register_type called after __getitem__"
+ self.type_names += [type_name]
+
+ def _init_map(self):
+ """Lazy initialization of self.map."""
+ if self.map:
+ return
+ self.map = {}
+ for type_name in self.type_names:
+ concrete_type = gdb.lookup_type(type_name)
+ try:
+ storage = gdb.parse_and_eval(
+ "&'mlir::detail::TypeIDExported::get<%s>()::instance'" % type_name
+ )
+ except gdb.error:
+ # Skip when TypeID instance cannot be found in current context.
+ continue
+ if concrete_type and storage:
+ self.map[int(storage)] = concrete_type
+
+ def __getitem__(self, type_id):
+ self._init_map()
+ return self.map.get(int(type_id["storage"]))
storage_type_map = StorageTypeMap()
def get_type_id_printer(val):
- """Returns a printer of the name of a mlir::TypeID."""
-
- class TypeIdPrinter:
+ """Returns a printer of the name of a mlir::TypeID."""
- def __init__(self, string):
- self.string = string
+ class TypeIdPrinter:
+ def __init__(self, string):
+ self.string = string
- def to_string(self):
- return self.string
+ def to_string(self):
+ return self.string
- concrete_type = storage_type_map[val]
- if not concrete_type:
- return None
- return TypeIdPrinter('mlir::TypeID::get<%s>()' % concrete_type)
+ concrete_type = storage_type_map[val]
+ if not concrete_type:
+ return None
+ return TypeIdPrinter("mlir::TypeID::get<%s>()" % concrete_type)
def get_attr_or_type_printer(val, get_type_id):
- """Returns a printer for mlir::Attribute or mlir::Type."""
-
- class AttrOrTypePrinter:
-
- def __init__(self, type_id, impl):
- self.type_id = type_id
- self.impl = impl
-
- def children(self):
- yield 'typeID', self.type_id
- yield 'impl', self.impl
-
- def to_string(self):
- return 'cast<%s>' % self.impl.type
-
- if not val['impl']:
- return None
- impl = val['impl'].dereference()
- type_id = get_type_id(impl)
- concrete_type = storage_type_map[type_id]
- if not concrete_type:
- return None
- # 3rd template argument of StorageUserBase is the storage type.
- storage_type = concrete_type.fields()[0].type.template_argument(2)
- if not storage_type:
- return None
- return AttrOrTypePrinter(type_id, impl.cast(storage_type))
+ """Returns a printer for mlir::Attribute or mlir::Type."""
+
+ class AttrOrTypePrinter:
+ def __init__(self, type_id, impl):
+ self.type_id = type_id
+ self.impl = impl
+
+ def children(self):
+ yield "typeID", self.type_id
+ yield "impl", self.impl
+
+ def to_string(self):
+ return "cast<%s>" % self.impl.type
+
+ if not val["impl"]:
+ return None
+ impl = val["impl"].dereference()
+ type_id = get_type_id(impl)
+ concrete_type = storage_type_map[type_id]
+ if not concrete_type:
+ return None
+ # 3rd template argument of StorageUserBase is the storage type.
+ storage_type = concrete_type.fields()[0].type.template_argument(2)
+ if not storage_type:
+ return None
+ return AttrOrTypePrinter(type_id, impl.cast(storage_type))
class ImplPrinter:
- """Printer for an instance with a single 'impl' member pointer."""
+ """Printer for an instance with a single 'impl' member pointer."""
- def __init__(self, val):
- self.val = val
- self.impl = val['impl']
+ def __init__(self, val):
+ self.val = val
+ self.impl = val["impl"]
- def children(self):
- if self.impl:
- yield 'impl', self.impl.dereference()
+ def children(self):
+ if self.impl:
+ yield "impl", self.impl.dereference()
- def to_string(self):
- return self.val.type.name
+ def to_string(self):
+ return self.val.type.name
# Printers of types deriving from Attribute::AttrBase or Type::TypeBase.
for name in [
# mlir/IR/Attributes.h
- 'ArrayAttr',
- 'DictionaryAttr',
- 'FloatAttr',
- 'IntegerAttr',
- 'IntegerSetAttr',
- 'OpaqueAttr',
- 'StringAttr',
- 'SymbolRefAttr',
- 'TypeAttr',
- 'UnitAttr',
- 'DenseStringElementsAttr',
- 'DenseIntOrFPElementsAttr',
- 'SparseElementsAttr',
+ "ArrayAttr",
+ "DictionaryAttr",
+ "FloatAttr",
+ "IntegerAttr",
+ "IntegerSetAttr",
+ "OpaqueAttr",
+ "StringAttr",
+ "SymbolRefAttr",
+ "TypeAttr",
+ "UnitAttr",
+ "DenseStringElementsAttr",
+ "DenseIntOrFPElementsAttr",
+ "SparseElementsAttr",
# mlir/IR/BuiltinTypes.h
- 'ComplexType',
- 'IndexType',
- 'IntegerType',
- 'Float16Type',
- 'Float32Type',
- 'Float64Type',
- 'Float80Type',
- 'Float128Type',
- 'NoneType',
- 'VectorType',
- 'RankedTensorType',
- 'UnrankedTensorType',
- 'MemRefType',
- 'UnrankedMemRefType',
- 'TupleType',
+ "ComplexType",
+ "IndexType",
+ "IntegerType",
+ "Float16Type",
+ "Float32Type",
+ "Float64Type",
+ "Float80Type",
+ "Float128Type",
+ "NoneType",
+ "VectorType",
+ "RankedTensorType",
+ "UnrankedTensorType",
+ "MemRefType",
+ "UnrankedMemRefType",
+ "TupleType",
# mlir/IR/Location.h
- 'CallSiteLoc',
- 'FileLineColLoc',
- 'FusedLoc',
- 'NameLoc',
- 'OpaqueLoc',
- 'UnknownLoc'
+ "CallSiteLoc",
+ "FileLineColLoc",
+ "FusedLoc",
+ "NameLoc",
+ "OpaqueLoc",
+ "UnknownLoc",
]:
- storage_type_map.register_type('mlir::%s' % name) # Register for upcasting.
-storage_type_map.register_type('void') # Register default.
+ storage_type_map.register_type("mlir::%s" % name) # Register for upcasting.
+storage_type_map.register_type("void") # Register default.
-pp = gdb.printing.RegexpCollectionPrettyPrinter('MLIRSupport')
+pp = gdb.printing.RegexpCollectionPrettyPrinter("MLIRSupport")
-pp.add_printer('mlir::OperationName', '^mlir::OperationName$', ImplPrinter)
-pp.add_printer('mlir::Value', '^mlir::Value$', ImplPrinter)
+pp.add_printer("mlir::OperationName", "^mlir::OperationName$", ImplPrinter)
+pp.add_printer("mlir::Value", "^mlir::Value$", ImplPrinter)
# Printers for types deriving from AttributeStorage or TypeStorage.
-pp.add_printer('mlir::detail::FusedLocationStorage',
- '^mlir::detail::FusedLocationStorage',
- FusedLocationStoragePrinter)
-pp.add_printer('mlir::detail::TupleTypeStorage',
- '^mlir::detail::TupleTypeStorage$', TupleTypeStoragePrinter)
+pp.add_printer(
+ "mlir::detail::FusedLocationStorage",
+ "^mlir::detail::FusedLocationStorage",
+ FusedLocationStoragePrinter,
+)
+pp.add_printer(
+ "mlir::detail::TupleTypeStorage",
+ "^mlir::detail::TupleTypeStorage$",
+ TupleTypeStoragePrinter,
+)
-pp.add_printer('mlir::TypeID', '^mlir::TypeID$', get_type_id_printer)
+pp.add_printer("mlir::TypeID", "^mlir::TypeID$", get_type_id_printer)
def add_attr_or_type_printers(name):
- """Adds printers for mlir::Attribute or mlir::Type and their Storage type."""
- get_type_id = lambda val: val['abstract%s' % name]['typeID']
- pp.add_printer('mlir::%s' % name, '^mlir::%s$' % name,
- lambda val: get_attr_or_type_printer(val, get_type_id))
+ """Adds printers for mlir::Attribute or mlir::Type and their Storage type."""
+ get_type_id = lambda val: val["abstract%s" % name]["typeID"]
+ pp.add_printer(
+ "mlir::%s" % name,
+ "^mlir::%s$" % name,
+ lambda val: get_attr_or_type_printer(val, get_type_id),
+ )
# Upcasting printers of mlir::Attribute and mlir::Type.
-for name in ['Attribute', 'Type']:
- add_attr_or_type_printers(name)
+for name in ["Attribute", "Type"]:
+ add_attr_or_type_printers(name)
gdb.printing.register_pretty_printer(gdb.current_objfile(), pp)
diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index 474f812c9c0bc..0210d7a56ebf5 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -32,7 +32,7 @@
import re
import sys
-ADVERT_BEGIN = '// NOTE: Assertions have been autogenerated by '
+ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
ADVERT_END = """
// The script is designed to make adding checks to
// a test case fast, it is *not* designed to be authoritative
@@ -42,250 +42,249 @@
# Regex command to match an SSA identifier.
-SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
+SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
SSA_RE = re.compile(SSA_RE_STR)
# Class used to generate and manage string substitution blocks for SSA value
# names.
class SSAVariableNamer:
+ def __init__(self):
+ self.scopes = []
+ self.name_counter = 0
- def __init__(self):
- self.scopes = []
- self.name_counter = 0
+ # Generate a substitution name for the given ssa value name.
+ def generate_name(self, ssa_name):
+ variable = "VAL_" + str(self.name_counter)
+ self.name_counter += 1
+ self.scopes[-1][ssa_name] = variable
+ return variable
- # Generate a substitution name for the given ssa value name.
- def generate_name(self, ssa_name):
- variable = 'VAL_' + str(self.name_counter)
- self.name_counter += 1
- self.scopes[-1][ssa_name] = variable
- return variable
+ # Push a new variable name scope.
+ def push_name_scope(self):
+ self.scopes.append({})
- # Push a new variable name scope.
- def push_name_scope(self):
- self.scopes.append({})
+ # Pop the last variable name scope.
+ def pop_name_scope(self):
+ self.scopes.pop()
- # Pop the last variable name scope.
- def pop_name_scope(self):
- self.scopes.pop()
+ # Return the level of nesting (number of pushed scopes).
+ def num_scopes(self):
+ return len(self.scopes)
- # Return the level of nesting (number of pushed scopes).
- def num_scopes(self):
- return len(self.scopes)
-
- # Reset the counter.
- def clear_counter(self):
- self.name_counter = 0
+ # Reset the counter.
+ def clear_counter(self):
+ self.name_counter = 0
# Process a line of input that has been split at each SSA identifier '%'.
def process_line(line_chunks, variable_namer):
- output_line = ''
-
- # Process the rest that contained an SSA value name.
- for chunk in line_chunks:
- m = SSA_RE.match(chunk)
- ssa_name = m.group(0)
-
- # Check if an existing variable exists for this name.
- variable = None
- for scope in variable_namer.scopes:
- variable = scope.get(ssa_name)
- if variable is not None:
- break
-
- # If one exists, then output the existing name.
- if variable is not None:
- output_line += '%[[' + variable + ']]'
- else:
- # Otherwise, generate a new variable.
- variable = variable_namer.generate_name(ssa_name)
- output_line += '%[[' + variable + ':.*]]'
+ output_line = ""
+
+ # Process the rest that contained an SSA value name.
+ for chunk in line_chunks:
+ m = SSA_RE.match(chunk)
+ ssa_name = m.group(0)
+
+ # Check if an existing variable exists for this name.
+ variable = None
+ for scope in variable_namer.scopes:
+ variable = scope.get(ssa_name)
+ if variable is not None:
+ break
- # Append the non named group.
- output_line += chunk[len(ssa_name):]
+ # If one exists, then output the existing name.
+ if variable is not None:
+ output_line += "%[[" + variable + "]]"
+ else:
+ # Otherwise, generate a new variable.
+ variable = variable_namer.generate_name(ssa_name)
+ output_line += "%[[" + variable + ":.*]]"
- return output_line.rstrip() + '\n'
+ # Append the non named group.
+ output_line += chunk[len(ssa_name) :]
+
+ return output_line.rstrip() + "\n"
# Process the source file lines. The source file doesn't have to be .mlir.
def process_source_lines(source_lines, note, args):
- source_split_re = re.compile(args.source_delim_regex)
+ source_split_re = re.compile(args.source_delim_regex)
- source_segments = [[]]
- for line in source_lines:
- # Remove previous note.
- if line == note:
- continue
- # Remove previous CHECK lines.
- if line.find(args.check_prefix) != -1:
- continue
- # Segment the file based on --source_delim_regex.
- if source_split_re.search(line):
- source_segments.append([])
+ source_segments = [[]]
+ for line in source_lines:
+ # Remove previous note.
+ if line == note:
+ continue
+ # Remove previous CHECK lines.
+ if line.find(args.check_prefix) != -1:
+ continue
+ # Segment the file based on --source_delim_regex.
+ if source_split_re.search(line):
+ source_segments.append([])
- source_segments[-1].append(line + '\n')
- return source_segments
+ source_segments[-1].append(line + "\n")
+ return source_segments
# Pre-process a line of input to remove any character sequences that will be
# problematic with FileCheck.
def preprocess_line(line):
- # Replace any double brackets, '[[' with escaped replacements. '[['
- # corresponds to variable names in FileCheck.
- output_line = line.replace('[[', '{{\\[\\[}}')
+ # Replace any double brackets, '[[' with escaped replacements. '[['
+ # corresponds to variable names in FileCheck.
+ output_line = line.replace("[[", "{{\\[\\[}}")
- # Replace any single brackets that are followed by an SSA identifier, the
- # identifier will be replace by a variable; Creating the same situation as
- # above.
- output_line = output_line.replace('[%', '{{\\[}}%')
+ # Replace any single brackets that are followed by an SSA identifier, the
+ # identifier will be replace by a variable; Creating the same situation as
+ # above.
+ output_line = output_line.replace("[%", "{{\\[}}%")
- return output_line
+ return output_line
def main():
- parser = argparse.ArgumentParser(
- description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
- parser.add_argument(
- '--check-prefix', default='CHECK', help='Prefix to use from check file.')
- parser.add_argument(
- '-o',
- '--output',
- nargs='?',
- type=argparse.FileType('w'),
- default=None)
- parser.add_argument(
- 'input',
- nargs='?',
- type=argparse.FileType('r'),
- default=sys.stdin)
- parser.add_argument(
- '--source', type=str,
- help='Print each CHECK chunk before each delimeter line in the source'
- 'file, respectively. The delimeter lines are identified by '
- '--source_delim_regex.')
- parser.add_argument('--source_delim_regex', type=str, default='func @')
- parser.add_argument(
- '--starts_from_scope', type=int, default=1,
- help='Omit the top specified level of content. For example, by default '
- 'it omits "module {"')
- parser.add_argument('-i', '--inplace', action='store_true', default=False)
-
- args = parser.parse_args()
-
- # Open the given input file.
- input_lines = [l.rstrip() for l in args.input]
- args.input.close()
-
- # Generate a note used for the generated check file.
- script_name = os.path.basename(__file__)
- autogenerated_note = (ADVERT_BEGIN + 'utils/' + script_name + "\n" + ADVERT_END)
-
- source_segments = None
- if args.source:
- source_segments = process_source_lines(
- [l.rstrip() for l in open(args.source, 'r')],
- autogenerated_note,
- args
+ parser = argparse.ArgumentParser(
+ description=__doc__, formatter_class=argparse.RawTextHelpFormatter
+ )
+ parser.add_argument(
+ "--check-prefix", default="CHECK", help="Prefix to use from check file."
+ )
+ parser.add_argument(
+ "-o", "--output", nargs="?", type=argparse.FileType("w"), default=None
+ )
+ parser.add_argument(
+ "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin
)
+ parser.add_argument(
+ "--source",
+ type=str,
+ help="Print each CHECK chunk before each delimeter line in the source"
+ "file, respectively. The delimeter lines are identified by "
+ "--source_delim_regex.",
+ )
+ parser.add_argument("--source_delim_regex", type=str, default="func @")
+ parser.add_argument(
+ "--starts_from_scope",
+ type=int,
+ default=1,
+ help="Omit the top specified level of content. For example, by default "
+ 'it omits "module {"',
+ )
+ parser.add_argument("-i", "--inplace", action="store_true", default=False)
+
+ args = parser.parse_args()
+
+ # Open the given input file.
+ input_lines = [l.rstrip() for l in args.input]
+ args.input.close()
- if args.inplace:
- assert args.output is None
- output = open(args.source, 'w')
- elif args.output is None:
- output = sys.stdout
- else:
- output = args.output
-
- output_segments = [[]]
- # A map containing data used for naming SSA value names.
- variable_namer = SSAVariableNamer()
- for input_line in input_lines:
- if not input_line:
- continue
- lstripped_input_line = input_line.lstrip()
-
- # Lines with blocks begin with a ^. These lines have a trailing comment
- # that needs to be stripped.
- is_block = lstripped_input_line[0] == '^'
- if is_block:
- input_line = input_line.rsplit('//', 1)[0].rstrip()
-
- cur_level = variable_namer.num_scopes()
-
- # If the line starts with a '}', pop the last name scope.
- if lstripped_input_line[0] == '}':
- variable_namer.pop_name_scope()
- cur_level = variable_namer.num_scopes()
-
- # If the line ends with a '{', push a new name scope.
- if input_line[-1] == '{':
- variable_namer.push_name_scope()
- if cur_level == args.starts_from_scope:
- output_segments.append([])
-
- # Omit lines at the near top level e.g. "module {".
- if cur_level < args.starts_from_scope:
- continue
-
- if len(output_segments[-1]) == 0:
- variable_namer.clear_counter()
-
- # Preprocess the input to remove any sequences that may be problematic with
- # FileCheck.
- input_line = preprocess_line(input_line)
-
- # Split the line at the each SSA value name.
- ssa_split = input_line.split('%')
-
- # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
- if len(output_segments[-1]) != 0 or not ssa_split[0]:
- output_line = '// ' + args.check_prefix + ': '
- # Pad to align with the 'LABEL' statements.
- output_line += (' ' * len('-LABEL'))
-
- # Output the first line chunk that does not contain an SSA name.
- output_line += ssa_split[0]
-
- # Process the rest of the input line.
- output_line += process_line(ssa_split[1:], variable_namer)
+ # Generate a note used for the generated check file.
+ script_name = os.path.basename(__file__)
+ autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END
+ source_segments = None
+ if args.source:
+ source_segments = process_source_lines(
+ [l.rstrip() for l in open(args.source, "r")], autogenerated_note, args
+ )
+
+ if args.inplace:
+ assert args.output is None
+ output = open(args.source, "w")
+ elif args.output is None:
+ output = sys.stdout
+ else:
+ output = args.output
+
+ output_segments = [[]]
+ # A map containing data used for naming SSA value names.
+ variable_namer = SSAVariableNamer()
+ for input_line in input_lines:
+ if not input_line:
+ continue
+ lstripped_input_line = input_line.lstrip()
+
+ # Lines with blocks begin with a ^. These lines have a trailing comment
+ # that needs to be stripped.
+ is_block = lstripped_input_line[0] == "^"
+ if is_block:
+ input_line = input_line.rsplit("//", 1)[0].rstrip()
+
+ cur_level = variable_namer.num_scopes()
+
+ # If the line starts with a '}', pop the last name scope.
+ if lstripped_input_line[0] == "}":
+ variable_namer.pop_name_scope()
+ cur_level = variable_namer.num_scopes()
+
+ # If the line ends with a '{', push a new name scope.
+ if input_line[-1] == "{":
+ variable_namer.push_name_scope()
+ if cur_level == args.starts_from_scope:
+ output_segments.append([])
+
+ # Omit lines at the near top level e.g. "module {".
+ if cur_level < args.starts_from_scope:
+ continue
+
+ if len(output_segments[-1]) == 0:
+ variable_namer.clear_counter()
+
+ # Preprocess the input to remove any sequences that may be problematic with
+ # FileCheck.
+ input_line = preprocess_line(input_line)
+
+ # Split the line at the each SSA value name.
+ ssa_split = input_line.split("%")
+
+ # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
+ if len(output_segments[-1]) != 0 or not ssa_split[0]:
+ output_line = "// " + args.check_prefix + ": "
+ # Pad to align with the 'LABEL' statements.
+ output_line += " " * len("-LABEL")
+
+ # Output the first line chunk that does not contain an SSA name.
+ output_line += ssa_split[0]
+
+ # Process the rest of the input line.
+ output_line += process_line(ssa_split[1:], variable_namer)
+
+ else:
+ # Output the first line chunk that does not contain an SSA name for the
+ # label.
+ output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n"
+
+ # Process the rest of the input line on separate check lines.
+ for argument in ssa_split[1:]:
+ output_line += "// " + args.check_prefix + "-SAME: "
+
+ # Pad to align with the original position in the line.
+ output_line += " " * len(ssa_split[0])
+
+ # Process the rest of the line.
+ output_line += process_line([argument], variable_namer)
+
+ # Append the output line.
+ output_segments[-1].append(output_line)
+
+ output.write(autogenerated_note + "\n")
+
+ # Write the output.
+ if source_segments:
+ assert len(output_segments) == len(source_segments)
+ for check_segment, source_segment in zip(output_segments, source_segments):
+ for line in check_segment:
+ output.write(line)
+ for line in source_segment:
+ output.write(line)
else:
- # Output the first line chunk that does not contain an SSA name for the
- # label.
- output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
-
- # Process the rest of the input line on separate check lines.
- for argument in ssa_split[1:]:
- output_line += '// ' + args.check_prefix + '-SAME: '
-
- # Pad to align with the original position in the line.
- output_line += ' ' * len(ssa_split[0])
-
- # Process the rest of the line.
- output_line += process_line([argument], variable_namer)
-
- # Append the output line.
- output_segments[-1].append(output_line)
-
- output.write(autogenerated_note + '\n')
-
- # Write the output.
- if source_segments:
- assert len(output_segments) == len(source_segments)
- for check_segment, source_segment in zip(output_segments, source_segments):
- for line in check_segment:
- output.write(line)
- for line in source_segment:
- output.write(line)
- else:
- for segment in output_segments:
- output.write('\n')
- for output_line in segment:
- output.write(output_line)
- output.write('\n')
- output.close()
-
-
-if __name__ == '__main__':
- main()
+ for segment in output_segments:
+ output.write("\n")
+ for output_line in segment:
+ output.write(output_line)
+ output.write("\n")
+ output.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/mlir/utils/jupyter/mlir_opt_kernel/__main__.py b/mlir/utils/jupyter/mlir_opt_kernel/__main__.py
index 02582f9b5bfb0..21994ff6ec6fa 100644
--- a/mlir/utils/jupyter/mlir_opt_kernel/__main__.py
+++ b/mlir/utils/jupyter/mlir_opt_kernel/__main__.py
@@ -4,4 +4,5 @@
from ipykernel.kernelapp import IPKernelApp
from .kernel import MlirOptKernel
+
IPKernelApp.launch_instance(kernel_class=MlirOptKernel)
diff --git a/mlir/utils/jupyter/mlir_opt_kernel/install.py b/mlir/utils/jupyter/mlir_opt_kernel/install.py
index ddb37c87dfea9..bd7b1d10b734f 100644
--- a/mlir/utils/jupyter/mlir_opt_kernel/install.py
+++ b/mlir/utils/jupyter/mlir_opt_kernel/install.py
@@ -10,12 +10,11 @@
def install_my_kernel_spec(user=True, prefix=None):
"""Install the kernel spec for user in given prefix."""
- print('Installing mlir-opt IPython kernel spec')
+ print("Installing mlir-opt IPython kernel spec")
pkgroot = os.path.dirname(__file__)
- KernelSpecManager().install_kernel_spec(os.path.join(pkgroot, 'assets'),
- 'mlir',
- user=user,
- prefix=prefix)
+ KernelSpecManager().install_kernel_spec(
+ os.path.join(pkgroot, "assets"), "mlir", user=user, prefix=prefix
+ )
def _is_root():
@@ -29,15 +28,16 @@ def _is_root():
def main(argv=None):
parser = argparse.ArgumentParser(
- description='Install KernelSpec for MlirOpt Kernel')
+ description="Install KernelSpec for MlirOpt Kernel"
+ )
prefix_locations = parser.add_mutually_exclusive_group()
- prefix_locations.add_argument('--user',
- help='Install in user home directory',
- action='store_true')
- prefix_locations.add_argument('--prefix',
- help='Install directory prefix',
- default=None)
+ prefix_locations.add_argument(
+ "--user", help="Install in user home directory", action="store_true"
+ )
+ prefix_locations.add_argument(
+ "--prefix", help="Install directory prefix", default=None
+ )
args = parser.parse_args(argv)
@@ -47,5 +47,5 @@ def main(argv=None):
install_my_kernel_spec(user=user, prefix=prefix)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/mlir/utils/jupyter/mlir_opt_kernel/kernel.py b/mlir/utils/jupyter/mlir_opt_kernel/kernel.py
index 85462dad83a30..c0e4fc1db4c8a 100644
--- a/mlir/utils/jupyter/mlir_opt_kernel/kernel.py
+++ b/mlir/utils/jupyter/mlir_opt_kernel/kernel.py
@@ -9,7 +9,7 @@
import traceback
from ipykernel.kernelbase import Kernel
-__version__ = '0.0.1'
+__version__ = "0.0.1"
def _get_executable():
@@ -19,7 +19,7 @@ def is_exe(fpath):
"""Returns whether executable file."""
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
- program = os.environ.get('MLIR_OPT_EXECUTABLE', 'mlir-opt')
+ program = os.environ.get("MLIR_OPT_EXECUTABLE", "mlir-opt")
path, name = os.path.split(program)
# Attempt to get the executable
if path:
@@ -30,7 +30,7 @@ def is_exe(fpath):
file = os.path.join(path, name)
if is_exe(file):
return file
- raise OSError('mlir-opt not found, please see README')
+ raise OSError("mlir-opt not found, please see README")
class MlirOptKernel(Kernel):
@@ -51,19 +51,17 @@ class MlirOptKernel(Kernel):
```
"""
- implementation = 'mlir'
+ implementation = "mlir"
implementation_version = __version__
language_version = __version__
language = "mlir"
language_info = {
"name": "mlir",
- "codemirror_mode": {
- "name": "mlir"
- },
+ "codemirror_mode": {"name": "mlir"},
"mimetype": "text/x-mlir",
"file_extension": ".mlir",
- "pygments_lexer": "text"
+ "pygments_lexer": "text",
}
@property
@@ -88,31 +86,28 @@ def process_output(self, output):
"""Reports regular command output."""
if not self.silent:
# Send standard output
- stream_content = {'name': 'stdout', 'text': output}
- self.send_response(self.iopub_socket, 'stream', stream_content)
+ stream_content = {"name": "stdout", "text": output}
+ self.send_response(self.iopub_socket, "stream", stream_content)
def process_error(self, output):
"""Reports error response."""
if not self.silent:
# Send standard error
- stream_content = {'name': 'stderr', 'text': output}
- self.send_response(self.iopub_socket, 'stream', stream_content)
-
- def do_execute(self,
- code,
- silent,
- store_history=True,
- user_expressions=None,
- allow_stdin=False):
+ stream_content = {"name": "stderr", "text": output}
+ self.send_response(self.iopub_socket, "stream", stream_content)
+
+ def do_execute(
+ self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
+ ):
"""Execute user code using mlir-opt binary."""
def ok_status():
"""Returns OK status."""
return {
- 'status': 'ok',
- 'execution_count': self.execution_count,
- 'payload': [],
- 'user_expressions': {}
+ "status": "ok",
+ "execution_count": self.execution_count,
+ "payload": [],
+ "user_expressions": {},
}
def run(code):
@@ -123,29 +118,27 @@ def run(code):
# Specify input and output file to error out if also
# set as arg.
self.get_executable(),
- '--color',
+ "--color",
inputmlir.name,
- '-o',
- '-'
+ "-o",
+ "-",
]
# Simple handling of repeating last line.
- if code.endswith('\n_'):
+ if code.endswith("\n_"):
if not self._:
- raise NameError('No previous result set')
+ raise NameError("No previous result set")
code = code[:-1] + self._
inputmlir.write(code.encode("utf-8"))
inputmlir.close()
- pipe = Popen(command,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
+ pipe = Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, errors = pipe.communicate()
exitcode = pipe.returncode
finally:
os.unlink(inputmlir.name)
-# Replace temporary filename with placeholder. This takes the very
-# remote chance where the full input filename (generated above)
-# overlaps with something in the dump unrelated to the file.
+ # Replace temporary filename with placeholder. This takes the very
+ # remote chance where the full input filename (generated above)
+ # overlaps with something in the dump unrelated to the file.
fname = inputmlir.name.encode("utf-8")
output = output.replace(fname, b"<<input>>")
errors = errors.replace(fname, b"<<input>>")
@@ -163,7 +156,7 @@ def run(code):
else:
self._ = output.decode("utf-8")
except KeyboardInterrupt:
- return {'status': 'abort', 'execution_count': self.execution_count}
+ return {"status": "abort", "execution_count": self.execution_count}
except Exception as error:
# Print traceback for local debugging.
traceback.print_exc()
@@ -172,24 +165,24 @@ def run(code):
errors = repr(error).encode("utf-8")
if exitcode:
- content = {'ename': '', 'evalue': str(exitcode), 'traceback': []}
+ content = {"ename": "", "evalue": str(exitcode), "traceback": []}
- self.send_response(self.iopub_socket, 'error', content)
+ self.send_response(self.iopub_socket, "error", content)
self.process_error(errors.decode("utf-8"))
- content['execution_count'] = self.execution_count
- content['status'] = 'error'
+ content["execution_count"] = self.execution_count
+ content["status"] = "error"
return content
if not silent:
data = {}
- data['text/x-mlir'] = self._
+ data["text/x-mlir"] = self._
content = {
- 'execution_count': self.execution_count,
- 'data': data,
- 'metadata': {}
+ "execution_count": self.execution_count,
+ "data": data,
+ "metadata": {},
}
- self.send_response(self.iopub_socket, 'execute_result', content)
+ self.send_response(self.iopub_socket, "execute_result", content)
self.process_output(self._)
self.process_error(errors.decode("utf-8"))
return ok_status()
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index bfd76a7d0ca28..5d06b400334c8 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -521,8 +521,7 @@ def update(self):
class IPListRangeSynthProvider:
- """Define an LLDB synthetic children provider for an IPList.
- """
+ """Define an LLDB synthetic children provider for an IPList."""
def __init__(self, valobj, internal_dict):
self.valobj = valobj
@@ -575,8 +574,7 @@ def update(self):
class ValueSynthProvider:
- """Define an LLDB synthetic children provider for Values.
- """
+ """Define an LLDB synthetic children provider for Values."""
def __init__(self, valobj, internal_dict):
self.valobj = valobj
@@ -677,8 +675,7 @@ def update(self):
def ValueSummaryProvider(valobj: lldb.SBValue, internal_dict):
- """Define an LLDB summary provider for Values.
- """
+ """Define an LLDB summary provider for Values."""
index = valobj.GetChildMemberWithName("index").GetValueAsUnsigned()
# Check if this is a block argument or not (block arguments have locations).
diff --git a/mlir/utils/mbr/mbr/__init__.py b/mlir/utils/mbr/mbr/__init__.py
index 3e47ec861b684..d01befd8fedba 100644
--- a/mlir/utils/mbr/mbr/__init__.py
+++ b/mlir/utils/mbr/mbr/__init__.py
@@ -9,5 +9,6 @@ class BenchmarkRunConfig:
class. The `compiler` attribute is optional, for example for python
benchmarks.
"""
+
runner: typing.Callable
compiler: typing.Optional[typing.Callable] = None
diff --git a/mlir/utils/mbr/mbr/discovery.py b/mlir/utils/mbr/mbr/discovery.py
index 37cc458b31a00..6c9803e0d3c33 100644
--- a/mlir/utils/mbr/mbr/discovery.py
+++ b/mlir/utils/mbr/mbr/discovery.py
@@ -16,21 +16,17 @@ def discover_benchmark_modules(top_level_path):
defaults to "benchmark_"
"""
config = configparser.ConfigParser()
- config.read(
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini")
- )
+ config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini"))
if "discovery" in config.sections():
filename_prefix = config["discovery"]["filename_prefix"]
else:
filename_prefix = "benchmark_"
- if re.search(fr"{filename_prefix}.*.py$", top_level_path):
+ if re.search(rf"{filename_prefix}.*.py$", top_level_path):
# A specific python file so just include that.
benchmark_files = [top_level_path]
else:
# A directory so recursively search for all python files.
- benchmark_files = pathlib.Path(
- top_level_path
- ).rglob(f"{filename_prefix}*.py")
+ benchmark_files = pathlib.Path(top_level_path).rglob(f"{filename_prefix}*.py")
for benchmark_filename in benchmark_files:
benchmark_abs_dir = os.path.abspath(os.path.dirname(benchmark_filename))
sys.path.append(benchmark_abs_dir)
@@ -46,9 +42,7 @@ def get_benchmark_functions(module, benchmark_function_name=None):
a specific prefix, which defaults to "benchmark_".
"""
config = configparser.ConfigParser()
- config.read(
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini")
- )
+ config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini"))
if "discovery" in config.sections():
function_prefix = config["discovery"].get("function_prefix")
else:
@@ -57,9 +51,8 @@ def get_benchmark_functions(module, benchmark_function_name=None):
module_functions = []
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
- if (
- isinstance(attribute, types.FunctionType)
- and attribute_name.startswith(function_prefix)
+ if isinstance(attribute, types.FunctionType) and attribute_name.startswith(
+ function_prefix
):
module_functions.append(attribute)
diff --git a/mlir/utils/mbr/mbr/main.py b/mlir/utils/mbr/mbr/main.py
index 0f67454878bb1..5d301abcc2393 100644
--- a/mlir/utils/mbr/mbr/main.py
+++ b/mlir/utils/mbr/mbr/main.py
@@ -12,8 +12,7 @@
def main(top_level_path, stop_on_error):
- """Top level function called when the CLI is invoked.
- """
+ """Top level function called when the CLI is invoked."""
if "::" in top_level_path:
if top_level_path.count("::") > 1:
raise AssertionError(f"Invalid path {top_level_path}")
@@ -22,16 +21,14 @@ def main(top_level_path, stop_on_error):
benchmark_function_name = None
if not os.path.exists(top_level_path):
- raise AssertionError(
- f"The top-level path {top_level_path} doesn't exist"
- )
+ raise AssertionError(f"The top-level path {top_level_path} doesn't exist")
modules = [module for module in discover_benchmark_modules(top_level_path)]
benchmark_dicts = []
for module in modules:
benchmark_functions = [
- function for function in
- get_benchmark_functions(module, benchmark_function_name)
+ function
+ for function in get_benchmark_functions(module, benchmark_function_name)
]
for benchmark_function in benchmark_functions:
try:
@@ -96,10 +93,9 @@ def main(top_level_path, stop_on_error):
if len(measurements_ns) > 0:
measurements_s = [t * 1e-9 for t in measurements_ns]
- benchmark_identifier = ":".join([
- module.__name__,
- benchmark_function.__name__
- ])
+ benchmark_identifier = ":".join(
+ [module.__name__, benchmark_function.__name__]
+ )
benchmark_dicts.append(
{
"name": benchmark_identifier,
diff --git a/mlir/utils/mbr/mbr/stats.py b/mlir/utils/mbr/mbr/stats.py
index 32880212013e5..9b7a3dce23bfa 100644
--- a/mlir/utils/mbr/mbr/stats.py
+++ b/mlir/utils/mbr/mbr/stats.py
@@ -16,9 +16,7 @@ def has_enough_measurements(measurements):
If 1. is true, 2. doesn't need to be true.
"""
config = configparser.ConfigParser()
- config.read(
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.cfg")
- )
+ config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.cfg"))
if "stats" in config:
stats_dict = {
"max_number_of_measurements": int(
@@ -34,6 +32,6 @@ def has_enough_measurements(measurements):
"max_time_for_a_benchmark_ns": 1e9,
}
return (
- np.sum(measurements) >= stats_dict["max_time_for_a_benchmark_ns"] or
- np.size(measurements) >= stats_dict["max_number_of_measurements"]
+ np.sum(measurements) >= stats_dict["max_time_for_a_benchmark_ns"]
+ or np.size(measurements) >= stats_dict["max_number_of_measurements"]
)
diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index aeb1827e7285f..426bfca1b4f88 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -23,1088 +23,1164 @@
import textwrap
import yaml
-SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html'
-SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json'
+SPIRV_HTML_SPEC_URL = (
+ "https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html"
+)
+SPIRV_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json"
-SPIRV_CL_EXT_HTML_SPEC_URL = 'https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html'
-SPIRV_CL_EXT_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json'
+SPIRV_CL_EXT_HTML_SPEC_URL = "https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html"
+SPIRV_CL_EXT_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json"
-AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n'
-AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!'
+AUTOGEN_OP_DEF_SEPARATOR = "\n// -----\n\n"
+AUTOGEN_ENUM_SECTION_MARKER = "enum section. Generated from SPIR-V spec; DO NOT MODIFY!"
AUTOGEN_OPCODE_SECTION_MARKER = (
- 'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!')
+ "opcode section. Generated from SPIR-V spec; DO NOT MODIFY!"
+)
+
def get_spirv_doc_from_html_spec(url, settings):
- """Extracts instruction documentation from SPIR-V HTML spec.
-
- Returns:
- - A dict mapping from instruction opcode to documentation.
- """
- if url is None:
- url = SPIRV_HTML_SPEC_URL
-
- response = requests.get(url)
- spec = response.content
-
- from bs4 import BeautifulSoup
- spirv = BeautifulSoup(spec, 'html.parser')
-
- doc = {}
-
- if settings.gen_cl_ops:
- section_anchor = spirv.find('h2', {'id': '_binary_form'})
- for section in section_anchor.parent.find_all('div', {'class': 'sect2'}):
- for table in section.find_all('table'):
- inst_html = table.tbody.tr.td
- opname = inst_html.a['id']
- # Ignore the first line, which is just the opname.
- doc[opname] = inst_html.text.split('\n', 1)[1].strip()
- else:
- section_anchor = spirv.find('h3', {'id': '_instructions_3'})
- for section in section_anchor.parent.find_all('div', {'class': 'sect3'}):
- for table in section.find_all('table'):
- inst_html = table.tbody.tr.td.p
- opname = inst_html.a['id']
- # Ignore the first line, which is just the opname.
- doc[opname] = inst_html.text.split('\n', 1)[1].strip()
-
- return doc
+ """Extracts instruction documentation from SPIR-V HTML spec.
+
+ Returns:
+ - A dict mapping from instruction opcode to documentation.
+ """
+ if url is None:
+ url = SPIRV_HTML_SPEC_URL
+
+ response = requests.get(url)
+ spec = response.content
+
+ from bs4 import BeautifulSoup
+
+ spirv = BeautifulSoup(spec, "html.parser")
+
+ doc = {}
+
+ if settings.gen_cl_ops:
+ section_anchor = spirv.find("h2", {"id": "_binary_form"})
+ for section in section_anchor.parent.find_all("div", {"class": "sect2"}):
+ for table in section.find_all("table"):
+ inst_html = table.tbody.tr.td
+ opname = inst_html.a["id"]
+ # Ignore the first line, which is just the opname.
+ doc[opname] = inst_html.text.split("\n", 1)[1].strip()
+ else:
+ section_anchor = spirv.find("h3", {"id": "_instructions_3"})
+ for section in section_anchor.parent.find_all("div", {"class": "sect3"}):
+ for table in section.find_all("table"):
+ inst_html = table.tbody.tr.td.p
+ opname = inst_html.a["id"]
+ # Ignore the first line, which is just the opname.
+ doc[opname] = inst_html.text.split("\n", 1)[1].strip()
+
+ return doc
def get_spirv_grammar_from_json_spec(url):
- """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
+ """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
- Returns:
- - A list containing all operand kinds' grammar
- - A list containing all instructions' grammar
- """
- response = requests.get(SPIRV_JSON_SPEC_URL)
- spec = response.content
+ Returns:
+ - A list containing all operand kinds' grammar
+ - A list containing all instructions' grammar
+ """
+ response = requests.get(SPIRV_JSON_SPEC_URL)
+ spec = response.content
- import json
- spirv = json.loads(spec)
+ import json
- if url is None:
- return spirv['operand_kinds'], spirv['instructions']
+ spirv = json.loads(spec)
- response_ext = requests.get(url)
- spec_ext = response_ext.content
- spirv_ext = json.loads(spec_ext)
+ if url is None:
+ return spirv["operand_kinds"], spirv["instructions"]
- return spirv['operand_kinds'], spirv_ext['instructions']
+ response_ext = requests.get(url)
+ spec_ext = response_ext.content
+ spirv_ext = json.loads(spec_ext)
+
+ return spirv["operand_kinds"], spirv_ext["instructions"]
def split_list_into_sublists(items):
- """Split the list of items into multiple sublists.
+ """Split the list of items into multiple sublists.
- This is to make sure the string composed from each sublist won't exceed
- 80 characters.
+ This is to make sure the string composed from each sublist won't exceed
+ 80 characters.
- Arguments:
- - items: a list of strings
- """
- chuncks = []
- chunk = []
- chunk_len = 0
+ Arguments:
+ - items: a list of strings
+ """
+ chuncks = []
+ chunk = []
+ chunk_len = 0
- for item in items:
- chunk_len += len(item) + 2
- if chunk_len > 80:
- chuncks.append(chunk)
- chunk = []
- chunk_len = len(item) + 2
- chunk.append(item)
+ for item in items:
+ chunk_len += len(item) + 2
+ if chunk_len > 80:
+ chuncks.append(chunk)
+ chunk = []
+ chunk_len = len(item) + 2
+ chunk.append(item)
- if len(chunk) != 0:
- chuncks.append(chunk)
+ if len(chunk) != 0:
+ chuncks.append(chunk)
- return chuncks
+ return chuncks
def uniquify_enum_cases(lst):
- """Prunes duplicate enum cases from the list.
-
- Arguments:
- - lst: List whose elements are to be uniqued. Assumes each element is a
- (symbol, value) pair and elements already sorted according to value.
-
- Returns:
- - A list with all duplicates removed. The elements are sorted according to
- value and, for each value, uniqued according to symbol.
- original list,
- - A map from deduplicated cases to the uniqued case.
- """
- cases = lst
- uniqued_cases = []
- duplicated_cases = {}
-
- # First sort according to the value
- cases.sort(key=lambda x: x[1])
-
- # Then group them according to the value
- for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
- # For each value, sort according to the enumerant symbol.
- sorted_group = sorted(groups, key=lambda x: x[0])
- # Keep the "smallest" case, which is typically the symbol without extension
- # suffix. But we have special cases that we want to fix.
- case = sorted_group[0]
- for i in range(1, len(sorted_group)):
- duplicated_cases[sorted_group[i][0]] = case[0]
- if case[0] == 'HlslSemanticGOOGLE':
- assert len(sorted_group) == 2, 'unexpected new variant for HlslSemantic'
- case = sorted_group[1]
- duplicated_cases[sorted_group[0][0]] = case[0]
- uniqued_cases.append(case)
-
- return uniqued_cases, duplicated_cases
+ """Prunes duplicate enum cases from the list.
+
+ Arguments:
+ - lst: List whose elements are to be uniqued. Assumes each element is a
+ (symbol, value) pair and elements already sorted according to value.
+
+ Returns:
+ - A list with all duplicates removed. The elements are sorted according to
+ value and, for each value, uniqued according to symbol.
+ original list,
+ - A map from deduplicated cases to the uniqued case.
+ """
+ cases = lst
+ uniqued_cases = []
+ duplicated_cases = {}
+
+ # First sort according to the value
+ cases.sort(key=lambda x: x[1])
+
+ # Then group them according to the value
+ for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
+ # For each value, sort according to the enumerant symbol.
+ sorted_group = sorted(groups, key=lambda x: x[0])
+ # Keep the "smallest" case, which is typically the symbol without extension
+ # suffix. But we have special cases that we want to fix.
+ case = sorted_group[0]
+ for i in range(1, len(sorted_group)):
+ duplicated_cases[sorted_group[i][0]] = case[0]
+ if case[0] == "HlslSemanticGOOGLE":
+ assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic"
+ case = sorted_group[1]
+ duplicated_cases[sorted_group[0][0]] = case[0]
+ uniqued_cases.append(case)
+
+ return uniqued_cases, duplicated_cases
def toposort(dag, sort_fn):
- """Topologically sorts the given dag.
+ """Topologically sorts the given dag.
- Arguments:
- - dag: a dict mapping from a node to its incoming nodes.
- - sort_fn: a function for sorting nodes in the same batch.
+ Arguments:
+ - dag: a dict mapping from a node to its incoming nodes.
+ - sort_fn: a function for sorting nodes in the same batch.
- Returns:
- A list containing topologically sorted nodes.
- """
+ Returns:
+ A list containing topologically sorted nodes.
+ """
- # Returns the next batch of nodes without incoming edges
- def get_next_batch(dag):
- while True:
- no_prev_nodes = set(node for node, prev in dag.items() if not prev)
- if not no_prev_nodes:
- break
- yield sorted(no_prev_nodes, key=sort_fn)
- dag = {
- node: (prev - no_prev_nodes)
- for node, prev in dag.items()
- if node not in no_prev_nodes
- }
- assert not dag, 'found cyclic dependency'
+ # Returns the next batch of nodes without incoming edges
+ def get_next_batch(dag):
+ while True:
+ no_prev_nodes = set(node for node, prev in dag.items() if not prev)
+ if not no_prev_nodes:
+ break
+ yield sorted(no_prev_nodes, key=sort_fn)
+ dag = {
+ node: (prev - no_prev_nodes)
+ for node, prev in dag.items()
+ if node not in no_prev_nodes
+ }
+ assert not dag, "found cyclic dependency"
- sorted_nodes = []
- for batch in get_next_batch(dag):
- sorted_nodes.extend(batch)
+ sorted_nodes = []
+ for batch in get_next_batch(dag):
+ sorted_nodes.extend(batch)
- return sorted_nodes
+ return sorted_nodes
def toposort_capabilities(all_cases, capability_mapping):
- """Returns topologically sorted capability (symbol, value) pairs.
-
- Arguments:
- - all_cases: all capability cases (containing symbol, value, and implied
- capabilities).
- - capability_mapping: mapping from duplicated capability symbols to the
- canonicalized symbol chosen for SPIRVBase.td.
-
- Returns:
- A list containing topologically sorted capability (symbol, value) pairs.
- """
- dag = {}
- name_to_value = {}
- for case in all_cases:
- # Get the current capability.
- cur = case['enumerant']
- name_to_value[cur] = case['value']
- # Ignore duplicated symbols.
- if cur in capability_mapping:
- continue
-
- # Get capabilities implied by the current capability.
- prev = case.get('capabilities', [])
- uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
- dag[cur] = uniqued_prev
-
- sorted_caps = toposort(dag, lambda x: name_to_value[x])
- # Attach the capability's value as the second component of the pair.
- return [(c, name_to_value[c]) for c in sorted_caps]
+ """Returns topologically sorted capability (symbol, value) pairs.
+
+ Arguments:
+ - all_cases: all capability cases (containing symbol, value, and implied
+ capabilities).
+ - capability_mapping: mapping from duplicated capability symbols to the
+ canonicalized symbol chosen for SPIRVBase.td.
+
+ Returns:
+ A list containing topologically sorted capability (symbol, value) pairs.
+ """
+ dag = {}
+ name_to_value = {}
+ for case in all_cases:
+ # Get the current capability.
+ cur = case["enumerant"]
+ name_to_value[cur] = case["value"]
+ # Ignore duplicated symbols.
+ if cur in capability_mapping:
+ continue
+
+ # Get capabilities implied by the current capability.
+ prev = case.get("capabilities", [])
+ uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
+ dag[cur] = uniqued_prev
+
+ sorted_caps = toposort(dag, lambda x: name_to_value[x])
+ # Attach the capability's value as the second component of the pair.
+ return [(c, name_to_value[c]) for c in sorted_caps]
def get_capability_mapping(operand_kinds):
- """Returns the capability mapping from duplicated cases to canonicalized ones.
+ """Returns the capability mapping from duplicated cases to canonicalized ones.
- Arguments:
- - operand_kinds: all operand kinds' grammar spec
+ Arguments:
+ - operand_kinds: all operand kinds' grammar spec
- Returns:
- - A map mapping from duplicated capability symbols to the canonicalized
- symbol chosen for SPIRVBase.td.
- """
- # Find the operand kind for capability
- cap_kind = {}
- for kind in operand_kinds:
- if kind['kind'] == 'Capability':
- cap_kind = kind
+ Returns:
+ - A map mapping from duplicated capability symbols to the canonicalized
+ symbol chosen for SPIRVBase.td.
+ """
+ # Find the operand kind for capability
+ cap_kind = {}
+ for kind in operand_kinds:
+ if kind["kind"] == "Capability":
+ cap_kind = kind
- kind_cases = [
- (case['enumerant'], case['value']) for case in cap_kind['enumerants']
- ]
- _, capability_mapping = uniquify_enum_cases(kind_cases)
+ kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]]
+ _, capability_mapping = uniquify_enum_cases(kind_cases)
- return capability_mapping
+ return capability_mapping
def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
- """Returns the availability specification string for the given enum case.
-
- Arguments:
- - enum_case: the enum case to generate availability spec for. It may contain
- 'version', 'lastVersion', 'extensions', or 'capabilities'.
- - capability_mapping: mapping from duplicated capability symbols to the
- canonicalized symbol chosen for SPIRVBase.td.
- - for_op: bool value indicating whether this is the availability spec for an
- op itself.
- - for_cap: bool value indicating whether this is the availability spec for
- capabilities themselves.
-
- Returns:
- - A `let availability = [...];` string if with availability spec or
- empty string if without availability spec
- """
- assert not (for_op and for_cap), 'cannot set both for_op and for_cap'
-
- DEFAULT_MIN_VERSION = 'MinVersion<SPIRV_V_1_0>'
- DEFAULT_MAX_VERSION = 'MaxVersion<SPIRV_V_1_6>'
- DEFAULT_CAP = 'Capability<[]>'
- DEFAULT_EXT = 'Extension<[]>'
-
- min_version = enum_case.get('version', '')
- if min_version == 'None':
- min_version = ''
- elif min_version:
- min_version = 'MinVersion<SPIRV_V_{}>'.format(min_version.replace('.', '_'))
- # TODO: delete this once ODS can support dialect-specific content
- # and we can use omission to mean no requirements.
- if for_op and not min_version:
- min_version = DEFAULT_MIN_VERSION
-
- max_version = enum_case.get('lastVersion', '')
- if max_version:
- max_version = 'MaxVersion<SPIRV_V_{}>'.format(max_version.replace('.', '_'))
- # TODO: delete this once ODS can support dialect-specific content
- # and we can use omission to mean no requirements.
- if for_op and not max_version:
- max_version = DEFAULT_MAX_VERSION
-
- exts = enum_case.get('extensions', [])
- if exts:
- exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts))))
- # We need to strip the minimal version requirement if this symbol is
- # available via an extension, which means *any* SPIR-V version can support
- # it as long as the extension is provided. The grammar's 'version' field
- # under such case should be interpreted as this symbol is introduced as
- # a core symbol since the given version, rather than a minimal version
- # requirement.
- min_version = DEFAULT_MIN_VERSION if for_op else ''
- # TODO: delete this once ODS can support dialect-specific content
- # and we can use omission to mean no requirements.
- if for_op and not exts:
- exts = DEFAULT_EXT
-
- caps = enum_case.get('capabilities', [])
- implies = ''
- if caps:
- canonicalized_caps = []
- for c in caps:
- if c in capability_mapping:
- canonicalized_caps.append(capability_mapping[c])
- else:
- canonicalized_caps.append(c)
- prefixed_caps = [
- 'SPIRV_C_{}'.format(c) for c in sorted(set(canonicalized_caps))
- ]
- if for_cap:
- # If this is generating the availability for capabilities, we need to
- # put the capability "requirements" in implies field because now
- # the "capabilities" field in the source grammar means so.
- caps = ''
- implies = 'list<I32EnumAttrCase> implies = [{}];'.format(
- ', '.join(prefixed_caps))
- else:
- caps = 'Capability<[{}]>'.format(', '.join(prefixed_caps))
- implies = ''
- # TODO: delete this once ODS can support dialect-specific content
- # and we can use omission to mean no requirements.
- if for_op and not caps:
- caps = DEFAULT_CAP
-
- avail = ''
- # Compose availability spec if any of the requirements is not empty.
- # For ops, because we have a default in SPIRV_Op class, omit if the spec
- # is the same.
- if (min_version or max_version or caps or exts) and not (
- for_op and min_version == DEFAULT_MIN_VERSION and
- max_version == DEFAULT_MAX_VERSION and caps == DEFAULT_CAP and
- exts == DEFAULT_EXT):
- joined_spec = ',\n '.join(
- [e for e in [min_version, max_version, exts, caps] if e])
- avail = '{} availability = [\n {}\n ];'.format(
- 'let' if for_op else 'list<Availability>', joined_spec)
-
- return '{}{}{}'.format(implies, '\n ' if implies and avail else '', avail)
+ """Returns the availability specification string for the given enum case.
+
+ Arguments:
+ - enum_case: the enum case to generate availability spec for. It may contain
+ 'version', 'lastVersion', 'extensions', or 'capabilities'.
+ - capability_mapping: mapping from duplicated capability symbols to the
+ canonicalized symbol chosen for SPIRVBase.td.
+ - for_op: bool value indicating whether this is the availability spec for an
+ op itself.
+ - for_cap: bool value indicating whether this is the availability spec for
+ capabilities themselves.
+
+ Returns:
+ - A `let availability = [...];` string if with availability spec or
+ empty string if without availability spec
+ """
+ assert not (for_op and for_cap), "cannot set both for_op and for_cap"
+
+ DEFAULT_MIN_VERSION = "MinVersion<SPIRV_V_1_0>"
+ DEFAULT_MAX_VERSION = "MaxVersion<SPIRV_V_1_6>"
+ DEFAULT_CAP = "Capability<[]>"
+ DEFAULT_EXT = "Extension<[]>"
+
+ min_version = enum_case.get("version", "")
+ if min_version == "None":
+ min_version = ""
+ elif min_version:
+ min_version = "MinVersion<SPIRV_V_{}>".format(min_version.replace(".", "_"))
+ # TODO: delete this once ODS can support dialect-specific content
+ # and we can use omission to mean no requirements.
+ if for_op and not min_version:
+ min_version = DEFAULT_MIN_VERSION
+
+ max_version = enum_case.get("lastVersion", "")
+ if max_version:
+ max_version = "MaxVersion<SPIRV_V_{}>".format(max_version.replace(".", "_"))
+ # TODO: delete this once ODS can support dialect-specific content
+ # and we can use omission to mean no requirements.
+ if for_op and not max_version:
+ max_version = DEFAULT_MAX_VERSION
+
+ exts = enum_case.get("extensions", [])
+ if exts:
+ exts = "Extension<[{}]>".format(", ".join(sorted(set(exts))))
+ # We need to strip the minimal version requirement if this symbol is
+ # available via an extension, which means *any* SPIR-V version can support
+ # it as long as the extension is provided. The grammar's 'version' field
+ # under such case should be interpreted as this symbol is introduced as
+ # a core symbol since the given version, rather than a minimal version
+ # requirement.
+ min_version = DEFAULT_MIN_VERSION if for_op else ""
+ # TODO: delete this once ODS can support dialect-specific content
+ # and we can use omission to mean no requirements.
+ if for_op and not exts:
+ exts = DEFAULT_EXT
+
+ caps = enum_case.get("capabilities", [])
+ implies = ""
+ if caps:
+ canonicalized_caps = []
+ for c in caps:
+ if c in capability_mapping:
+ canonicalized_caps.append(capability_mapping[c])
+ else:
+ canonicalized_caps.append(c)
+ prefixed_caps = [
+ "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
+ ]
+ if for_cap:
+ # If this is generating the availability for capabilities, we need to
+ # put the capability "requirements" in implies field because now
+ # the "capabilities" field in the source grammar means so.
+ caps = ""
+ implies = "list<I32EnumAttrCase> implies = [{}];".format(
+ ", ".join(prefixed_caps)
+ )
+ else:
+ caps = "Capability<[{}]>".format(", ".join(prefixed_caps))
+ implies = ""
+ # TODO: delete this once ODS can support dialect-specific content
+ # and we can use omission to mean no requirements.
+ if for_op and not caps:
+ caps = DEFAULT_CAP
+
+ avail = ""
+ # Compose availability spec if any of the requirements is not empty.
+ # For ops, because we have a default in SPIRV_Op class, omit if the spec
+ # is the same.
+ if (min_version or max_version or caps or exts) and not (
+ for_op
+ and min_version == DEFAULT_MIN_VERSION
+ and max_version == DEFAULT_MAX_VERSION
+ and caps == DEFAULT_CAP
+ and exts == DEFAULT_EXT
+ ):
+ joined_spec = ",\n ".join(
+ [e for e in [min_version, max_version, exts, caps] if e]
+ )
+ avail = "{} availability = [\n {}\n ];".format(
+ "let" if for_op else "list<Availability>", joined_spec
+ )
+
+ return "{}{}{}".format(implies, "\n " if implies and avail else "", avail)
def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
- """Generates the TableGen EnumAttr definition for the given operand kind.
-
- Returns:
- - The operand kind's name
- - A string containing the TableGen EnumAttr definition
- """
- if 'enumerants' not in operand_kind:
- return '', ''
-
- # Returns a symbol for the given case in the given kind. This function
- # handles Dim specially to avoid having numbers as the start of symbols,
- # which does not play well with C++ and the MLIR parser.
- def get_case_symbol(kind_name, case_name):
- if kind_name == 'Dim':
- if case_name == '1D' or case_name == '2D' or case_name == '3D':
- return 'Dim{}'.format(case_name)
- return case_name
-
- kind_name = operand_kind['kind']
- is_bit_enum = operand_kind['category'] == 'BitEnum'
- kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])
-
- name_to_case_dict = {}
- for case in operand_kind['enumerants']:
- name_to_case_dict[case['enumerant']] = case
-
- if kind_name == 'Capability':
- # Special treatment for capability cases: we need to sort them topologically
- # because a capability can refer to another via the 'implies' field.
- kind_cases = toposort_capabilities(operand_kind['enumerants'],
- capability_mapping)
- else:
- kind_cases = [(case['enumerant'], case['value'])
- for case in operand_kind['enumerants']]
- kind_cases, _ = uniquify_enum_cases(kind_cases)
- max_len = max([len(symbol) for (symbol, _) in kind_cases])
-
- # Generate the definition for each enum case
- case_category = 'I32Bit' if is_bit_enum else 'I32'
- fmt_str = 'def SPIRV_{acronym}_{case_name} {colon:>{offset}} '\
- '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}'
- case_defs = []
- for case_pair in kind_cases:
- name = case_pair[0]
- if is_bit_enum:
- value = int(case_pair[1], base=16)
+ """Generates the TableGen EnumAttr definition for the given operand kind.
+
+ Returns:
+ - The operand kind's name
+ - A string containing the TableGen EnumAttr definition
+ """
+ if "enumerants" not in operand_kind:
+ return "", ""
+
+ # Returns a symbol for the given case in the given kind. This function
+ # handles Dim specially to avoid having numbers as the start of symbols,
+ # which does not play well with C++ and the MLIR parser.
+ def get_case_symbol(kind_name, case_name):
+ if kind_name == "Dim":
+ if case_name == "1D" or case_name == "2D" or case_name == "3D":
+ return "Dim{}".format(case_name)
+ return case_name
+
+ kind_name = operand_kind["kind"]
+ is_bit_enum = operand_kind["category"] == "BitEnum"
+ kind_acronym = "".join([c for c in kind_name if c >= "A" and c <= "Z"])
+
+ name_to_case_dict = {}
+ for case in operand_kind["enumerants"]:
+ name_to_case_dict[case["enumerant"]] = case
+
+ if kind_name == "Capability":
+ # Special treatment for capability cases: we need to sort them topologically
+ # because a capability can refer to another via the 'implies' field.
+ kind_cases = toposort_capabilities(
+ operand_kind["enumerants"], capability_mapping
+ )
else:
- value = int(case_pair[1])
- avail = get_availability_spec(name_to_case_dict[name],
- capability_mapping,
- False, kind_name == 'Capability')
- if is_bit_enum:
- if value == 0:
- suffix = 'None'
- value = ''
- else:
- suffix = "Bit"
- value = ', {}'.format(int(math.log2(value)))
- else:
- suffix = ''
- value = ', {}'.format(value)
-
- case_def = fmt_str.format(
- category=case_category,
- suffix=suffix,
- acronym=kind_acronym,
- case_name=name,
- symbol=get_case_symbol(kind_name, name),
- case_value_part=value,
- avail=' {{\n {}\n}}'.format(avail) if avail else ';',
- colon=':',
- offset=(max_len + 1 - len(name)))
- case_defs.append(case_def)
- case_defs = '\n'.join(case_defs)
-
- # Generate the list of enum case names
- fmt_str = 'SPIRV_{acronym}_{symbol}';
- case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0])
- for case in kind_cases]
-
- # Split them into sublists and concatenate into multiple lines
- case_names = split_list_into_sublists(case_names)
- case_names = ['{:6}'.format('') + ', '.join(sublist)
- for sublist in case_names]
- case_names = ',\n'.join(case_names)
-
- # Generate the enum attribute definition
- kind_category = 'Bit' if is_bit_enum else 'I32'
- enum_attr = '''def SPIRV_{name}Attr :
+ kind_cases = [
+ (case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
+ ]
+ kind_cases, _ = uniquify_enum_cases(kind_cases)
+ max_len = max([len(symbol) for (symbol, _) in kind_cases])
+
+ # Generate the definition for each enum case
+ case_category = "I32Bit" if is_bit_enum else "I32"
+ fmt_str = (
+ "def SPIRV_{acronym}_{case_name} {colon:>{offset}} "
+ '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}'
+ )
+ case_defs = []
+ for case_pair in kind_cases:
+ name = case_pair[0]
+ if is_bit_enum:
+ value = int(case_pair[1], base=16)
+ else:
+ value = int(case_pair[1])
+ avail = get_availability_spec(
+ name_to_case_dict[name],
+ capability_mapping,
+ False,
+ kind_name == "Capability",
+ )
+ if is_bit_enum:
+ if value == 0:
+ suffix = "None"
+ value = ""
+ else:
+ suffix = "Bit"
+ value = ", {}".format(int(math.log2(value)))
+ else:
+ suffix = ""
+ value = ", {}".format(value)
+
+ case_def = fmt_str.format(
+ category=case_category,
+ suffix=suffix,
+ acronym=kind_acronym,
+ case_name=name,
+ symbol=get_case_symbol(kind_name, name),
+ case_value_part=value,
+ avail=" {{\n {}\n}}".format(avail) if avail else ";",
+ colon=":",
+ offset=(max_len + 1 - len(name)),
+ )
+ case_defs.append(case_def)
+ case_defs = "\n".join(case_defs)
+
+ # Generate the list of enum case names
+ fmt_str = "SPIRV_{acronym}_{symbol}"
+ case_names = [
+ fmt_str.format(acronym=kind_acronym, symbol=case[0]) for case in kind_cases
+ ]
+
+ # Split them into sublists and concatenate into multiple lines
+ case_names = split_list_into_sublists(case_names)
+ case_names = ["{:6}".format("") + ", ".join(sublist) for sublist in case_names]
+ case_names = ",\n".join(case_names)
+
+ # Generate the enum attribute definition
+ kind_category = "Bit" if is_bit_enum else "I32"
+ enum_attr = """def SPIRV_{name}Attr :
SPIRV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [
{cases}
- ]>;'''.format(
- name=kind_name,
- snake_name=snake_casify(kind_name),
- category=kind_category,
- cases=case_names)
- return kind_name, case_defs + '\n\n' + enum_attr
+ ]>;""".format(
+ name=kind_name,
+ snake_name=snake_casify(kind_name),
+ category=kind_category,
+ cases=case_names,
+ )
+ return kind_name, case_defs + "\n\n" + enum_attr
def gen_opcode(instructions):
- """ Generates the TableGen definition to map opname to opcode
-
- Returns:
- - A string containing the TableGen SPIRV_OpCode definition
- """
-
- max_len = max([len(inst['opname']) for inst in instructions])
- def_fmt_str = 'def SPIRV_OC_{name} {colon:>{offset}} '\
- 'I32EnumAttrCase<"{name}", {value}>;'
- opcode_defs = [
- def_fmt_str.format(
- name=inst['opname'],
- value=inst['opcode'],
- colon=':',
- offset=(max_len + 1 - len(inst['opname']))) for inst in instructions
- ]
- opcode_str = '\n'.join(opcode_defs)
-
- decl_fmt_str = 'SPIRV_OC_{name}'
- opcode_list = [
- decl_fmt_str.format(name=inst['opname']) for inst in instructions
- ]
- opcode_list = split_list_into_sublists(opcode_list)
- opcode_list = [
- '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list
- ]
- opcode_list = ',\n'.join(opcode_list)
- enum_attr = 'def SPIRV_OpcodeAttr :\n'\
- ' SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '\
- '"opcode", [\n'\
- '{lst}\n'\
- ' ]>;'.format(name='Opcode', lst=opcode_list)
- return opcode_str + '\n\n' + enum_attr
+ """Generates the TableGen definition to map opname to opcode
+
+ Returns:
+ - A string containing the TableGen SPIRV_OpCode definition
+ """
+
+ max_len = max([len(inst["opname"]) for inst in instructions])
+ def_fmt_str = (
+ "def SPIRV_OC_{name} {colon:>{offset}} " 'I32EnumAttrCase<"{name}", {value}>;'
+ )
+ opcode_defs = [
+ def_fmt_str.format(
+ name=inst["opname"],
+ value=inst["opcode"],
+ colon=":",
+ offset=(max_len + 1 - len(inst["opname"])),
+ )
+ for inst in instructions
+ ]
+ opcode_str = "\n".join(opcode_defs)
+
+ decl_fmt_str = "SPIRV_OC_{name}"
+ opcode_list = [decl_fmt_str.format(name=inst["opname"]) for inst in instructions]
+ opcode_list = split_list_into_sublists(opcode_list)
+ opcode_list = ["{:6}".format("") + ", ".join(sublist) for sublist in opcode_list]
+ opcode_list = ",\n".join(opcode_list)
+ enum_attr = (
+ "def SPIRV_OpcodeAttr :\n"
+ ' SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '
+ '"opcode", [\n'
+ "{lst}\n"
+ " ]>;".format(name="Opcode", lst=opcode_list)
+ )
+ return opcode_str + "\n\n" + enum_attr
+
def map_cap_to_opnames(instructions):
- """Maps capabilities to instructions enabled by those capabilities
+ """Maps capabilities to instructions enabled by those capabilities
- Arguments:
- - instructions: a list containing a subset of SPIR-V instructions' grammar
- Returns:
- - A map with keys representing capabilities and values of lists of
- instructions enabled by the corresponding key
- """
- cap_to_inst = {}
+ Arguments:
+ - instructions: a list containing a subset of SPIR-V instructions' grammar
+ Returns:
+ - A map with keys representing capabilities and values of lists of
+ instructions enabled by the corresponding key
+ """
+ cap_to_inst = {}
- for inst in instructions:
- caps = inst['capabilities'] if 'capabilities' in inst else ['0_core_0']
- for cap in caps:
- if cap not in cap_to_inst:
- cap_to_inst[cap] = []
- cap_to_inst[cap].append(inst['opname'])
+ for inst in instructions:
+ caps = inst["capabilities"] if "capabilities" in inst else ["0_core_0"]
+ for cap in caps:
+ if cap not in cap_to_inst:
+ cap_to_inst[cap] = []
+ cap_to_inst[cap].append(inst["opname"])
+
+ return cap_to_inst
- return cap_to_inst
def gen_instr_coverage_report(path, instructions):
- """Dumps to standard output a YAML report of current instruction coverage
+ """Dumps to standard output a YAML report of current instruction coverage
- Arguments:
- - path: the path to SPIRBase.td
- - instructions: a list containing all SPIR-V instructions' grammar
- """
- with open(path, 'r') as f:
- content = f.read()
+ Arguments:
+ - path: the path to SPIRBase.td
+ - instructions: a list containing all SPIR-V instructions' grammar
+ """
+ with open(path, "r") as f:
+ content = f.read()
- content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
+ content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
- existing_opcodes = [k[11:] for k in re.findall('def SPIRV_OC_\w+', content[1])]
- existing_instructions = list(
- filter(lambda inst: (inst['opname'] in existing_opcodes),
- instructions))
+ existing_opcodes = [k[11:] for k in re.findall("def SPIRV_OC_\w+", content[1])]
+ existing_instructions = list(
+ filter(lambda inst: (inst["opname"] in existing_opcodes), instructions)
+ )
- instructions_opnames = [inst['opname'] for inst in instructions]
+ instructions_opnames = [inst["opname"] for inst in instructions]
- remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
- remaining_instructions = list(
- filter(lambda inst: (inst['opname'] in remaining_opcodes),
- instructions))
+ remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
+ remaining_instructions = list(
+ filter(lambda inst: (inst["opname"] in remaining_opcodes), instructions)
+ )
- rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
- ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
+ rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
+ ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
- rem_cap_to_cov = {}
+ rem_cap_to_cov = {}
- # Calculate coverage for each capability
- for cap in rem_cap_to_instr:
- if cap not in ex_cap_to_instr:
- rem_cap_to_cov[cap] = 0.0
- else:
- rem_cap_to_cov[cap] = \
- (len(ex_cap_to_instr[cap]) / (len(ex_cap_to_instr[cap]) \
- + len(rem_cap_to_instr[cap])))
+ # Calculate coverage for each capability
+ for cap in rem_cap_to_instr:
+ if cap not in ex_cap_to_instr:
+ rem_cap_to_cov[cap] = 0.0
+ else:
+ rem_cap_to_cov[cap] = len(ex_cap_to_instr[cap]) / (
+ len(ex_cap_to_instr[cap]) + len(rem_cap_to_instr[cap])
+ )
- report = {}
+ report = {}
- # Merge the 3 maps into one report
- for cap in rem_cap_to_instr:
- report[cap] = {}
- report[cap]['Supported Instructions'] = \
+ # Merge the 3 maps into one report
+ for cap in rem_cap_to_instr:
+ report[cap] = {}
+ report[cap]["Supported Instructions"] = (
ex_cap_to_instr[cap] if cap in ex_cap_to_instr else []
- report[cap]['Unsupported Instructions'] = rem_cap_to_instr[cap]
- report[cap]['Coverage'] = '{}%'.format(int(rem_cap_to_cov[cap] * 100))
+ )
+ report[cap]["Unsupported Instructions"] = rem_cap_to_instr[cap]
+ report[cap]["Coverage"] = "{}%".format(int(rem_cap_to_cov[cap] * 100))
+
+ print(yaml.dump(report))
- print(yaml.dump(report))
def update_td_opcodes(path, instructions, filter_list):
- """Updates SPIRBase.td with new generated opcode cases.
-
- Arguments:
- - path: the path to SPIRBase.td
- - instructions: a list containing all SPIR-V instructions' grammar
- - filter_list: a list containing new opnames to add
- """
-
- with open(path, 'r') as f:
- content = f.read()
-
- content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
- assert len(content) == 3
-
- # Extend opcode list with existing list
- prefix = 'def SPIRV_OC_'
- existing_opcodes = [k[len(prefix):] for k in re.findall(prefix + '\w+', content[1])]
- filter_list.extend(existing_opcodes)
- filter_list = list(set(filter_list))
-
- # Generate the opcode for all instructions in SPIR-V
- filter_instrs = list(
- filter(lambda inst: (inst['opname'] in filter_list), instructions))
- # Sort instruction based on opcode
- filter_instrs.sort(key=lambda inst: inst['opcode'])
- opcode = gen_opcode(filter_instrs)
-
- # Substitute the opcode
- content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \
- opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \
+ """Updates SPIRBase.td with new generated opcode cases.
+
+ Arguments:
+ - path: the path to SPIRBase.td
+ - instructions: a list containing all SPIR-V instructions' grammar
+ - filter_list: a list containing new opnames to add
+ """
+
+ with open(path, "r") as f:
+ content = f.read()
+
+ content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
+ assert len(content) == 3
+
+ # Extend opcode list with existing list
+ prefix = "def SPIRV_OC_"
+ existing_opcodes = [
+ k[len(prefix) :] for k in re.findall(prefix + "\w+", content[1])
+ ]
+ filter_list.extend(existing_opcodes)
+ filter_list = list(set(filter_list))
+
+ # Generate the opcode for all instructions in SPIR-V
+ filter_instrs = list(
+ filter(lambda inst: (inst["opname"] in filter_list), instructions)
+ )
+ # Sort instruction based on opcode
+ filter_instrs.sort(key=lambda inst: inst["opcode"])
+ opcode = gen_opcode(filter_instrs)
+
+ # Substitute the opcode
+ content = (
+ content[0]
+ + AUTOGEN_OPCODE_SECTION_MARKER
+ + "\n\n"
+ + opcode
+ + "\n\n// End "
+ + AUTOGEN_OPCODE_SECTION_MARKER
+ content[2]
+ )
- with open(path, 'w') as f:
- f.write(content)
+ with open(path, "w") as f:
+ f.write(content)
def update_td_enum_attrs(path, operand_kinds, filter_list):
- """Updates SPIRBase.td with new generated enum definitions.
-
- Arguments:
- - path: the path to SPIRBase.td
- - operand_kinds: a list containing all operand kinds' grammar
- - filter_list: a list containing new enums to add
- """
- with open(path, 'r') as f:
- content = f.read()
-
- content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
- assert len(content) == 3
-
- # Extend filter list with existing enum definitions
- existing_kinds = [
- k[8:-4] for k in re.findall('def SPIRV_\w+Attr', content[1])]
- filter_list.extend(existing_kinds)
-
- capability_mapping = get_capability_mapping(operand_kinds)
-
- # Generate definitions for all enums in filter list
- defs = [
- gen_operand_kind_enum_attr(kind, capability_mapping)
- for kind in operand_kinds
- if kind['kind'] in filter_list
- ]
- # Sort alphabetically according to enum name
- defs.sort(key=lambda enum : enum[0])
- # Only keep the definitions from now on
- # Put Capability's definition at the very beginning because capability cases
- # will be referenced later
- defs = [enum[1] for enum in defs if enum[0] == 'Capability'
- ] + [enum[1] for enum in defs if enum[0] != 'Capability']
-
- # Substitute the old section
- content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \
- '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER \
- + content[2];
-
- with open(path, 'w') as f:
- f.write(content)
+ """Updates SPIRBase.td with new generated enum definitions.
+
+ Arguments:
+ - path: the path to SPIRBase.td
+ - operand_kinds: a list containing all operand kinds' grammar
+ - filter_list: a list containing new enums to add
+ """
+ with open(path, "r") as f:
+ content = f.read()
+
+ content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
+ assert len(content) == 3
+
+ # Extend filter list with existing enum definitions
+ existing_kinds = [k[8:-4] for k in re.findall("def SPIRV_\w+Attr", content[1])]
+ filter_list.extend(existing_kinds)
+
+ capability_mapping = get_capability_mapping(operand_kinds)
+
+ # Generate definitions for all enums in filter list
+ defs = [
+ gen_operand_kind_enum_attr(kind, capability_mapping)
+ for kind in operand_kinds
+ if kind["kind"] in filter_list
+ ]
+ # Sort alphabetically according to enum name
+ defs.sort(key=lambda enum: enum[0])
+ # Only keep the definitions from now on
+ # Put Capability's definition at the very beginning because capability cases
+ # will be referenced later
+ defs = [enum[1] for enum in defs if enum[0] == "Capability"] + [
+ enum[1] for enum in defs if enum[0] != "Capability"
+ ]
+
+ # Substitute the old section
+ content = (
+ content[0]
+ + AUTOGEN_ENUM_SECTION_MARKER
+ + "\n\n"
+ + "\n\n".join(defs)
+ + "\n\n// End "
+ + AUTOGEN_ENUM_SECTION_MARKER
+ + content[2]
+ )
+
+ with open(path, "w") as f:
+ f.write(content)
def snake_casify(name):
- """Turns the given name to follow snake_case convention."""
- return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
+ """Turns the given name to follow snake_case convention."""
+ return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
def map_spec_operand_to_ods_argument(operand):
- """Maps an operand in SPIR-V JSON spec to an op argument in ODS.
-
- Arguments:
- - A dict containing the operand's kind, quantifier, and name
-
- Returns:
- - A string containing both the type and name for the argument
- """
- kind = operand['kind']
- quantifier = operand.get('quantifier', '')
-
- # These instruction "operands" are for encoding the results; they should
- # not be handled here.
- assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind'
- assert kind != 'IdResult', 'unexpected to handle "IdResult" kind'
-
- if kind == 'IdRef':
- if quantifier == '':
- arg_type = 'SPIRV_Type'
- elif quantifier == '?':
- arg_type = 'Optional<SPIRV_Type>'
- else:
- arg_type = 'Variadic<SPIRV_Type>'
- elif kind == 'IdMemorySemantics' or kind == 'IdScope':
- # TODO: Need to further constrain 'IdMemorySemantics'
- # and 'IdScope' given that they should be generated from OpConstant.
- assert quantifier == '', ('unexpected to have optional/variadic memory '
- 'semantics or scope <id>')
- arg_type = 'SPIRV_' + kind[2:] + 'Attr'
- elif kind == 'LiteralInteger':
- if quantifier == '':
- arg_type = 'I32Attr'
- elif quantifier == '?':
- arg_type = 'OptionalAttr<I32Attr>'
+ """Maps an operand in SPIR-V JSON spec to an op argument in ODS.
+
+ Arguments:
+ - A dict containing the operand's kind, quantifier, and name
+
+ Returns:
+ - A string containing both the type and name for the argument
+ """
+ kind = operand["kind"]
+ quantifier = operand.get("quantifier", "")
+
+ # These instruction "operands" are for encoding the results; they should
+ # not be handled here.
+ assert kind != "IdResultType", 'unexpected to handle "IdResultType" kind'
+ assert kind != "IdResult", 'unexpected to handle "IdResult" kind'
+
+ if kind == "IdRef":
+ if quantifier == "":
+ arg_type = "SPIRV_Type"
+ elif quantifier == "?":
+ arg_type = "Optional<SPIRV_Type>"
+ else:
+ arg_type = "Variadic<SPIRV_Type>"
+ elif kind == "IdMemorySemantics" or kind == "IdScope":
+ # TODO: Need to further constrain 'IdMemorySemantics'
+ # and 'IdScope' given that they should be generated from OpConstant.
+ assert quantifier == "", (
+ "unexpected to have optional/variadic memory " "semantics or scope <id>"
+ )
+ arg_type = "SPIRV_" + kind[2:] + "Attr"
+ elif kind == "LiteralInteger":
+ if quantifier == "":
+ arg_type = "I32Attr"
+ elif quantifier == "?":
+ arg_type = "OptionalAttr<I32Attr>"
+ else:
+ arg_type = "OptionalAttr<I32ArrayAttr>"
+ elif (
+ kind == "LiteralString"
+ or kind == "LiteralContextDependentNumber"
+ or kind == "LiteralExtInstInteger"
+ or kind == "LiteralSpecConstantOpInteger"
+ or kind == "PairLiteralIntegerIdRef"
+ or kind == "PairIdRefLiteralInteger"
+ or kind == "PairIdRefIdRef"
+ ):
+ assert False, '"{}" kind unimplemented'.format(kind)
else:
- arg_type = 'OptionalAttr<I32ArrayAttr>'
- elif kind == 'LiteralString' or \
- kind == 'LiteralContextDependentNumber' or \
- kind == 'LiteralExtInstInteger' or \
- kind == 'LiteralSpecConstantOpInteger' or \
- kind == 'PairLiteralIntegerIdRef' or \
- kind == 'PairIdRefLiteralInteger' or \
- kind == 'PairIdRefIdRef':
- assert False, '"{}" kind unimplemented'.format(kind)
- else:
- # The rest are all enum operands that we represent with op attributes.
- assert quantifier != '*', 'unexpected to have variadic enum attribute'
- arg_type = 'SPIRV_{}Attr'.format(kind)
- if quantifier == '?':
- arg_type = 'OptionalAttr<{}>'.format(arg_type)
-
- name = operand.get('name', '')
- name = snake_casify(name) if name else kind.lower()
-
- return '{}:${}'.format(arg_type, name)
+ # The rest are all enum operands that we represent with op attributes.
+ assert quantifier != "*", "unexpected to have variadic enum attribute"
+ arg_type = "SPIRV_{}Attr".format(kind)
+ if quantifier == "?":
+ arg_type = "OptionalAttr<{}>".format(arg_type)
+
+ name = operand.get("name", "")
+ name = snake_casify(name) if name else kind.lower()
+
+ return "{}:${}".format(arg_type, name)
def get_description(text, appendix):
- """Generates the description for the given SPIR-V instruction.
-
- Arguments:
- - text: Textual description of the operation as string.
- - appendix: Additional contents to attach in description as string,
- includking IR examples, and others.
-
- Returns:
- - A string that corresponds to the description of the Tablegen op.
- """
- fmt_str = '{text}\n\n <!-- End of AutoGen section -->\n{appendix}\n '
- return fmt_str.format(text=text, appendix=appendix)
-
-
-def get_op_definition(instruction, opname, doc, existing_info, capability_mapping, settings):
- """Generates the TableGen op definition for the given SPIR-V instruction.
-
- Arguments:
- - instruction: the instruction's SPIR-V JSON grammar
- - doc: the instruction's SPIR-V HTML doc
- - existing_info: a dict containing potential manually specified sections for
- this instruction
- - capability_mapping: mapping from duplicated capability symbols to the
- canonicalized symbol chosen for SPIRVBase.td
-
- Returns:
- - A string containing the TableGen op definition
- """
- if settings.gen_cl_ops:
- fmt_str = ('def SPIRV_{opname}Op : '
- 'SPIRV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > '
- '{{\n let summary = {summary};\n\n let description = '
- '[{{\n{description}}}];{availability}\n')
- else:
- fmt_str = ('def SPIRV_{vendor_name}{opname_src}Op : '
- 'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
- '{{\n let summary = {summary};\n\n let description = '
- '[{{\n{description}}}];{availability}\n')
-
- vendor_name = ''
- inst_category = existing_info.get('inst_category', 'Op')
- if inst_category == 'Op':
- fmt_str +='\n let arguments = (ins{args});\n\n'\
- ' let results = (outs{results});\n'
- elif inst_category.endswith('VendorOp'):
- vendor_name = inst_category.split('VendorOp')[0].upper()
- assert len(vendor_name) != 0, 'Invalid instruction category'
-
- fmt_str +='{extras}'\
- '}}\n'
-
- opname_src = instruction['opname']
- if opname.startswith('Op'):
- opname_src = opname_src[2:]
- if len(vendor_name) > 0:
- assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
- opname_src = opname_src[:-len(vendor_name)]
-
- category_args = existing_info.get('category_args', '')
-
- if '\n' in doc:
- summary, text = doc.split('\n', 1)
- else:
- summary = doc
- text = ''
- wrapper = textwrap.TextWrapper(
- width=76, initial_indent=' ', subsequent_indent=' ')
-
- # Format summary. If the summary can fit in the same line, we print it out
- # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
- summary = summary.strip()
- if len(summary) + len(' let summary = "";') <= 80:
- summary = '"{}"'.format(summary)
- else:
- summary = '[{{\n{}\n }}]'.format(wrapper.fill(summary))
-
- # Wrap text
- text = text.split('\n')
- text = [wrapper.fill(line) for line in text if line]
- text = '\n\n'.join(text)
-
- operands = instruction.get('operands', [])
-
- # Op availability
- avail = get_availability_spec(instruction, capability_mapping, True, False)
- if avail:
- avail = '\n\n {0}'.format(avail)
-
- # Set op's result
- results = ''
- if len(operands) > 0 and operands[0]['kind'] == 'IdResultType':
- results = '\n SPIRV_Type:$result\n '
- operands = operands[1:]
- if 'results' in existing_info:
- results = existing_info['results']
-
- # Ignore the operand standing for the result <id>
- if len(operands) > 0 and operands[0]['kind'] == 'IdResult':
- operands = operands[1:]
-
- # Set op' argument
- arguments = existing_info.get('arguments', None)
- if arguments is None:
- arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
- arguments = ',\n '.join(arguments)
- if arguments:
- # Prepend and append whitespace for formatting
- arguments = '\n {}\n '.format(arguments)
-
- description = existing_info.get('description', None)
- if description is None:
- assembly = '\n ```\n'\
- ' [TODO]\n'\
- ' ```\n\n'\
- ' #### Example:\n\n'\
- ' ```mlir\n'\
- ' [TODO]\n' \
- ' ```'
- description = get_description(text, assembly)
-
- return fmt_str.format(
- opname=opname,
- opname_src=opname_src,
- opcode=instruction['opcode'],
- category_args=category_args,
- inst_category=inst_category,
- vendor_name=vendor_name,
- traits=existing_info.get('traits', ''),
- summary=summary,
- description=description,
- availability=avail,
- args=arguments,
- results=results,
- extras=existing_info.get('extras', ''))
+ """Generates the description for the given SPIR-V instruction.
+
+ Arguments:
+ - text: Textual description of the operation as string.
+ - appendix: Additional contents to attach in description as string,
+ includking IR examples, and others.
+
+ Returns:
+ - A string that corresponds to the description of the Tablegen op.
+ """
+ fmt_str = "{text}\n\n <!-- End of AutoGen section -->\n{appendix}\n "
+ return fmt_str.format(text=text, appendix=appendix)
+
+
+def get_op_definition(
+ instruction, opname, doc, existing_info, capability_mapping, settings
+):
+ """Generates the TableGen op definition for the given SPIR-V instruction.
+
+ Arguments:
+ - instruction: the instruction's SPIR-V JSON grammar
+ - doc: the instruction's SPIR-V HTML doc
+ - existing_info: a dict containing potential manually specified sections for
+ this instruction
+ - capability_mapping: mapping from duplicated capability symbols to the
+ canonicalized symbol chosen for SPIRVBase.td
+
+ Returns:
+ - A string containing the TableGen op definition
+ """
+ if settings.gen_cl_ops:
+ fmt_str = (
+ "def SPIRV_{opname}Op : "
+ 'SPIRV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > '
+ "{{\n let summary = {summary};\n\n let description = "
+ "[{{\n{description}}}];{availability}\n"
+ )
+ else:
+ fmt_str = (
+ "def SPIRV_{vendor_name}{opname_src}Op : "
+ 'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
+ "{{\n let summary = {summary};\n\n let description = "
+ "[{{\n{description}}}];{availability}\n"
+ )
+
+ vendor_name = ""
+ inst_category = existing_info.get("inst_category", "Op")
+ if inst_category == "Op":
+ fmt_str += (
+ "\n let arguments = (ins{args});\n\n" " let results = (outs{results});\n"
+ )
+ elif inst_category.endswith("VendorOp"):
+ vendor_name = inst_category.split("VendorOp")[0].upper()
+ assert len(vendor_name) != 0, "Invalid instruction category"
+
+ fmt_str += "{extras}" "}}\n"
+
+ opname_src = instruction["opname"]
+ if opname.startswith("Op"):
+ opname_src = opname_src[2:]
+ if len(vendor_name) > 0:
+ assert opname_src.endswith(
+ vendor_name
+ ), "op name does not match the instruction category"
+ opname_src = opname_src[: -len(vendor_name)]
+
+ category_args = existing_info.get("category_args", "")
+
+ if "\n" in doc:
+ summary, text = doc.split("\n", 1)
+ else:
+ summary = doc
+ text = ""
+ wrapper = textwrap.TextWrapper(
+ width=76, initial_indent=" ", subsequent_indent=" "
+ )
+
+ # Format summary. If the summary can fit in the same line, we print it out
+ # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
+ summary = summary.strip()
+ if len(summary) + len(' let summary = "";') <= 80:
+ summary = '"{}"'.format(summary)
+ else:
+ summary = "[{{\n{}\n }}]".format(wrapper.fill(summary))
+
+ # Wrap text
+ text = text.split("\n")
+ text = [wrapper.fill(line) for line in text if line]
+ text = "\n\n".join(text)
+
+ operands = instruction.get("operands", [])
+
+ # Op availability
+ avail = get_availability_spec(instruction, capability_mapping, True, False)
+ if avail:
+ avail = "\n\n {0}".format(avail)
+
+ # Set op's result
+ results = ""
+ if len(operands) > 0 and operands[0]["kind"] == "IdResultType":
+ results = "\n SPIRV_Type:$result\n "
+ operands = operands[1:]
+ if "results" in existing_info:
+ results = existing_info["results"]
+
+ # Ignore the operand standing for the result <id>
+ if len(operands) > 0 and operands[0]["kind"] == "IdResult":
+ operands = operands[1:]
+
+ # Set op' argument
+ arguments = existing_info.get("arguments", None)
+ if arguments is None:
+ arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
+ arguments = ",\n ".join(arguments)
+ if arguments:
+ # Prepend and append whitespace for formatting
+ arguments = "\n {}\n ".format(arguments)
+
+ description = existing_info.get("description", None)
+ if description is None:
+ assembly = (
+ "\n ```\n"
+ " [TODO]\n"
+ " ```\n\n"
+ " #### Example:\n\n"
+ " ```mlir\n"
+ " [TODO]\n"
+ " ```"
+ )
+ description = get_description(text, assembly)
+
+ return fmt_str.format(
+ opname=opname,
+ opname_src=opname_src,
+ opcode=instruction["opcode"],
+ category_args=category_args,
+ inst_category=inst_category,
+ vendor_name=vendor_name,
+ traits=existing_info.get("traits", ""),
+ summary=summary,
+ description=description,
+ availability=avail,
+ args=arguments,
+ results=results,
+ extras=existing_info.get("extras", ""),
+ )
def get_string_between(base, start, end):
- """Extracts a substring with a specified start and end from a string.
-
- Arguments:
- - base: string to extract from.
- - start: string to use as the start of the substring.
- - end: string to use as the end of the substring.
-
- Returns:
- - The substring if found
- - The part of the base after end of the substring. Is the base string itself
- if the substring wasnt found.
- """
- split = base.split(start, 1)
- if len(split) == 2:
- rest = split[1].split(end, 1)
- assert len(rest) == 2, \
- 'cannot find end "{end}" while extracting substring '\
- 'starting with {start}'.format(start=start, end=end)
- return rest[0].rstrip(end), rest[1]
- return '', split[0]
+ """Extracts a substring with a specified start and end from a string.
+
+ Arguments:
+ - base: string to extract from.
+ - start: string to use as the start of the substring.
+ - end: string to use as the end of the substring.
+
+ Returns:
+ - The substring if found
+ - The part of the base after end of the substring. Is the base string itself
+ if the substring wasnt found.
+ """
+ split = base.split(start, 1)
+ if len(split) == 2:
+ rest = split[1].split(end, 1)
+ assert len(rest) == 2, (
+ 'cannot find end "{end}" while extracting substring '
+ "starting with {start}".format(start=start, end=end)
+ )
+ return rest[0].rstrip(end), rest[1]
+ return "", split[0]
def get_string_between_nested(base, start, end):
- """Extracts a substring with a nested start and end from a string.
-
- Arguments:
- - base: string to extract from.
- - start: string to use as the start of the substring.
- - end: string to use as the end of the substring.
-
- Returns:
- - The substring if found
- - The part of the base after end of the substring. Is the base string itself
- if the substring wasn't found.
- """
- split = base.split(start, 1)
- if len(split) == 2:
- # Handle nesting delimiters
- rest = split[1]
- unmatched_start = 1
- index = 0
- while unmatched_start > 0 and index < len(rest):
- if rest[index:].startswith(end):
- unmatched_start -= 1
- if unmatched_start == 0:
- break
- index += len(end)
- elif rest[index:].startswith(start):
- unmatched_start += 1
- index += len(start)
- else:
- index += 1
-
- assert index < len(rest), \
- 'cannot find end "{end}" while extracting substring '\
- 'starting with "{start}"'.format(start=start, end=end)
- return rest[:index], rest[index + len(end):]
- return '', split[0]
+ """Extracts a substring with a nested start and end from a string.
+
+ Arguments:
+ - base: string to extract from.
+ - start: string to use as the start of the substring.
+ - end: string to use as the end of the substring.
+
+ Returns:
+ - The substring if found
+ - The part of the base after end of the substring. Is the base string itself
+ if the substring wasn't found.
+ """
+ split = base.split(start, 1)
+ if len(split) == 2:
+ # Handle nesting delimiters
+ rest = split[1]
+ unmatched_start = 1
+ index = 0
+ while unmatched_start > 0 and index < len(rest):
+ if rest[index:].startswith(end):
+ unmatched_start -= 1
+ if unmatched_start == 0:
+ break
+ index += len(end)
+ elif rest[index:].startswith(start):
+ unmatched_start += 1
+ index += len(start)
+ else:
+ index += 1
+
+ assert index < len(rest), (
+ 'cannot find end "{end}" while extracting substring '
+ 'starting with "{start}"'.format(start=start, end=end)
+ )
+ return rest[:index], rest[index + len(end) :]
+ return "", split[0]
def extract_td_op_info(op_def):
- """Extracts potentially manually specified sections in op's definition.
-
- Arguments: - A string containing the op's TableGen definition
-
- Returns:
- - A dict containing potential manually specified sections
- """
- # Get opname
- opname = [o[8:-2] for o in re.findall('def SPIRV_\w+Op', op_def)]
- assert len(opname) == 1, 'more than one ops in the same section!'
- opname = opname[0]
-
- # Get instruction category
- inst_category = [
- o[4:] for o in re.findall('SPIRV_\w+Op',
- op_def.split(':', 1)[1])
- ]
- assert len(inst_category) <= 1, 'more than one ops in the same section!'
- inst_category = inst_category[0] if len(inst_category) == 1 else 'Op'
-
- # Get category_args
- op_tmpl_params, _ = get_string_between_nested(op_def, '<', '>')
- opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
- category_args = rest.split('[', 1)[0]
-
- # Get traits
- traits, _ = get_string_between_nested(rest, '[', ']')
-
- # Get description
- description, rest = get_string_between(op_def, 'let description = [{\n',
- '}];\n')
-
- # Get arguments
- args, rest = get_string_between(rest, ' let arguments = (ins', ');\n')
-
- # Get results
- results, rest = get_string_between(rest, ' let results = (outs', ');\n')
-
- extras = rest.strip(' }\n')
- if extras:
- extras = '\n {}\n'.format(extras)
-
- return {
- # Prefix with 'Op' to make it consistent with SPIR-V spec
- 'opname': 'Op{}'.format(opname),
- 'inst_category': inst_category,
- 'category_args': category_args,
- 'traits': traits,
- 'description': description,
- 'arguments': args,
- 'results': results,
- 'extras': extras
- }
-
-
-def update_td_op_definitions(path, instructions, docs, filter_list,
- inst_category, capability_mapping, settings):
- """Updates SPIRVOps.td with newly generated op definition.
-
- Arguments:
- - path: path to SPIRVOps.td
- - instructions: SPIR-V JSON grammar for all instructions
- - docs: SPIR-V HTML doc for all instructions
- - filter_list: a list containing new opnames to include
- - capability_mapping: mapping from duplicated capability symbols to the
- canonicalized symbol chosen for SPIRVBase.td.
-
- Returns:
- - A string containing all the TableGen op definitions
- """
- with open(path, 'r') as f:
- content = f.read()
-
- # Split the file into chunks, each containing one op.
- ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
- header = ops[0]
- footer = ops[-1]
- ops = ops[1:-1]
-
- # For each existing op, extract the manually-written sections out to retain
- # them when re-generating the ops. Also append the existing ops to filter
- # list.
- name_op_map = {} # Map from opname to its existing ODS definition
- op_info_dict = {}
- for op in ops:
- info_dict = extract_td_op_info(op)
- opname = info_dict['opname']
- name_op_map[opname] = op
- op_info_dict[opname] = info_dict
- filter_list.append(opname)
- filter_list = sorted(list(set(filter_list)))
-
- op_defs = []
-
- if settings.gen_cl_ops:
- fix_opname = lambda src: src.replace('CL','').lower()
- else:
- fix_opname = lambda src: src
-
- for opname in filter_list:
- # Find the grammar spec for this op
- try:
- fixed_opname = fix_opname(opname)
- instruction = next(
- inst for inst in instructions if inst['opname'] == fixed_opname)
-
- op_defs.append(
- get_op_definition(
- instruction, opname, docs[fixed_opname],
- op_info_dict.get(opname, {'inst_category': inst_category}),
- capability_mapping, settings))
- except StopIteration:
- # This is an op added by us; use the existing ODS definition.
- op_defs.append(name_op_map[opname])
-
- # Substitute the old op definitions
- op_defs = [header] + op_defs + [footer]
- content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
-
- with open(path, 'w') as f:
- f.write(content)
-
-
-if __name__ == '__main__':
- import argparse
-
- cli_parser = argparse.ArgumentParser(
- description='Update SPIR-V dialect definitions using SPIR-V spec')
-
- cli_parser.add_argument(
- '--base-td-path',
- dest='base_td_path',
- type=str,
- default=None,
- help='Path to SPIRVBase.td')
- cli_parser.add_argument(
- '--op-td-path',
- dest='op_td_path',
- type=str,
- default=None,
- help='Path to SPIRVOps.td')
-
- cli_parser.add_argument(
- '--new-enum',
- dest='new_enum',
- type=str,
- default=None,
- help='SPIR-V enum to be added to SPIRVBase.td')
- cli_parser.add_argument(
- '--new-opcodes',
- dest='new_opcodes',
- type=str,
- default=None,
- nargs='*',
- help='update SPIR-V opcodes in SPIRVBase.td')
- cli_parser.add_argument(
- '--new-inst',
- dest='new_inst',
- type=str,
- default=None,
- nargs='*',
- help='SPIR-V instruction to be added to ops file')
- cli_parser.add_argument(
- '--inst-category',
- dest='inst_category',
- type=str,
- default='Op',
- help='SPIR-V instruction category used for choosing '\
- 'the TableGen base class to define this op')
- cli_parser.add_argument(
- '--gen-cl-ops',
- dest='gen_cl_ops',
- help='Generate OpenCL Extended Instruction Set op',
- action='store_true')
- cli_parser.set_defaults(gen_cl_ops=False)
- cli_parser.add_argument('--gen-inst-coverage', dest='gen_inst_coverage', action='store_true')
- cli_parser.set_defaults(gen_inst_coverage=False)
-
- args = cli_parser.parse_args()
-
- if args.gen_cl_ops:
- ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL
- ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL
- else:
- ext_html_url = None
- ext_json_url = None
-
- operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url)
-
- # Define new enum attr
- if args.new_enum is not None:
- assert args.base_td_path is not None
- filter_list = [args.new_enum] if args.new_enum else []
- update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
-
- # Define new opcode
- if args.new_opcodes is not None:
- assert args.base_td_path is not None
- update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
-
- # Define new op
- if args.new_inst is not None:
- assert args.op_td_path is not None
- docs = get_spirv_doc_from_html_spec(ext_html_url, args)
- capability_mapping = get_capability_mapping(operand_kinds)
- update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst,
- args.inst_category, capability_mapping, args)
- print('Done. Note that this script just generates a template; ', end='')
- print('please read the spec and update traits, arguments, and ', end='')
- print('results accordingly.')
-
- if args.gen_inst_coverage:
- gen_instr_coverage_report(args.base_td_path, instructions)
+ """Extracts potentially manually specified sections in op's definition.
+
+ Arguments: - A string containing the op's TableGen definition
+
+ Returns:
+ - A dict containing potential manually specified sections
+ """
+ # Get opname
+ opname = [o[8:-2] for o in re.findall("def SPIRV_\w+Op", op_def)]
+ assert len(opname) == 1, "more than one ops in the same section!"
+ opname = opname[0]
+
+ # Get instruction category
+ inst_category = [o[4:] for o in re.findall("SPIRV_\w+Op", op_def.split(":", 1)[1])]
+ assert len(inst_category) <= 1, "more than one ops in the same section!"
+ inst_category = inst_category[0] if len(inst_category) == 1 else "Op"
+
+ # Get category_args
+ op_tmpl_params, _ = get_string_between_nested(op_def, "<", ">")
+ opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
+ category_args = rest.split("[", 1)[0]
+
+ # Get traits
+ traits, _ = get_string_between_nested(rest, "[", "]")
+
+ # Get description
+ description, rest = get_string_between(op_def, "let description = [{\n", "}];\n")
+
+ # Get arguments
+ args, rest = get_string_between(rest, " let arguments = (ins", ");\n")
+
+ # Get results
+ results, rest = get_string_between(rest, " let results = (outs", ");\n")
+
+ extras = rest.strip(" }\n")
+ if extras:
+ extras = "\n {}\n".format(extras)
+
+ return {
+ # Prefix with 'Op' to make it consistent with SPIR-V spec
+ "opname": "Op{}".format(opname),
+ "inst_category": inst_category,
+ "category_args": category_args,
+ "traits": traits,
+ "description": description,
+ "arguments": args,
+ "results": results,
+ "extras": extras,
+ }
+
+
+def update_td_op_definitions(
+ path, instructions, docs, filter_list, inst_category, capability_mapping, settings
+):
+ """Updates SPIRVOps.td with newly generated op definition.
+
+ Arguments:
+ - path: path to SPIRVOps.td
+ - instructions: SPIR-V JSON grammar for all instructions
+ - docs: SPIR-V HTML doc for all instructions
+ - filter_list: a list containing new opnames to include
+ - capability_mapping: mapping from duplicated capability symbols to the
+ canonicalized symbol chosen for SPIRVBase.td.
+
+ Returns:
+ - A string containing all the TableGen op definitions
+ """
+ with open(path, "r") as f:
+ content = f.read()
+
+ # Split the file into chunks, each containing one op.
+ ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
+ header = ops[0]
+ footer = ops[-1]
+ ops = ops[1:-1]
+
+ # For each existing op, extract the manually-written sections out to retain
+ # them when re-generating the ops. Also append the existing ops to filter
+ # list.
+ name_op_map = {} # Map from opname to its existing ODS definition
+ op_info_dict = {}
+ for op in ops:
+ info_dict = extract_td_op_info(op)
+ opname = info_dict["opname"]
+ name_op_map[opname] = op
+ op_info_dict[opname] = info_dict
+ filter_list.append(opname)
+ filter_list = sorted(list(set(filter_list)))
+
+ op_defs = []
+
+ if settings.gen_cl_ops:
+ fix_opname = lambda src: src.replace("CL", "").lower()
+ else:
+ fix_opname = lambda src: src
+
+ for opname in filter_list:
+ # Find the grammar spec for this op
+ try:
+ fixed_opname = fix_opname(opname)
+ instruction = next(
+ inst for inst in instructions if inst["opname"] == fixed_opname
+ )
+
+ op_defs.append(
+ get_op_definition(
+ instruction,
+ opname,
+ docs[fixed_opname],
+ op_info_dict.get(opname, {"inst_category": inst_category}),
+ capability_mapping,
+ settings,
+ )
+ )
+ except StopIteration:
+ # This is an op added by us; use the existing ODS definition.
+ op_defs.append(name_op_map[opname])
+
+ # Substitute the old op definitions
+ op_defs = [header] + op_defs + [footer]
+ content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
+
+ with open(path, "w") as f:
+ f.write(content)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ cli_parser = argparse.ArgumentParser(
+ description="Update SPIR-V dialect definitions using SPIR-V spec"
+ )
+
+ cli_parser.add_argument(
+ "--base-td-path",
+ dest="base_td_path",
+ type=str,
+ default=None,
+ help="Path to SPIRVBase.td",
+ )
+ cli_parser.add_argument(
+ "--op-td-path",
+ dest="op_td_path",
+ type=str,
+ default=None,
+ help="Path to SPIRVOps.td",
+ )
+
+ cli_parser.add_argument(
+ "--new-enum",
+ dest="new_enum",
+ type=str,
+ default=None,
+ help="SPIR-V enum to be added to SPIRVBase.td",
+ )
+ cli_parser.add_argument(
+ "--new-opcodes",
+ dest="new_opcodes",
+ type=str,
+ default=None,
+ nargs="*",
+ help="update SPIR-V opcodes in SPIRVBase.td",
+ )
+ cli_parser.add_argument(
+ "--new-inst",
+ dest="new_inst",
+ type=str,
+ default=None,
+ nargs="*",
+ help="SPIR-V instruction to be added to ops file",
+ )
+ cli_parser.add_argument(
+ "--inst-category",
+ dest="inst_category",
+ type=str,
+ default="Op",
+ help="SPIR-V instruction category used for choosing "
+ "the TableGen base class to define this op",
+ )
+ cli_parser.add_argument(
+ "--gen-cl-ops",
+ dest="gen_cl_ops",
+ help="Generate OpenCL Extended Instruction Set op",
+ action="store_true",
+ )
+ cli_parser.set_defaults(gen_cl_ops=False)
+ cli_parser.add_argument(
+ "--gen-inst-coverage", dest="gen_inst_coverage", action="store_true"
+ )
+ cli_parser.set_defaults(gen_inst_coverage=False)
+
+ args = cli_parser.parse_args()
+
+ if args.gen_cl_ops:
+ ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL
+ ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL
+ else:
+ ext_html_url = None
+ ext_json_url = None
+
+ operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url)
+
+ # Define new enum attr
+ if args.new_enum is not None:
+ assert args.base_td_path is not None
+ filter_list = [args.new_enum] if args.new_enum else []
+ update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
+
+ # Define new opcode
+ if args.new_opcodes is not None:
+ assert args.base_td_path is not None
+ update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
+
+ # Define new op
+ if args.new_inst is not None:
+ assert args.op_td_path is not None
+ docs = get_spirv_doc_from_html_spec(ext_html_url, args)
+ capability_mapping = get_capability_mapping(operand_kinds)
+ update_td_op_definitions(
+ args.op_td_path,
+ instructions,
+ docs,
+ args.new_inst,
+ args.inst_category,
+ capability_mapping,
+ args,
+ )
+ print("Done. Note that this script just generates a template; ", end="")
+ print("please read the spec and update traits, arguments, and ", end="")
+ print("results accordingly.")
+
+ if args.gen_inst_coverage:
+ gen_instr_coverage_report(args.base_td_path, instructions)
More information about the Mlir-commits
mailing list