[Mlir-commits] [mlir] f9008e6 - [NFC][Py Reformat] Reformat python files in mlir subdir

Tobias Hieta llvmlistbot at llvm.org
Thu May 25 23:05:53 PDT 2023


Author: Tobias Hieta
Date: 2023-05-26T08:05:40+02:00
New Revision: f9008e6366c2496b1ca1785b891d5578174ad63e

URL: https://github.com/llvm/llvm-project/commit/f9008e6366c2496b1ca1785b891d5578174ad63e
DIFF: https://github.com/llvm/llvm-project/commit/f9008e6366c2496b1ca1785b891d5578174ad63e.diff

LOG: [NFC][Py Reformat] Reformat python files in mlir subdir

This is an ongoing series of commits that are reformatting our
Python code.

Reformatting is done with `black`.

If you end up having problems merging this commit because you
have made changes to a python file, the best way to handle that
is to run git checkout --ours <yourfile> and then reformat it
with black.

If you run into any problems, post to discourse about it and
we will try to help.

RFC Thread below:

https://discourse.llvm.org/t/rfc-document-and-standardize-python-code-style

Differential Revision: https://reviews.llvm.org/D150782

Added: 
    

Modified: 
    mlir/benchmark/python/benchmark_sparse.py
    mlir/benchmark/python/common.py
    mlir/examples/standalone/test/CAPI/lit.local.cfg
    mlir/examples/standalone/test/lit.cfg.py
    mlir/examples/standalone/test/python/lit.local.cfg
    mlir/examples/standalone/test/python/smoketest.py
    mlir/python/mlir/_mlir_libs/__init__.py
    mlir/python/mlir/dialects/_arith_ops_ext.py
    mlir/python/mlir/dialects/_bufferization_ops_ext.py
    mlir/python/mlir/dialects/_builtin_ops_ext.py
    mlir/python/mlir/dialects/_func_ops_ext.py
    mlir/python/mlir/dialects/_linalg_ops_ext.py
    mlir/python/mlir/dialects/_loop_transform_ops_ext.py
    mlir/python/mlir/dialects/_memref_ops_ext.py
    mlir/python/mlir/dialects/_ml_program_ops_ext.py
    mlir/python/mlir/dialects/_ods_common.py
    mlir/python/mlir/dialects/_pdl_ops_ext.py
    mlir/python/mlir/dialects/_scf_ops_ext.py
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/python/mlir/dialects/_tensor_ops_ext.py
    mlir/python/mlir/dialects/_transform_ops_ext.py
    mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/types.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/python/mlir/dialects/python_test.py
    mlir/python/mlir/dialects/transform/__init__.py
    mlir/python/mlir/execution_engine.py
    mlir/python/mlir/ir.py
    mlir/python/mlir/runtime/np_to_memref.py
    mlir/test/CAPI/lit.local.cfg
    mlir/test/Conversion/GPUToCUDA/lit.local.cfg
    mlir/test/Conversion/GPUToROCm/lit.local.cfg
    mlir/test/Examples/Toy/Ch6/lit.local.cfg
    mlir/test/Examples/Toy/Ch7/lit.local.cfg
    mlir/test/Examples/lit.local.cfg
    mlir/test/Examples/standalone/lit.local.cfg
    mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg
    mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg
    mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg
    mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg
    mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg
    mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg
    mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
    mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
    mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
    mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
    mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
    mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
    mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
    mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg
    mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
    mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
    mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
    mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
    mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py
    mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
    mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
    mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
    mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
    mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
    mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
    mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py
    mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg
    mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg
    mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg
    mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg
    mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg
    mlir/test/Integration/GPU/CUDA/lit.local.cfg
    mlir/test/Integration/GPU/ROCM/lit.local.cfg
    mlir/test/Integration/lit.local.cfg
    mlir/test/Unit/lit.cfg.py
    mlir/test/lib/Dialect/Test/lit.local.cfg
    mlir/test/lib/Dialect/Transform/lit.local.cfg
    mlir/test/lib/Tools/PDLL/lit.local.cfg
    mlir/test/lib/Transforms/lit.local.cfg
    mlir/test/lit.cfg.py
    mlir/test/mlir-cpu-runner/lit.local.cfg
    mlir/test/mlir-pdll-lsp-server/lit.local.cfg
    mlir/test/mlir-pdll/lit.local.cfg
    mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
    mlir/test/mlir-vulkan-runner/lit.local.cfg
    mlir/test/python/develoment_files.py
    mlir/test/python/dialects/arith_dialect.py
    mlir/test/python/dialects/async_dialect.py
    mlir/test/python/dialects/builtin.py
    mlir/test/python/dialects/complex_dialect.py
    mlir/test/python/dialects/func.py
    mlir/test/python/dialects/gpu.py
    mlir/test/python/dialects/linalg/opdsl/arguments.py
    mlir/test/python/dialects/linalg/opdsl/assignments.py
    mlir/test/python/dialects/linalg/opdsl/doctests.py
    mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
    mlir/test/python/dialects/linalg/opdsl/emit_fill.py
    mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
    mlir/test/python/dialects/linalg/opdsl/emit_misc.py
    mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
    mlir/test/python/dialects/linalg/opdsl/lit.local.cfg
    mlir/test/python/dialects/linalg/opdsl/metadata.py
    mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
    mlir/test/python/dialects/linalg/ops.py
    mlir/test/python/dialects/math_dialect.py
    mlir/test/python/dialects/memref.py
    mlir/test/python/dialects/ml_program.py
    mlir/test/python/dialects/ods_helpers.py
    mlir/test/python/dialects/pdl_ops.py
    mlir/test/python/dialects/python_test.py
    mlir/test/python/dialects/quant.py
    mlir/test/python/dialects/scf.py
    mlir/test/python/dialects/shape.py
    mlir/test/python/dialects/sparse_tensor/dialect.py
    mlir/test/python/dialects/sparse_tensor/passes.py
    mlir/test/python/dialects/tensor.py
    mlir/test/python/dialects/transform.py
    mlir/test/python/dialects/transform_loop_ext.py
    mlir/test/python/dialects/transform_structured_ext.py
    mlir/test/python/dialects/vector.py
    mlir/test/python/execution_engine.py
    mlir/test/python/integration/dialects/linalg/opsrun.py
    mlir/test/python/ir/affine_expr.py
    mlir/test/python/ir/affine_map.py
    mlir/test/python/ir/array_attributes.py
    mlir/test/python/ir/attributes.py
    mlir/test/python/ir/blocks.py
    mlir/test/python/ir/builtin_types.py
    mlir/test/python/ir/context_managers.py
    mlir/test/python/ir/debug.py
    mlir/test/python/ir/diagnostic_handler.py
    mlir/test/python/ir/dialects.py
    mlir/test/python/ir/exception.py
    mlir/test/python/ir/insertion_point.py
    mlir/test/python/ir/integer_set.py
    mlir/test/python/ir/location.py
    mlir/test/python/ir/module.py
    mlir/test/python/ir/operation.py
    mlir/test/python/ir/symbol_table.py
    mlir/test/python/ir/value.py
    mlir/test/python/lit.local.cfg
    mlir/test/python/pass_manager.py
    mlir/test/tblgen-lsp-server/lit.local.cfg
    mlir/utils/gdb-scripts/prettyprinters.py
    mlir/utils/generate-test-checks.py
    mlir/utils/jupyter/mlir_opt_kernel/__main__.py
    mlir/utils/jupyter/mlir_opt_kernel/install.py
    mlir/utils/jupyter/mlir_opt_kernel/kernel.py
    mlir/utils/lldb-scripts/mlirDataFormatters.py
    mlir/utils/mbr/mbr/__init__.py
    mlir/utils/mbr/mbr/discovery.py
    mlir/utils/mbr/mbr/main.py
    mlir/utils/mbr/mbr/stats.py
    mlir/utils/spirv/gen_spirv_dialect.py

Removed: 
    


################################################################################
diff  --git a/mlir/benchmark/python/benchmark_sparse.py b/mlir/benchmark/python/benchmark_sparse.py
index 6d7a39690734c..72b3ef1a3f424 100644
--- a/mlir/benchmark/python/benchmark_sparse.py
+++ b/mlir/benchmark/python/benchmark_sparse.py
@@ -25,7 +25,7 @@
 def matmul_dsl(
     A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
     B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
-    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)
+    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
 ):
     """Helper function for mlir sparse matrix multiplication benchmark."""
     C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
@@ -43,6 +43,7 @@ def benchmark_sparse_mlir_multiplication():
         param2_type = ir.RankedTensorType.get([1500, 2000], f64)
         result_type = ir.RankedTensorType.get([1000, 2000], f64)
         with ir.InsertionPoint(module.body):
+
             @func.FuncOp.from_py_func(param1_type, param2_type, result_type)
             def sparse_kernel(x, y, z):
                 return matmul_dsl(x, y, outs=[z])
@@ -51,37 +52,34 @@ def compiler():
         with ir.Context(), ir.Location.unknown():
             kernel_func = get_kernel_func_from_module(module)
             timer_func = emit_timer_func()
-            wrapped_func = emit_benchmark_wrapped_main_func(
-                kernel_func,
-                timer_func
-            )
+            wrapped_func = emit_benchmark_wrapped_main_func(kernel_func, timer_func)
             main_module_with_benchmark = ir.Module.parse(
                 str(timer_func) + str(wrapped_func) + str(kernel_func)
             )
             setup_passes(main_module_with_benchmark)
             c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "")
-            assert os.path.exists(c_runner_utils),\
-                f"{c_runner_utils} does not exist." \
-                f" Please pass a valid value for" \
+            assert os.path.exists(c_runner_utils), (
+                f"{c_runner_utils} does not exist."
+                f" Please pass a valid value for"
                 f" MLIR_C_RUNNER_UTILS environment variable."
+            )
             runner_utils = os.getenv("MLIR_RUNNER_UTILS", "")
-            assert os.path.exists(runner_utils),\
-                f"{runner_utils} does not exist." \
-                f" Please pass a valid value for MLIR_RUNNER_UTILS" \
+            assert os.path.exists(runner_utils), (
+                f"{runner_utils} does not exist."
+                f" Please pass a valid value for MLIR_RUNNER_UTILS"
                 f" environment variable."
+            )
 
             engine = ExecutionEngine(
                 main_module_with_benchmark,
                 3,
-                shared_libs=[c_runner_utils, runner_utils]
+                shared_libs=[c_runner_utils, runner_utils],
             )
             return engine.invoke
 
     def runner(engine_invoke):
         compiled_program_args = []
-        for argument_type in [
-            result_type, param1_type, param2_type, result_type
-        ]:
+        for argument_type in [result_type, param1_type, param2_type, result_type]:
             argument_type_str = str(argument_type)
             dimensions_str = re.sub("<|>|tensor", "", argument_type_str)
             dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]]
@@ -111,6 +109,7 @@ def benchmark_np_matrix_multiplication():
     benchmark, we don't have any `compiler` function returned. We just return
     the `runner` function.
     """
+
     def runner():
         argument1 = np.random.uniform(low=0.0, high=100.0, size=(1000, 1500))
         argument2 = np.random.uniform(low=0.0, high=100.0, size=(1500, 2000))

diff  --git a/mlir/benchmark/python/common.py b/mlir/benchmark/python/common.py
index 3634641074f3e..c605726df2a5f 100644
--- a/mlir/benchmark/python/common.py
+++ b/mlir/benchmark/python/common.py
@@ -10,8 +10,7 @@
 
 
 def setup_passes(mlir_module):
-    """Setup pass pipeline parameters for benchmark functions.
-    """
+    """Setup pass pipeline parameters for benchmark functions."""
     opt = (
         "parallelization-strategy=none"
         " vectorization-strategy=none vl=1 enable-simd-index32=False"
@@ -43,12 +42,15 @@ def get_kernel_func_from_module(module: ir.Module) -> func.FuncOp:
     This function only works for a module with one region, one block, and one
     operation.
     """
-    assert len(module.operation.regions) == 1, \
-        "Expected kernel module to have only one region"
-    assert len(module.operation.regions[0].blocks) == 1, \
-        "Expected kernel module to have only one block"
-    assert len(module.operation.regions[0].blocks[0].operations) == 1, \
-        "Expected kernel module to have only one operation"
+    assert (
+        len(module.operation.regions) == 1
+    ), "Expected kernel module to have only one region"
+    assert (
+        len(module.operation.regions[0].blocks) == 1
+    ), "Expected kernel module to have only one block"
+    assert (
+        len(module.operation.regions[0].blocks[0].operations) == 1
+    ), "Expected kernel module to have only one operation"
     return module.operation.regions[0].blocks[0].operations[0]
 
 
@@ -57,8 +59,7 @@ def emit_timer_func() -> func.FuncOp:
     used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included.
     """
     i64_type = ir.IntegerType.get_signless(64)
-    nanoTime = func.FuncOp(
-        "nanoTime", ([], [i64_type]), visibility="private")
+    nanoTime = func.FuncOp("nanoTime", ([], [i64_type]), visibility="private")
     nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
     return nanoTime
 
@@ -76,9 +77,8 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
     wrapped_func = func.FuncOp(
         # Same signature and an extra buffer of indices to save timings.
         "main",
-        (kernel_func.arguments.types + [memref_of_i64_type],
-         kernel_func.type.results),
-        visibility="public"
+        (kernel_func.arguments.types + [memref_of_i64_type], kernel_func.type.results),
+        visibility="public",
     )
     wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
 
@@ -88,13 +88,13 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
         zero = arith.ConstantOp.create_index(0)
         n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero)
         one = arith.ConstantOp.create_index(1)
-        iter_args = list(wrapped_func.arguments[-num_results - 1:-1])
+        iter_args = list(wrapped_func.arguments[-num_results - 1 : -1])
         loop = scf.ForOp(zero, n_iterations, one, iter_args)
         with ir.InsertionPoint(loop.body):
             start = func.CallOp(timer_func, [])
             call = func.CallOp(
                 kernel_func,
-                wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args
+                wrapped_func.arguments[: -num_results - 1] + loop.inner_iter_args,
             )
             end = func.CallOp(timer_func, [])
             time_taken = arith.SubIOp(end, start)

diff  --git a/mlir/examples/standalone/test/CAPI/lit.local.cfg b/mlir/examples/standalone/test/CAPI/lit.local.cfg
index f08a0de488ddd..bb0c17cdbada7 100644
--- a/mlir/examples/standalone/test/CAPI/lit.local.cfg
+++ b/mlir/examples/standalone/test/CAPI/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.add('.c')
+config.suffixes.add(".c")

diff  --git a/mlir/examples/standalone/test/lit.cfg.py b/mlir/examples/standalone/test/lit.cfg.py
index 601ac8f769f93..e27dddd7fb0b9 100644
--- a/mlir/examples/standalone/test/lit.cfg.py
+++ b/mlir/examples/standalone/test/lit.cfg.py
@@ -16,52 +16,55 @@
 # Configuration file for the 'lit' test runner.
 
 # name: The name of this test suite.
-config.name = 'STANDALONE'
+config.name = "STANDALONE"
 
 config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
 
 # suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.mlir']
+config.suffixes = [".mlir"]
 
 # test_source_root: The root path where tests are located.
 config.test_source_root = os.path.dirname(__file__)
 
 # test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
+config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
 
-config.substitutions.append(('%PATH%', config.environment['PATH']))
-config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
+config.substitutions.append(("%PATH%", config.environment["PATH"]))
+config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
 
-llvm_config.with_system_environment(
-    ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
+llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
 
 llvm_config.use_default_substitutions()
 
 # excludes: A list of directories to exclude from the testsuite. The 'Inputs'
 # subdirectories contain auxiliary inputs for various tests in their parent
 # directories.
-config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
+config.excludes = ["Inputs", "Examples", "CMakeLists.txt", "README.txt", "LICENSE.txt"]
 
 # test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
-config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin')
-config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib')
+config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
+config.standalone_tools_dir = os.path.join(config.standalone_obj_root, "bin")
+config.standalone_libs_dir = os.path.join(config.standalone_obj_root, "lib")
 
-config.substitutions.append(('%standalone_libs', config.standalone_libs_dir))
+config.substitutions.append(("%standalone_libs", config.standalone_libs_dir))
 
 # Tweak the PATH to include the tools dir.
-llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
+llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
 
 tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir]
 tools = [
-    'mlir-opt',
-    'standalone-capi-test',
-    'standalone-opt',
-    'standalone-translate',
+    "mlir-opt",
+    "standalone-capi-test",
+    "standalone-opt",
+    "standalone-translate",
 ]
 
 llvm_config.add_tool_substitutions(tools, tool_dirs)
 
-llvm_config.with_environment('PYTHONPATH', [
-    os.path.join(config.mlir_obj_dir, 'python_packages', 'standalone'),
-], append_path=True)
+llvm_config.with_environment(
+    "PYTHONPATH",
+    [
+        os.path.join(config.mlir_obj_dir, "python_packages", "standalone"),
+    ],
+    append_path=True,
+)

diff  --git a/mlir/examples/standalone/test/python/lit.local.cfg b/mlir/examples/standalone/test/python/lit.local.cfg
index b70b9d7a34fdd..3394f180e5121 100644
--- a/mlir/examples/standalone/test/python/lit.local.cfg
+++ b/mlir/examples/standalone/test/python/lit.local.cfg
@@ -1,4 +1,4 @@
-config.suffixes.add('.py')
+config.suffixes.add(".py")
 
 if not config.enable_bindings_python:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 0d8f41c27e8ef..08e08cbd2fe24 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -1,17 +1,16 @@
 # RUN: %python %s | FileCheck %s
 
 from mlir_standalone.ir import *
-from mlir_standalone.dialects import (
-  builtin as builtin_d,
-  standalone as standalone_d
-)
+from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d
 
 with Context():
-  standalone_d.register_dialect()
-  module = Module.parse("""
+    standalone_d.register_dialect()
+    module = Module.parse(
+        """
     %0 = arith.constant 2 : i32
     %1 = standalone.foo %0 : i32
-    """)
-  # CHECK: %[[C:.*]] = arith.constant 2 : i32
-  # CHECK: standalone.foo %[[C]] : i32
-  print(str(module))
+    """
+    )
+    # CHECK: %[[C:.*]] = arith.constant 2 : i32
+    # CHECK: standalone.foo %[[C]] : i32
+    print(str(module))

diff  --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 7d3d1f6ca873a..03fcb10130c3a 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -10,26 +10,26 @@
 
 
 def get_lib_dirs() -> Sequence[str]:
-  """Gets the lib directory for linking to shared libraries.
+    """Gets the lib directory for linking to shared libraries.
 
-  On some platforms, the package may need to be built specially to export
-  development libraries.
-  """
-  return [_this_dir]
+    On some platforms, the package may need to be built specially to export
+    development libraries.
+    """
+    return [_this_dir]
 
 
 def get_include_dirs() -> Sequence[str]:
-  """Gets the include directory for compiling against exported C libraries.
+    """Gets the include directory for compiling against exported C libraries.
 
-  Depending on how the package was build, development C libraries may or may
-  not be present.
-  """
-  return [os.path.join(_this_dir, "include")]
+    Depending on how the package was build, development C libraries may or may
+    not be present.
+    """
+    return [os.path.join(_this_dir, "include")]
 
 
 # Perform Python level site initialization. This involves:
 #   1. Attempting to load initializer modules, specific to the distribution.
-#   2. Defining the concrete mlir.ir.Context that does site specific 
+#   2. Defining the concrete mlir.ir.Context that does site specific
 #      initialization.
 #
 # Aside from just being far more convenient to do this at the Python level,
@@ -38,91 +38,106 @@ def get_include_dirs() -> Sequence[str]:
 # in the scope of the base class __init__).
 #
 # For #1, we:
-#   a. Probe for modules named '_mlirRegisterEverything' and 
-#     '_site_initialize_{i}', where 'i' is a number starting at zero and 
+#   a. Probe for modules named '_mlirRegisterEverything' and
+#     '_site_initialize_{i}', where 'i' is a number starting at zero and
 #     proceeding so long as a module with the name is found.
 #   b. If the module has a 'register_dialects' attribute, it will be called
 #     immediately with a DialectRegistry to populate.
 #   c. If the module has a 'context_init_hook', it will be added to a list
-#     of callbacks that are invoked as the last step of Context 
+#     of callbacks that are invoked as the last step of Context
 #     initialization (and passed the Context under construction).
 #
 # This facility allows downstreams to customize Context creation to their
 # needs.
 def _site_initialize():
-  import importlib
-  import itertools
-  import logging
-  from ._mlir import ir
-  logger = logging.getLogger(__name__)
-  registry = ir.DialectRegistry()
-  post_init_hooks = []
-
-  def process_initializer_module(module_name):
-    try:
-      m = importlib.import_module(f".{module_name}", __name__)
-    except ModuleNotFoundError:
-      return False
-    except ImportError:
-      message = (f"Error importing mlir initializer {module_name}. This may "
-      "happen in unclean incremental builds but is likely a real bug if "
-      "encountered otherwise and the MLIR Python API may not function.")
-      logger.warning(message, exc_info=True)
-
-    logger.debug("Initializing MLIR with module: %s", module_name)
-    if hasattr(m, "register_dialects"):
-      logger.debug("Registering dialects from initializer %r", m)
-      m.register_dialects(registry)
-    if hasattr(m, "context_init_hook"):
-      logger.debug("Adding context init hook from %r", m)
-      post_init_hooks.append(m.context_init_hook)
-    return True
-
-
-  # If _mlirRegisterEverything is built, then include it as an initializer
-  # module.
-  process_initializer_module("_mlirRegisterEverything")
-
-  # Load all _site_initialize_{i} modules, where 'i' is a number starting
-  # at 0.
-  for i in itertools.count():
-    module_name = f"_site_initialize_{i}"
-    if not process_initializer_module(module_name):
-      break
-
-  class Context(ir._BaseContext):
-    def __init__(self, *args, **kwargs):
-      super().__init__(*args, **kwargs)
-      self.append_dialect_registry(registry)
-      for hook in post_init_hooks:
-        hook(self)
-      # TODO: There is some debate about whether we should eagerly load
-      # all dialects. It is being done here in order to preserve existing
-      # behavior. See: https://github.com/llvm/llvm-project/issues/56037
-      self.load_all_available_dialects()
-  ir.Context = Context
-
-  class MLIRError(Exception):
-    """
-    An exception with diagnostic information. Has the following fields:
-      message: str
-      error_diagnostics: List[ir.DiagnosticInfo]
-    """
-    def __init__(self, message, error_diagnostics):
-      self.message = message
-      self.error_diagnostics = error_diagnostics
-      super().__init__(message, error_diagnostics)
-
-    def __str__(self):
-      s = self.message
-      if self.error_diagnostics:
-        s += ':'
-      for diag in self.error_diagnostics:
-        s += "\nerror: "  + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n  ')
-        for note in diag.notes:
-          s += "\n note: "  + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n  ')
-      return s
-  ir.MLIRError = MLIRError
+    import importlib
+    import itertools
+    import logging
+    from ._mlir import ir
+
+    logger = logging.getLogger(__name__)
+    registry = ir.DialectRegistry()
+    post_init_hooks = []
+
+    def process_initializer_module(module_name):
+        try:
+            m = importlib.import_module(f".{module_name}", __name__)
+        except ModuleNotFoundError:
+            return False
+        except ImportError:
+            message = (
+                f"Error importing mlir initializer {module_name}. This may "
+                "happen in unclean incremental builds but is likely a real bug if "
+                "encountered otherwise and the MLIR Python API may not function."
+            )
+            logger.warning(message, exc_info=True)
+
+        logger.debug("Initializing MLIR with module: %s", module_name)
+        if hasattr(m, "register_dialects"):
+            logger.debug("Registering dialects from initializer %r", m)
+            m.register_dialects(registry)
+        if hasattr(m, "context_init_hook"):
+            logger.debug("Adding context init hook from %r", m)
+            post_init_hooks.append(m.context_init_hook)
+        return True
+
+    # If _mlirRegisterEverything is built, then include it as an initializer
+    # module.
+    process_initializer_module("_mlirRegisterEverything")
+
+    # Load all _site_initialize_{i} modules, where 'i' is a number starting
+    # at 0.
+    for i in itertools.count():
+        module_name = f"_site_initialize_{i}"
+        if not process_initializer_module(module_name):
+            break
+
+    class Context(ir._BaseContext):
+        def __init__(self, *args, **kwargs):
+            super().__init__(*args, **kwargs)
+            self.append_dialect_registry(registry)
+            for hook in post_init_hooks:
+                hook(self)
+            # TODO: There is some debate about whether we should eagerly load
+            # all dialects. It is being done here in order to preserve existing
+            # behavior. See: https://github.com/llvm/llvm-project/issues/56037
+            self.load_all_available_dialects()
+
+    ir.Context = Context
+
+    class MLIRError(Exception):
+        """
+        An exception with diagnostic information. Has the following fields:
+          message: str
+          error_diagnostics: List[ir.DiagnosticInfo]
+        """
+
+        def __init__(self, message, error_diagnostics):
+            self.message = message
+            self.error_diagnostics = error_diagnostics
+            super().__init__(message, error_diagnostics)
+
+        def __str__(self):
+            s = self.message
+            if self.error_diagnostics:
+                s += ":"
+            for diag in self.error_diagnostics:
+                s += (
+                    "\nerror: "
+                    + str(diag.location)[4:-1]
+                    + ": "
+                    + diag.message.replace("\n", "\n  ")
+                )
+                for note in diag.notes:
+                    s += (
+                        "\n note: "
+                        + str(note.location)[4:-1]
+                        + ": "
+                        + note.message.replace("\n", "\n  ")
+                    )
+            return s
+
+    ir.MLIRError = MLIRError
 
 
 _site_initialize()

diff  --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py
index 240859352ce21..df38f871710fe 100644
--- a/mlir/python/mlir/dialects/_arith_ops_ext.py
+++ b/mlir/python/mlir/dialects/_arith_ops_ext.py
@@ -3,72 +3,67 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_default_loc_context as _get_default_loc_context
+    from ..ir import *
+    from ._ods_common import get_default_loc_context as _get_default_loc_context
 
-  from typing import Any, List, Union
+    from typing import Any, List, Union
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 
 def _isa(obj: Any, cls: type):
-  try:
-    cls(obj)
-  except ValueError:
-    return False
-  return True
+    try:
+        cls(obj)
+    except ValueError:
+        return False
+    return True
 
 
 def _is_any_of(obj: Any, classes: List[type]):
-  return any(_isa(obj, cls) for cls in classes)
+    return any(_isa(obj, cls) for cls in classes)
 
 
 def _is_integer_like_type(type: Type):
-  return _is_any_of(type, [IntegerType, IndexType])
+    return _is_any_of(type, [IntegerType, IndexType])
 
 
 def _is_float_type(type: Type):
-  return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
+    return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
 
 
 class ConstantOp:
-  """Specialization for the constant op class."""
-
-  def __init__(self,
-               result: Type,
-               value: Union[int, float, Attribute],
-               *,
-               loc=None,
-               ip=None):
-    if isinstance(value, int):
-      super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
-    elif isinstance(value, float):
-      super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
-    else:
-      super().__init__(value, loc=loc, ip=ip)
-
-  @classmethod
-  def create_index(cls, value: int, *, loc=None, ip=None):
-    """Create an index-typed constant."""
-    return cls(
-        IndexType.get(context=_get_default_loc_context(loc)),
-        value,
-        loc=loc,
-        ip=ip)
-
-  @property
-  def type(self):
-    return self.results[0].type
-
-  @property
-  def value(self):
-    return Attribute(self.operation.attributes["value"])
-
-  @property
-  def literal_value(self) -> Union[int, float]:
-    if _is_integer_like_type(self.type):
-      return IntegerAttr(self.value).value
-    elif _is_float_type(self.type):
-      return FloatAttr(self.value).value
-    else:
-      raise ValueError("only integer and float constants have literal values")
+    """Specialization for the constant op class."""
+
+    def __init__(
+        self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+    ):
+        if isinstance(value, int):
+            super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+        elif isinstance(value, float):
+            super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+        else:
+            super().__init__(value, loc=loc, ip=ip)
+
+    @classmethod
+    def create_index(cls, value: int, *, loc=None, ip=None):
+        """Create an index-typed constant."""
+        return cls(
+            IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
+        )
+
+    @property
+    def type(self):
+        return self.results[0].type
+
+    @property
+    def value(self):
+        return Attribute(self.operation.attributes["value"])
+
+    @property
+    def literal_value(self) -> Union[int, float]:
+        if _is_integer_like_type(self.type):
+            return IntegerAttr(self.value).value
+        elif _is_float_type(self.type):
+            return FloatAttr(self.value).value
+        else:
+            raise ValueError("only integer and float constants have literal values")

diff  --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py
index 6ed35f4445c56..1066cb4c775ca 100644
--- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py
+++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py
@@ -3,36 +3,39 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from typing import Sequence, Union
-  from ..ir import *
-  from ._ods_common import get_default_loc_context
+    from typing import Sequence, Union
+    from ..ir import *
+    from ._ods_common import get_default_loc_context
 
-  from typing import Any, List, Union
+    from typing import Any, List, Union
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 
 class AllocTensorOp:
-  """Extends the bufferization.alloc_tensor op."""
+    """Extends the bufferization.alloc_tensor op."""
 
-  def __init__(self,
-               tensor_type: Type,
-               dynamic_sizes: Sequence[Value],
-               copy: Value,
-               size_hint: Value,
-               escape: BoolAttr,
-               *,
-               loc=None,
-               ip=None):
-    """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
-    context = get_default_loc_context(loc)
-    attributes = {}
-    if escape:
-      attributes["escape"] = escape
-    op = self.build_generic(
-        results=[tensor_type],
-        operands=[dynamic_sizes, copy, size_hint],
-        attributes=attributes,
-        loc=loc,
-        ip=ip)
-    OpView.__init__(self, op)
+    def __init__(
+        self,
+        tensor_type: Type,
+        dynamic_sizes: Sequence[Value],
+        copy: Value,
+        size_hint: Value,
+        escape: BoolAttr,
+        *,
+        loc=None,
+        ip=None
+    ):
+        """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
+        context = get_default_loc_context(loc)
+        attributes = {}
+        if escape:
+            attributes["escape"] = escape
+        op = self.build_generic(
+            results=[tensor_type],
+            operands=[dynamic_sizes, copy, size_hint],
+            attributes=attributes,
+            loc=loc,
+            ip=ip,
+        )
+        OpView.__init__(self, op)

diff  --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py
index b69163fa41519..27a60123050ac 100644
--- a/mlir/python/mlir/dialects/_builtin_ops_ext.py
+++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py
@@ -3,18 +3,18 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
+    from ..ir import *
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
+
 
 class ModuleOp:
-  """Specialization for the module op class."""
+    """Specialization for the module op class."""
 
-  def __init__(self, *, loc=None, ip=None):
-    super().__init__(self.build_generic(results=[], operands=[], loc=loc,
-                                        ip=ip))
-    body = self.regions[0].blocks.append()
+    def __init__(self, *, loc=None, ip=None):
+        super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
+        body = self.regions[0].blocks.append()
 
-  @property
-  def body(self):
-    return self.regions[0].blocks[0]
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]

diff  --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py
index 56df423d30a0f..6d264c33f1f9d 100644
--- a/mlir/python/mlir/dialects/_func_ops_ext.py
+++ b/mlir/python/mlir/dialects/_func_ops_ext.py
@@ -3,298 +3,317 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_default_loc_context as _get_default_loc_context
+    from ..ir import *
+    from ._ods_common import get_default_loc_context as _get_default_loc_context
 
-  import inspect
+    import inspect
 
-  from typing import Any, List, Optional, Sequence, Union
+    from typing import Any, List, Optional, Sequence, Union
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
 RESULT_ATTRIBUTE_NAME = "res_attrs"
 
+
 class ConstantOp:
-  """Specialization for the constant op class."""
+    """Specialization for the constant op class."""
 
-  def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
-    super().__init__(result, value, loc=loc, ip=ip)
+    def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
+        super().__init__(result, value, loc=loc, ip=ip)
 
-  @property
-  def type(self):
-    return self.results[0].type
+    @property
+    def type(self):
+        return self.results[0].type
 
 
 class FuncOp:
-  """Specialization for the func op class."""
-
-  def __init__(self,
-               name,
-               type,
-               *,
-               visibility=None,
-               body_builder=None,
-               loc=None,
-               ip=None):
-    """
-    Create a FuncOp with the provided `name`, `type`, and `visibility`.
-    - `name` is a string representing the function name.
-    - `type` is either a FunctionType or a pair of list describing inputs and
-      results.
-    - `visibility` is a string matching `public`, `private`, or `nested`. None
-      implies private visibility.
-    - `body_builder` is an optional callback, when provided a new entry block
-      is created and the callback is invoked with the new op as argument within
-      an InsertionPoint context already set for the block. The callback is
-      expected to insert a terminator in the block.
-    """
-    sym_name = StringAttr.get(str(name))
-
-    # If the type is passed as a tuple, build a FunctionType on the fly.
-    if isinstance(type, tuple):
-      type = FunctionType.get(inputs=type[0], results=type[1])
-
-    type = TypeAttr.get(type)
-    sym_visibility = StringAttr.get(
-        str(visibility)) if visibility is not None else None
-    super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
-    if body_builder:
-      entry_block = self.add_entry_block()
-      with InsertionPoint(entry_block):
-        body_builder(self)
-
-  @property
-  def is_external(self):
-    return len(self.regions[0].blocks) == 0
-
-  @property
-  def body(self):
-    return self.regions[0]
-
-  @property
-  def type(self):
-    return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
-  @property
-  def visibility(self):
-    return self.attributes["sym_visibility"]
-
-  @property
-  def name(self) -> StringAttr:
-    return StringAttr(self.attributes["sym_name"])
-
-  @property
-  def entry_block(self):
-    if self.is_external:
-      raise IndexError('External function does not have a body')
-    return self.regions[0].blocks[0]
-
-  def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
-    """
-    Add an entry block to the function body using the function signature to
-    infer block arguments.
-    Returns the newly created block
-    """
-    if not self.is_external:
-      raise IndexError('The function already has an entry block!')
-    self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
-    return self.body.blocks[0]
-
-  @property
-  def arg_attrs(self):
-    return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
-  @arg_attrs.setter
-  def arg_attrs(self, attribute: Union[ArrayAttr, list]):
-    if isinstance(attribute, ArrayAttr):
-      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
-    else:
-      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
-          attribute, context=self.context)
-
-  @property
-  def arguments(self):
-    return self.entry_block.arguments
-
-  @property
-  def result_attrs(self):
-    return self.attributes[RESULT_ATTRIBUTE_NAME]
-
-  @result_attrs.setter
-  def result_attrs(self, attribute: ArrayAttr):
-    self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
-
-  @classmethod
-  def from_py_func(FuncOp,
-                   *inputs: Type,
-                   results: Optional[Sequence[Type]] = None,
-                   name: Optional[str] = None):
-    """Decorator to define an MLIR FuncOp specified as a python function.
-
-    Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
-    active for the current thread (i.e. established in a `with` block).
-
-    When applied as a decorator to a Python function, an entry block will
-    be constructed for the FuncOp with types as specified in `*inputs`. The
-    block arguments will be passed positionally to the Python function. In
-    addition, if the Python function accepts keyword arguments generally or
-    has a corresponding keyword argument, the following will be passed:
-      * `func_op`: The `func` op being defined.
-
-    By default, the function name will be the Python function `__name__`. This
-    can be overriden by passing the `name` argument to the decorator.
-
-    If `results` is not specified, then the decorator will implicitly
-    insert a `ReturnOp` with the `Value`'s returned from the decorated
-    function. It will also set the `FuncOp` type with the actual return
-    value types. If `results` is specified, then the decorated function
-    must return `None` and no implicit `ReturnOp` is added (nor are the result
-    types updated). The implicit behavior is intended for simple, single-block
-    cases, and users should specify result types explicitly for any complicated
-    cases.
-
-    The decorated function can further be called from Python and will insert
-    a `CallOp` at the then-current insertion point, returning either None (
-    if no return values), a unary Value (for one result), or a list of Values).
-    This mechanism cannot be used to emit recursive calls (by construction).
-    """
-
-    def decorator(f):
-      from . import func
-      # Introspect the callable for optional features.
-      sig = inspect.signature(f)
-      has_arg_func_op = False
-      for param in sig.parameters.values():
-        if param.kind == param.VAR_KEYWORD:
-          has_arg_func_op = True
-        if param.name == "func_op" and (param.kind
-                                        == param.POSITIONAL_OR_KEYWORD or
-                                        param.kind == param.KEYWORD_ONLY):
-          has_arg_func_op = True
-
-      # Emit the FuncOp.
-      implicit_return = results is None
-      symbol_name = name or f.__name__
-      function_type = FunctionType.get(
-          inputs=inputs, results=[] if implicit_return else results)
-      func_op = FuncOp(name=symbol_name, type=function_type)
-      with InsertionPoint(func_op.add_entry_block()):
-        func_args = func_op.entry_block.arguments
-        func_kwargs = {}
-        if has_arg_func_op:
-          func_kwargs["func_op"] = func_op
-        return_values = f(*func_args, **func_kwargs)
-        if not implicit_return:
-          return_types = list(results)
-          assert return_values is None, (
-              "Capturing a python function with explicit `results=` "
-              "requires that the wrapped function returns None.")
-        else:
-          # Coerce return values, add ReturnOp and rewrite func type.
-          if return_values is None:
-            return_values = []
-          elif isinstance(return_values, tuple):
-            return_values = list(return_values)
-          elif isinstance(return_values, Value):
-            # Returning a single value is fine, coerce it into a list.
-            return_values = [return_values]
-          elif isinstance(return_values, OpView):
-            # Returning a single operation is fine, coerce its results a list.
-            return_values = return_values.operation.results
-          elif isinstance(return_values, Operation):
-            # Returning a single operation is fine, coerce its results a list.
-            return_values = return_values.results
-          else:
-            return_values = list(return_values)
-          func.ReturnOp(return_values)
-          # Recompute the function type.
-          return_types = [v.type for v in return_values]
-          function_type = FunctionType.get(inputs=inputs, results=return_types)
-          func_op.attributes["function_type"] = TypeAttr.get(function_type)
-
-      def emit_call_op(*call_args):
-        call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
-                              call_args)
-        if return_types is None:
-          return None
-        elif len(return_types) == 1:
-          return call_op.result
+    """Specialization for the func op class."""
+
+    def __init__(
+        self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+    ):
+        """
+        Create a FuncOp with the provided `name`, `type`, and `visibility`.
+        - `name` is a string representing the function name.
+        - `type` is either a FunctionType or a pair of list describing inputs and
+          results.
+        - `visibility` is a string matching `public`, `private`, or `nested`. None
+          implies private visibility.
+        - `body_builder` is an optional callback, when provided a new entry block
+          is created and the callback is invoked with the new op as argument within
+          an InsertionPoint context already set for the block. The callback is
+          expected to insert a terminator in the block.
+        """
+        sym_name = StringAttr.get(str(name))
+
+        # If the type is passed as a tuple, build a FunctionType on the fly.
+        if isinstance(type, tuple):
+            type = FunctionType.get(inputs=type[0], results=type[1])
+
+        type = TypeAttr.get(type)
+        sym_visibility = (
+            StringAttr.get(str(visibility)) if visibility is not None else None
+        )
+        super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+        if body_builder:
+            entry_block = self.add_entry_block()
+            with InsertionPoint(entry_block):
+                body_builder(self)
+
+    @property
+    def is_external(self):
+        return len(self.regions[0].blocks) == 0
+
+    @property
+    def body(self):
+        return self.regions[0]
+
+    @property
+    def type(self):
+        return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+    @property
+    def visibility(self):
+        return self.attributes["sym_visibility"]
+
+    @property
+    def name(self) -> StringAttr:
+        return StringAttr(self.attributes["sym_name"])
+
+    @property
+    def entry_block(self):
+        if self.is_external:
+            raise IndexError("External function does not have a body")
+        return self.regions[0].blocks[0]
+
+    def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
+        """
+        Add an entry block to the function body using the function signature to
+        infer block arguments.
+        Returns the newly created block
+        """
+        if not self.is_external:
+            raise IndexError("The function already has an entry block!")
+        self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
+        return self.body.blocks[0]
+
+    @property
+    def arg_attrs(self):
+        return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+    @arg_attrs.setter
+    def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+        if isinstance(attribute, ArrayAttr):
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
         else:
-          return call_op.results
-
-      wrapped = emit_call_op
-      wrapped.__name__ = f.__name__
-      wrapped.func_op = func_op
-      return wrapped
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+                attribute, context=self.context
+            )
+
+    @property
+    def arguments(self):
+        return self.entry_block.arguments
+
+    @property
+    def result_attrs(self):
+        return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+    @result_attrs.setter
+    def result_attrs(self, attribute: ArrayAttr):
+        self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+
+    @classmethod
+    def from_py_func(
+        FuncOp,
+        *inputs: Type,
+        results: Optional[Sequence[Type]] = None,
+        name: Optional[str] = None,
+    ):
+        """Decorator to define an MLIR FuncOp specified as a python function.
+
+        Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
+        active for the current thread (i.e. established in a `with` block).
+
+        When applied as a decorator to a Python function, an entry block will
+        be constructed for the FuncOp with types as specified in `*inputs`. The
+        block arguments will be passed positionally to the Python function. In
+        addition, if the Python function accepts keyword arguments generally or
+        has a corresponding keyword argument, the following will be passed:
+          * `func_op`: The `func` op being defined.
+
+        By default, the function name will be the Python function `__name__`. This
+        can be overriden by passing the `name` argument to the decorator.
+
+        If `results` is not specified, then the decorator will implicitly
+        insert a `ReturnOp` with the `Value`'s returned from the decorated
+        function. It will also set the `FuncOp` type with the actual return
+        value types. If `results` is specified, then the decorated function
+        must return `None` and no implicit `ReturnOp` is added (nor are the result
+        types updated). The implicit behavior is intended for simple, single-block
+        cases, and users should specify result types explicitly for any complicated
+        cases.
+
+        The decorated function can further be called from Python and will insert
+        a `CallOp` at the then-current insertion point, returning either None (
+        if no return values), a unary Value (for one result), or a list of Values).
+        This mechanism cannot be used to emit recursive calls (by construction).
+        """
+
+        def decorator(f):
+            from . import func
+
+            # Introspect the callable for optional features.
+            sig = inspect.signature(f)
+            has_arg_func_op = False
+            for param in sig.parameters.values():
+                if param.kind == param.VAR_KEYWORD:
+                    has_arg_func_op = True
+                if param.name == "func_op" and (
+                    param.kind == param.POSITIONAL_OR_KEYWORD
+                    or param.kind == param.KEYWORD_ONLY
+                ):
+                    has_arg_func_op = True
+
+            # Emit the FuncOp.
+            implicit_return = results is None
+            symbol_name = name or f.__name__
+            function_type = FunctionType.get(
+                inputs=inputs, results=[] if implicit_return else results
+            )
+            func_op = FuncOp(name=symbol_name, type=function_type)
+            with InsertionPoint(func_op.add_entry_block()):
+                func_args = func_op.entry_block.arguments
+                func_kwargs = {}
+                if has_arg_func_op:
+                    func_kwargs["func_op"] = func_op
+                return_values = f(*func_args, **func_kwargs)
+                if not implicit_return:
+                    return_types = list(results)
+                    assert return_values is None, (
+                        "Capturing a python function with explicit `results=` "
+                        "requires that the wrapped function returns None."
+                    )
+                else:
+                    # Coerce return values, add ReturnOp and rewrite func type.
+                    if return_values is None:
+                        return_values = []
+                    elif isinstance(return_values, tuple):
+                        return_values = list(return_values)
+                    elif isinstance(return_values, Value):
+                        # Returning a single value is fine, coerce it into a list.
+                        return_values = [return_values]
+                    elif isinstance(return_values, OpView):
+                        # Returning a single operation is fine, coerce its results a list.
+                        return_values = return_values.operation.results
+                    elif isinstance(return_values, Operation):
+                        # Returning a single operation is fine, coerce its results a list.
+                        return_values = return_values.results
+                    else:
+                        return_values = list(return_values)
+                    func.ReturnOp(return_values)
+                    # Recompute the function type.
+                    return_types = [v.type for v in return_values]
+                    function_type = FunctionType.get(
+                        inputs=inputs, results=return_types
+                    )
+                    func_op.attributes["function_type"] = TypeAttr.get(function_type)
+
+            def emit_call_op(*call_args):
+                call_op = func.CallOp(
+                    return_types, FlatSymbolRefAttr.get(symbol_name), call_args
+                )
+                if return_types is None:
+                    return None
+                elif len(return_types) == 1:
+                    return call_op.result
+                else:
+                    return call_op.results
+
+            wrapped = emit_call_op
+            wrapped.__name__ = f.__name__
+            wrapped.func_op = func_op
+            return wrapped
+
+        return decorator
 
-    return decorator
 
 class CallOp:
-  """Specialization for the call op class."""
-
-  def __init__(self,
-               calleeOrResults: Union[FuncOp, List[Type]],
-               argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
-               arguments: Optional[List] = None,
-               *,
-               loc=None,
-               ip=None):
-    """Creates an call operation.
-
-    The constructor accepts three 
diff erent forms:
-
-      1. A function op to be called followed by a list of arguments.
-      2. A list of result types, followed by the name of the function to be
-         called as string, following by a list of arguments.
-      3. A list of result types, followed by the name of the function to be
-         called as symbol reference attribute, followed by a list of arguments.
-
-    For example
-
-        f = func.FuncOp("foo", ...)
-        func.CallOp(f, [args])
-        func.CallOp([result_types], "foo", [args])
-
-    In all cases, the location and insertion point may be specified as keyword
-    arguments if not provided by the surrounding context managers.
-    """
-
-    # TODO: consider supporting constructor "overloads", e.g., through a custom
-    # or pybind-provided metaclass.
-    if isinstance(calleeOrResults, FuncOp):
-      if not isinstance(argumentsOrCallee, list):
-        raise ValueError(
-            "when constructing a call to a function, expected " +
-            "the second argument to be a list of call arguments, " +
-            f"got {type(argumentsOrCallee)}")
-      if arguments is not None:
-        raise ValueError("unexpected third argument when constructing a call" +
-                         "to a function")
-
-      super().__init__(
-          calleeOrResults.type.results,
-          FlatSymbolRefAttr.get(
-              calleeOrResults.name.value,
-              context=_get_default_loc_context(loc)),
-          argumentsOrCallee,
-          loc=loc,
-          ip=ip)
-      return
-
-    if isinstance(argumentsOrCallee, list):
-      raise ValueError("when constructing a call to a function by name, " +
-                       "expected the second argument to be a string or a " +
-                       f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
-
-    if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
-      super().__init__(
-          calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
-    elif isinstance(argumentsOrCallee, str):
-      super().__init__(
-          calleeOrResults,
-          FlatSymbolRefAttr.get(
-              argumentsOrCallee, context=_get_default_loc_context(loc)),
-          arguments,
-          loc=loc,
-          ip=ip)
+    """Specialization for the call op class."""
+
+    def __init__(
+        self,
+        calleeOrResults: Union[FuncOp, List[Type]],
+        argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
+        arguments: Optional[List] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an call operation.
+
+        The constructor accepts three 
diff erent forms:
+
+          1. A function op to be called followed by a list of arguments.
+          2. A list of result types, followed by the name of the function to be
+             called as string, following by a list of arguments.
+          3. A list of result types, followed by the name of the function to be
+             called as symbol reference attribute, followed by a list of arguments.
+
+        For example
+
+            f = func.FuncOp("foo", ...)
+            func.CallOp(f, [args])
+            func.CallOp([result_types], "foo", [args])
+
+        In all cases, the location and insertion point may be specified as keyword
+        arguments if not provided by the surrounding context managers.
+        """
+
+        # TODO: consider supporting constructor "overloads", e.g., through a custom
+        # or pybind-provided metaclass.
+        if isinstance(calleeOrResults, FuncOp):
+            if not isinstance(argumentsOrCallee, list):
+                raise ValueError(
+                    "when constructing a call to a function, expected "
+                    + "the second argument to be a list of call arguments, "
+                    + f"got {type(argumentsOrCallee)}"
+                )
+            if arguments is not None:
+                raise ValueError(
+                    "unexpected third argument when constructing a call"
+                    + "to a function"
+                )
+
+            super().__init__(
+                calleeOrResults.type.results,
+                FlatSymbolRefAttr.get(
+                    calleeOrResults.name.value, context=_get_default_loc_context(loc)
+                ),
+                argumentsOrCallee,
+                loc=loc,
+                ip=ip,
+            )
+            return
+
+        if isinstance(argumentsOrCallee, list):
+            raise ValueError(
+                "when constructing a call to a function by name, "
+                + "expected the second argument to be a string or a "
+                + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
+            )
+
+        if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
+            super().__init__(
+                calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
+            )
+        elif isinstance(argumentsOrCallee, str):
+            super().__init__(
+                calleeOrResults,
+                FlatSymbolRefAttr.get(
+                    argumentsOrCallee, context=_get_default_loc_context(loc)
+                ),
+                arguments,
+                loc=loc,
+                ip=ip,
+            )

diff  --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index eb9e969f33602..3f6d854ca3e2b 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -3,39 +3,45 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from typing import Optional, Sequence, Union
-  from ..ir import *
-  from ._ods_common import get_default_loc_context
-  from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
+    from typing import Optional, Sequence, Union
+    from ..ir import *
+    from ._ods_common import get_default_loc_context
+    from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from ._ods_common import get_op_result_or_value as _get_op_result_or_value
 
+
 def isa(cls: Type, ty: Type):
-  try:
-    cls(ty)
-    return True
-  except ValueError:
-    return False
+    try:
+        cls(ty)
+        return True
+    except ValueError:
+        return False
 
 
 class StructuredOpMixin:
-  """All structured ops use the same mixin class."""
+    """All structured ops use the same mixin class."""
 
-  def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
-    super().__init__(
-        self.build_generic(results=list(results),
-                           operands=[list(inputs), list(outputs)],
-                           loc=loc,
-                           ip=ip))
+    def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
+        super().__init__(
+            self.build_generic(
+                results=list(results),
+                operands=[list(inputs), list(outputs)],
+                loc=loc,
+                ip=ip,
+            )
+        )
 
 
 def select_opview_mixin(parent_opview_cls):
-  # TODO: This shouldn't be a heuristic: we should have a way to annotate
-  # the OpView to note that it is a structured op.
-  if ("__init__" not in parent_opview_cls.__dict__ and
-      hasattr(parent_opview_cls, "inputs") and
-      hasattr(parent_opview_cls, "outputs") and
-      hasattr(parent_opview_cls, "result_tensors")):
-    return StructuredOpMixin
+    # TODO: This shouldn't be a heuristic: we should have a way to annotate
+    # the OpView to note that it is a structured op.
+    if (
+        "__init__" not in parent_opview_cls.__dict__
+        and hasattr(parent_opview_cls, "inputs")
+        and hasattr(parent_opview_cls, "outputs")
+        and hasattr(parent_opview_cls, "result_tensors")
+    ):
+        return StructuredOpMixin

diff  --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
index 10079d32fd925..3536d45ab7369 100644
--- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
@@ -3,125 +3,130 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+    from ..ir import *
+    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Union
 
 
 class GetParentForOp:
-  """Extension for GetParentForOp."""
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      num_loops: Optional[int] = None,
-      ip=None,
-      loc=None,
-  ):
-    if num_loops is None:
-      num_loops = 1
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        num_loops=num_loops,
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for GetParentForOp."""
+
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        num_loops: Optional[int] = None,
+        ip=None,
+        loc=None,
+    ):
+        if num_loops is None:
+            num_loops = 1
+        super().__init__(
+            result_type,
+            _get_op_result_or_value(target),
+            num_loops=num_loops,
+            ip=ip,
+            loc=loc,
+        )
 
 
 class LoopOutlineOp:
-  """Extension for LoopOutlineOp."""
-
-  def __init__(
-      self,
-      function_type: Type,
-      call_type: Type,
-      target: Union[Operation, Value],
-      *,
-      func_name: Union[str, StringAttr],
-      ip=None,
-      loc=None,
-  ):
-    super().__init__(
-        function_type,
-        call_type,
-        _get_op_result_or_value(target),
-        func_name=(func_name if isinstance(func_name, StringAttr) else
-                   StringAttr.get(func_name)),
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for LoopOutlineOp."""
+
+    def __init__(
+        self,
+        function_type: Type,
+        call_type: Type,
+        target: Union[Operation, Value],
+        *,
+        func_name: Union[str, StringAttr],
+        ip=None,
+        loc=None,
+    ):
+        super().__init__(
+            function_type,
+            call_type,
+            _get_op_result_or_value(target),
+            func_name=(
+                func_name
+                if isinstance(func_name, StringAttr)
+                else StringAttr.get(func_name)
+            ),
+            ip=ip,
+            loc=loc,
+        )
 
 
 class LoopPeelOp:
-  """Extension for LoopPeelOp."""
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      fail_if_already_divisible: Union[bool, BoolAttr] = False,
-      ip=None,
-      loc=None,
-  ):
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        fail_if_already_divisible=(fail_if_already_divisible if isinstance(
-            fail_if_already_divisible, BoolAttr) else
-                                   BoolAttr.get(fail_if_already_divisible)),
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for LoopPeelOp."""
+
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        fail_if_already_divisible: Union[bool, BoolAttr] = False,
+        ip=None,
+        loc=None,
+    ):
+        super().__init__(
+            result_type,
+            _get_op_result_or_value(target),
+            fail_if_already_divisible=(
+                fail_if_already_divisible
+                if isinstance(fail_if_already_divisible, BoolAttr)
+                else BoolAttr.get(fail_if_already_divisible)
+            ),
+            ip=ip,
+            loc=loc,
+        )
 
 
 class LoopPipelineOp:
-  """Extension for LoopPipelineOp."""
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      iteration_interval: Optional[Union[int, IntegerAttr]] = None,
-      read_latency: Optional[Union[int, IntegerAttr]] = None,
-      ip=None,
-      loc=None,
-  ):
-    if iteration_interval is None:
-      iteration_interval = 1
-    if read_latency is None:
-      read_latency = 10
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        iteration_interval=iteration_interval,
-        read_latency=read_latency,
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for LoopPipelineOp."""
+
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        iteration_interval: Optional[Union[int, IntegerAttr]] = None,
+        read_latency: Optional[Union[int, IntegerAttr]] = None,
+        ip=None,
+        loc=None,
+    ):
+        if iteration_interval is None:
+            iteration_interval = 1
+        if read_latency is None:
+            read_latency = 10
+        super().__init__(
+            result_type,
+            _get_op_result_or_value(target),
+            iteration_interval=iteration_interval,
+            read_latency=read_latency,
+            ip=ip,
+            loc=loc,
+        )
 
 
 class LoopUnrollOp:
-  """Extension for LoopUnrollOp."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      *,
-      factor: Union[int, IntegerAttr],
-      ip=None,
-      loc=None,
-  ):
-    super().__init__(
-        _get_op_result_or_value(target),
-        factor=factor,
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for LoopUnrollOp."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        *,
+        factor: Union[int, IntegerAttr],
+        ip=None,
+        loc=None,
+    ):
+        super().__init__(
+            _get_op_result_or_value(target),
+            factor=factor,
+            ip=ip,
+            loc=loc,
+        )

diff  --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
index a00a087be79b2..825f1a0a7a6fa 100644
--- a/mlir/python/mlir/dialects/_memref_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_ops_ext.py
@@ -3,34 +3,34 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-  from ._ods_common import get_op_results_or_values as _get_op_results_or_values
+    from ..ir import *
+    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+    from ._ods_common import get_op_results_or_values as _get_op_results_or_values
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Sequence, Union
 
 
 class LoadOp:
-  """Specialization for the MemRef load operation."""
+    """Specialization for the MemRef load operation."""
 
-  def __init__(self,
-               memref: Union[Operation, OpView, Value],
-               indices: Optional[Union[Operation, OpView,
-                                       Sequence[Value]]] = None,
-               *,
-               loc=None,
-               ip=None):
-    """Creates a memref load operation.
+    def __init__(
+        self,
+        memref: Union[Operation, OpView, Value],
+        indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        loc=None,
+        ip=None
+    ):
+        """Creates a memref load operation.
 
-    Args:
-      memref: the buffer to load from.
-      indices: the list of subscripts, may be empty for zero-dimensional
-        buffers.
-      loc: user-visible location of the operation.
-      ip: insertion point.
-    """
-    indices_resolved = [] if indices is None else _get_op_results_or_values(
-        indices)
-    super().__init__(memref, indices_resolved, loc=loc, ip=ip)
+        Args:
+          memref: the buffer to load from.
+          indices: the list of subscripts, may be empty for zero-dimensional
+            buffers.
+          loc: user-visible location of the operation.
+          ip: insertion point.
+        """
+        indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
+        super().__init__(memref, indices_resolved, loc=loc, ip=ip)

diff  --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py
index 8db82cf81c678..c84d23c16ef93 100644
--- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py
+++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py
@@ -3,11 +3,11 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from typing import Union
-  from ..ir import *
-  from ._ods_common import get_default_loc_context as _get_default_loc_context
+    from typing import Union
+    from ..ir import *
+    from ._ods_common import get_default_loc_context as _get_default_loc_context
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from ._ml_program_ops_gen import *
 
@@ -17,100 +17,97 @@
 
 
 class FuncOp:
-  """Specialization for the func op class."""
-
-  def __init__(self,
-               name,
-               type,
-               *,
-               visibility=None,
-               body_builder=None,
-               loc=None,
-               ip=None):
-    """
-    Create a FuncOp with the provided `name`, `type`, and `visibility`.
-    - `name` is a string representing the function name.
-    - `type` is either a FunctionType or a pair of list describing inputs and
-      results.
-    - `visibility` is a string matching `public`, `private`, or `nested`. None
-      implies private visibility.
-    - `body_builder` is an optional callback, when provided a new entry block
-      is created and the callback is invoked with the new op as argument within
-      an InsertionPoint context already set for the block. The callback is
-      expected to insert a terminator in the block.
-    """
-    sym_name = StringAttr.get(str(name))
-
-    # If the type is passed as a tuple, build a FunctionType on the fly.
-    if isinstance(type, tuple):
-      type = FunctionType.get(inputs=type[0], results=type[1])
-
-    type = TypeAttr.get(type)
-    sym_visibility = StringAttr.get(
-        str(visibility)) if visibility is not None else None
-    super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
-    if body_builder:
-      entry_block = self.add_entry_block()
-      with InsertionPoint(entry_block):
-        body_builder(self)
-
-  @property
-  def is_external(self):
-    return len(self.regions[0].blocks) == 0
-
-  @property
-  def body(self):
-    return self.regions[0]
-
-  @property
-  def type(self):
-    return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
-  @property
-  def visibility(self):
-    return self.attributes["sym_visibility"]
-
-  @property
-  def name(self) -> StringAttr:
-    return StringAttr(self.attributes["sym_name"])
-
-  @property
-  def entry_block(self):
-    if self.is_external:
-      raise IndexError('External function does not have a body')
-    return self.regions[0].blocks[0]
-
-  def add_entry_block(self):
-    """
-    Add an entry block to the function body using the function signature to
-    infer block arguments.
-    Returns the newly created block
-    """
-    if not self.is_external:
-      raise IndexError('The function already has an entry block!')
-    self.body.blocks.append(*self.type.inputs)
-    return self.body.blocks[0]
-
-  @property
-  def arg_attrs(self):
-    return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
-  @arg_attrs.setter
-  def arg_attrs(self, attribute: Union[ArrayAttr, list]):
-    if isinstance(attribute, ArrayAttr):
-      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
-    else:
-      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
-          attribute, context=self.context)
-
-  @property
-  def arguments(self):
-    return self.entry_block.arguments
-
-  @property
-  def result_attrs(self):
-    return self.attributes[RESULT_ATTRIBUTE_NAME]
-
-  @result_attrs.setter
-  def result_attrs(self, attribute: ArrayAttr):
-    self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+    """Specialization for the func op class."""
+
+    def __init__(
+        self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+    ):
+        """
+        Create a FuncOp with the provided `name`, `type`, and `visibility`.
+        - `name` is a string representing the function name.
+        - `type` is either a FunctionType or a pair of list describing inputs and
+          results.
+        - `visibility` is a string matching `public`, `private`, or `nested`. None
+          implies private visibility.
+        - `body_builder` is an optional callback, when provided a new entry block
+          is created and the callback is invoked with the new op as argument within
+          an InsertionPoint context already set for the block. The callback is
+          expected to insert a terminator in the block.
+        """
+        sym_name = StringAttr.get(str(name))
+
+        # If the type is passed as a tuple, build a FunctionType on the fly.
+        if isinstance(type, tuple):
+            type = FunctionType.get(inputs=type[0], results=type[1])
+
+        type = TypeAttr.get(type)
+        sym_visibility = (
+            StringAttr.get(str(visibility)) if visibility is not None else None
+        )
+        super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+        if body_builder:
+            entry_block = self.add_entry_block()
+            with InsertionPoint(entry_block):
+                body_builder(self)
+
+    @property
+    def is_external(self):
+        return len(self.regions[0].blocks) == 0
+
+    @property
+    def body(self):
+        return self.regions[0]
+
+    @property
+    def type(self):
+        return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+    @property
+    def visibility(self):
+        return self.attributes["sym_visibility"]
+
+    @property
+    def name(self) -> StringAttr:
+        return StringAttr(self.attributes["sym_name"])
+
+    @property
+    def entry_block(self):
+        if self.is_external:
+            raise IndexError("External function does not have a body")
+        return self.regions[0].blocks[0]
+
+    def add_entry_block(self):
+        """
+        Add an entry block to the function body using the function signature to
+        infer block arguments.
+        Returns the newly created block
+        """
+        if not self.is_external:
+            raise IndexError("The function already has an entry block!")
+        self.body.blocks.append(*self.type.inputs)
+        return self.body.blocks[0]
+
+    @property
+    def arg_attrs(self):
+        return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+    @arg_attrs.setter
+    def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+        if isinstance(attribute, ArrayAttr):
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+        else:
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+                attribute, context=self.context
+            )
+
+    @property
+    def arguments(self):
+        return self.entry_block.arguments
+
+    @property
+    def result_attrs(self):
+        return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+    @result_attrs.setter
+    def result_attrs(self, attribute: ArrayAttr):
+        self.attributes[RESULT_ATTRIBUTE_NAME] = attribute

diff  --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 51b90081973c4..7655629a55425 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -18,144 +18,152 @@
 
 
 def extend_opview_class(ext_module):
-  """Decorator to extend an OpView class from an extension module.
-
-  Extension modules can expose various entry-points:
-    Stand-alone class with the same name as a parent OpView class (i.e.
-    "ReturnOp"). A name-based match is attempted first before falling back
-    to a below mechanism.
-
-    def select_opview_mixin(parent_opview_cls):
-      If defined, allows an appropriate mixin class to be selected dynamically
-      based on the parent OpView class. Should return NotImplemented if a
-      decision is not made.
-
-  Args:
-    ext_module: A module from which to locate extensions. Can be None if not
-      available.
-
-  Returns:
-    A decorator that takes an OpView subclass and further extends it as
-    needed.
-  """
-
-  def class_decorator(parent_opview_cls: type):
-    if ext_module is None:
-      return parent_opview_cls
-    mixin_cls = NotImplemented
-    # First try to resolve by name.
-    try:
-      mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
-    except AttributeError:
-      # Fall back to a select_opview_mixin hook.
-      try:
-        select_mixin = getattr(ext_module, "select_opview_mixin")
-      except AttributeError:
-        pass
-      else:
-        mixin_cls = select_mixin(parent_opview_cls)
-
-    if mixin_cls is NotImplemented or mixin_cls is None:
-      return parent_opview_cls
-
-    # Have a mixin_cls. Create an appropriate subclass.
-    try:
-
-      class LocalOpView(mixin_cls, parent_opview_cls):
-        pass
-    except TypeError as e:
-      raise TypeError(
-          f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e
-    LocalOpView.__name__ = parent_opview_cls.__name__
-    LocalOpView.__qualname__ = parent_opview_cls.__qualname__
-    return LocalOpView
-
-  return class_decorator
+    """Decorator to extend an OpView class from an extension module.
+
+    Extension modules can expose various entry-points:
+      Stand-alone class with the same name as a parent OpView class (i.e.
+      "ReturnOp"). A name-based match is attempted first before falling back
+      to a below mechanism.
+
+      def select_opview_mixin(parent_opview_cls):
+        If defined, allows an appropriate mixin class to be selected dynamically
+        based on the parent OpView class. Should return NotImplemented if a
+        decision is not made.
+
+    Args:
+      ext_module: A module from which to locate extensions. Can be None if not
+        available.
+
+    Returns:
+      A decorator that takes an OpView subclass and further extends it as
+      needed.
+    """
+
+    def class_decorator(parent_opview_cls: type):
+        if ext_module is None:
+            return parent_opview_cls
+        mixin_cls = NotImplemented
+        # First try to resolve by name.
+        try:
+            mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
+        except AttributeError:
+            # Fall back to a select_opview_mixin hook.
+            try:
+                select_mixin = getattr(ext_module, "select_opview_mixin")
+            except AttributeError:
+                pass
+            else:
+                mixin_cls = select_mixin(parent_opview_cls)
+
+        if mixin_cls is NotImplemented or mixin_cls is None:
+            return parent_opview_cls
+
+        # Have a mixin_cls. Create an appropriate subclass.
+        try:
+
+            class LocalOpView(mixin_cls, parent_opview_cls):
+                pass
+
+        except TypeError as e:
+            raise TypeError(
+                f"Could not mixin {mixin_cls} into {parent_opview_cls}"
+            ) from e
+        LocalOpView.__name__ = parent_opview_cls.__name__
+        LocalOpView.__qualname__ = parent_opview_cls.__qualname__
+        return LocalOpView
+
+    return class_decorator
 
 
 def segmented_accessor(elements, raw_segments, idx):
-  """
-  Returns a slice of elements corresponding to the idx-th segment.
-
-    elements: a sliceable container (operands or results).
-    raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
-        sizes of the segments.
-    idx: index of the segment.
-  """
-  segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
-  start = sum(segments[i] for i in range(idx))
-  end = start + segments[idx]
-  return elements[start:end]
-
-
-def equally_sized_accessor(elements, n_variadic, n_preceding_simple,
-                           n_preceding_variadic):
-  """
-  Returns a starting position and a number of elements per variadic group
-  assuming equally-sized groups and the given numbers of preceding groups.
-
-    elements: a sequential container.
-    n_variadic: the number of variadic groups in the container.
-    n_preceding_simple: the number of non-variadic groups preceding the current
-        group.
-    n_preceding_variadic: the number of variadic groups preceding the current
-        group.
-  """
-
-  total_variadic_length = len(elements) - n_variadic + 1
-  # This should be enforced by the C++-side trait verifier.
-  assert total_variadic_length % n_variadic == 0
-
-  elements_per_group = total_variadic_length // n_variadic
-  start = n_preceding_simple + n_preceding_variadic * elements_per_group
-  return start, elements_per_group
+    """
+    Returns a slice of elements corresponding to the idx-th segment.
+
+      elements: a sliceable container (operands or results).
+      raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
+          sizes of the segments.
+      idx: index of the segment.
+    """
+    segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
+    start = sum(segments[i] for i in range(idx))
+    end = start + segments[idx]
+    return elements[start:end]
+
+
+def equally_sized_accessor(
+    elements, n_variadic, n_preceding_simple, n_preceding_variadic
+):
+    """
+    Returns a starting position and a number of elements per variadic group
+    assuming equally-sized groups and the given numbers of preceding groups.
+
+      elements: a sequential container.
+      n_variadic: the number of variadic groups in the container.
+      n_preceding_simple: the number of non-variadic groups preceding the current
+          group.
+      n_preceding_variadic: the number of variadic groups preceding the current
+          group.
+    """
+
+    total_variadic_length = len(elements) - n_variadic + 1
+    # This should be enforced by the C++-side trait verifier.
+    assert total_variadic_length % n_variadic == 0
+
+    elements_per_group = total_variadic_length // n_variadic
+    start = n_preceding_simple + n_preceding_variadic * elements_per_group
+    return start, elements_per_group
 
 
 def get_default_loc_context(location=None):
-  """
-  Returns a context in which the defaulted location is created. If the location
-  is None, takes the current location from the stack, raises ValueError if there
-  is no location on the stack.
-  """
-  if location is None:
-    # Location.current raises ValueError if there is no current location.
-    return _cext.ir.Location.current.context
-  return location.context
+    """
+    Returns a context in which the defaulted location is created. If the location
+    is None, takes the current location from the stack, raises ValueError if there
+    is no location on the stack.
+    """
+    if location is None:
+        # Location.current raises ValueError if there is no current location.
+        return _cext.ir.Location.current.context
+    return location.context
 
 
 def get_op_result_or_value(
-    arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList]
+    arg: _Union[
+        _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList
+    ]
 ) -> _cext.ir.Value:
-  """Returns the given value or the single result of the given op.
-
-  This is useful to implement op constructors so that they can take other ops as
-  arguments instead of requiring the caller to extract results for every op.
-  Raises ValueError if provided with an op that doesn't have a single result.
-  """
-  if isinstance(arg, _cext.ir.OpView):
-    return arg.operation.result
-  elif isinstance(arg, _cext.ir.Operation):
-    return arg.result
-  elif isinstance(arg, _cext.ir.OpResultList):
-    return arg[0]
-  else:
-    assert isinstance(arg, _cext.ir.Value)
-    return arg
+    """Returns the given value or the single result of the given op.
+
+    This is useful to implement op constructors so that they can take other ops as
+    arguments instead of requiring the caller to extract results for every op.
+    Raises ValueError if provided with an op that doesn't have a single result.
+    """
+    if isinstance(arg, _cext.ir.OpView):
+        return arg.operation.result
+    elif isinstance(arg, _cext.ir.Operation):
+        return arg.result
+    elif isinstance(arg, _cext.ir.OpResultList):
+        return arg[0]
+    else:
+        assert isinstance(arg, _cext.ir.Value)
+        return arg
 
 
 def get_op_results_or_values(
-    arg: _Union[_cext.ir.OpView, _cext.ir.Operation,
-                _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]]
+    arg: _Union[
+        _cext.ir.OpView,
+        _cext.ir.Operation,
+        _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
+    ]
 ) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
-  """Returns the given sequence of values or the results of the given op.
-
-  This is useful to implement op constructors so that they can take other ops as
-  lists of arguments instead of requiring the caller to extract results for
-  every op.
-  """
-  if isinstance(arg, _cext.ir.OpView):
-    return arg.operation.results
-  elif isinstance(arg, _cext.ir.Operation):
-    return arg.results
-  else:
-    return [get_op_result_or_value(element) for element in arg]
+    """Returns the given sequence of values or the results of the given op.
+
+    This is useful to implement op constructors so that they can take other ops as
+    lists of arguments instead of requiring the caller to extract results for
+    every op.
+    """
+    if isinstance(arg, _cext.ir.OpView):
+        return arg.operation.results
+    elif isinstance(arg, _cext.ir.Operation):
+        return arg.results
+    else:
+        return [get_op_result_or_value(element) for element in arg]

diff  --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
index 40ccbef6351dc..fc9de0b7f7db6 100644
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py
+++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py
@@ -3,10 +3,10 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ..dialects import pdl
+    from ..ir import *
+    from ..dialects import pdl
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Union, Optional, Sequence, Mapping
 from ._ods_common import (
@@ -16,264 +16,256 @@
 
 
 class ApplyNativeConstraintOp:
-  """Specialization for PDL apply native constraint op class."""
-
-  def __init__(
-      self,
-      name: Union[str, StringAttr],
-      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if args is None:
-      args = []
-    args = _get_values(args)
-    super().__init__(name, args, loc=loc, ip=ip)
+    """Specialization for PDL apply native constraint op class."""
+
+    def __init__(
+        self,
+        name: Union[str, StringAttr],
+        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if args is None:
+            args = []
+        args = _get_values(args)
+        super().__init__(name, args, loc=loc, ip=ip)
 
 
 class ApplyNativeRewriteOp:
-  """Specialization for PDL apply native rewrite op class."""
-
-  def __init__(
-      self,
-      results: Sequence[Type],
-      name: Union[str, StringAttr],
-      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if args is None:
-      args = []
-    args = _get_values(args)
-    super().__init__(results, name, args, loc=loc, ip=ip)
+    """Specialization for PDL apply native rewrite op class."""
+
+    def __init__(
+        self,
+        results: Sequence[Type],
+        name: Union[str, StringAttr],
+        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if args is None:
+            args = []
+        args = _get_values(args)
+        super().__init__(results, name, args, loc=loc, ip=ip)
 
 
 class AttributeOp:
-  """Specialization for PDL attribute op class."""
+    """Specialization for PDL attribute op class."""
 
-  def __init__(
-      self,
-      valueType: Optional[Union[OpView, Operation, Value]] = None,
-      value: Optional[Attribute] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    valueType = valueType if valueType is None else _get_value(valueType)
-    result = pdl.AttributeType.get()
-    super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
+    def __init__(
+        self,
+        valueType: Optional[Union[OpView, Operation, Value]] = None,
+        value: Optional[Attribute] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        valueType = valueType if valueType is None else _get_value(valueType)
+        result = pdl.AttributeType.get()
+        super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
 
 
 class EraseOp:
-  """Specialization for PDL erase op class."""
+    """Specialization for PDL erase op class."""
 
-  def __init__(
-      self,
-      operation: Optional[Union[OpView, Operation, Value]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    operation = _get_value(operation)
-    super().__init__(operation, loc=loc, ip=ip)
+    def __init__(
+        self,
+        operation: Optional[Union[OpView, Operation, Value]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        operation = _get_value(operation)
+        super().__init__(operation, loc=loc, ip=ip)
 
 
 class OperandOp:
-  """Specialization for PDL operand op class."""
+    """Specialization for PDL operand op class."""
 
-  def __init__(
-      self,
-      type: Optional[Union[OpView, Operation, Value]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    type = type if type is None else _get_value(type)
-    result = pdl.ValueType.get()
-    super().__init__(result, valueType=type, loc=loc, ip=ip)
+    def __init__(
+        self,
+        type: Optional[Union[OpView, Operation, Value]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        type = type if type is None else _get_value(type)
+        result = pdl.ValueType.get()
+        super().__init__(result, valueType=type, loc=loc, ip=ip)
 
 
 class OperandsOp:
-  """Specialization for PDL operands op class."""
+    """Specialization for PDL operands op class."""
 
-  def __init__(
-      self,
-      types: Optional[Union[OpView, Operation, Value]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    types = types if types is None else _get_value(types)
-    result = pdl.RangeType.get(pdl.ValueType.get())
-    super().__init__(result, valueType=types, loc=loc, ip=ip)
+    def __init__(
+        self,
+        types: Optional[Union[OpView, Operation, Value]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        types = types if types is None else _get_value(types)
+        result = pdl.RangeType.get(pdl.ValueType.get())
+        super().__init__(result, valueType=types, loc=loc, ip=ip)
 
 
 class OperationOp:
-  """Specialization for PDL operand op class."""
-
-  def __init__(
-      self,
-      name: Optional[Union[str, StringAttr]] = None,
-      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      attributes: Optional[Mapping[str, Union[OpView, Operation,
-                                              Value]]] = None,
-      types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if types is None:
-      types = []
-    if attributes is None:
-      attributes = {}
-    if args is None:
-      args = []
-    args = _get_values(args)
-    attrNames = []
-    attrValues = []
-    for attrName, attrValue in attributes.items():
-      attrNames.append(StringAttr.get(attrName))
-      attrValues.append(_get_value(attrValue))
-    attrNames = ArrayAttr.get(attrNames)
-    types = _get_values(types)
-    result = pdl.OperationType.get()
-    super().__init__(result,
-                     args,
-                     attrValues,
-                     attrNames,
-                     types,
-                     opName=name,
-                     loc=loc,
-                     ip=ip)
+    """Specialization for PDL operand op class."""
+
+    def __init__(
+        self,
+        name: Optional[Union[str, StringAttr]] = None,
+        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
+        types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if types is None:
+            types = []
+        if attributes is None:
+            attributes = {}
+        if args is None:
+            args = []
+        args = _get_values(args)
+        attrNames = []
+        attrValues = []
+        for attrName, attrValue in attributes.items():
+            attrNames.append(StringAttr.get(attrName))
+            attrValues.append(_get_value(attrValue))
+        attrNames = ArrayAttr.get(attrNames)
+        types = _get_values(types)
+        result = pdl.OperationType.get()
+        super().__init__(
+            result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
+        )
 
 
 class PatternOp:
-  """Specialization for PDL pattern op class."""
-
-  def __init__(
-      self,
-      benefit: Union[IntegerAttr, int],
-      name: Optional[Union[StringAttr, str]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    """Creates an PDL `pattern` operation."""
-    super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
-    self.regions[0].blocks.append()
-
-  @property
-  def body(self):
-    """Return the body (block) of the pattern."""
-    return self.regions[0].blocks[0]
+    """Specialization for PDL pattern op class."""
+
+    def __init__(
+        self,
+        benefit: Union[IntegerAttr, int],
+        name: Optional[Union[StringAttr, str]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an PDL `pattern` operation."""
+        super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
+        self.regions[0].blocks.append()
+
+    @property
+    def body(self):
+        """Return the body (block) of the pattern."""
+        return self.regions[0].blocks[0]
 
 
 class ReplaceOp:
-  """Specialization for PDL replace op class."""
-
-  def __init__(
-      self,
-      op: Union[OpView, Operation, Value],
-      *,
-      with_op: Optional[Union[OpView, Operation, Value]] = None,
-      with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      loc=None,
-      ip=None,
-  ):
-    if with_values is None:
-      with_values = []
-    op = _get_value(op)
-    with_op = with_op if with_op is None else _get_value(with_op)
-    with_values = _get_values(with_values)
-    super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
+    """Specialization for PDL replace op class."""
+
+    def __init__(
+        self,
+        op: Union[OpView, Operation, Value],
+        *,
+        with_op: Optional[Union[OpView, Operation, Value]] = None,
+        with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        loc=None,
+        ip=None,
+    ):
+        if with_values is None:
+            with_values = []
+        op = _get_value(op)
+        with_op = with_op if with_op is None else _get_value(with_op)
+        with_values = _get_values(with_values)
+        super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
 
 
 class ResultOp:
-  """Specialization for PDL result op class."""
+    """Specialization for PDL result op class."""
 
-  def __init__(
-      self,
-      parent: Union[OpView, Operation, Value],
-      index: Union[IntegerAttr, int],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    parent = _get_value(parent)
-    result = pdl.ValueType.get()
-    super().__init__(result, parent, index, loc=loc, ip=ip)
+    def __init__(
+        self,
+        parent: Union[OpView, Operation, Value],
+        index: Union[IntegerAttr, int],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        parent = _get_value(parent)
+        result = pdl.ValueType.get()
+        super().__init__(result, parent, index, loc=loc, ip=ip)
 
 
 class ResultsOp:
-  """Specialization for PDL results op class."""
+    """Specialization for PDL results op class."""
 
-  def __init__(
-      self,
-      result: Type,
-      parent: Union[OpView, Operation, Value],
-      index: Optional[Union[IntegerAttr, int]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    parent = _get_value(parent)
-    super().__init__(result, parent, index=index, loc=loc, ip=ip)
+    def __init__(
+        self,
+        result: Type,
+        parent: Union[OpView, Operation, Value],
+        index: Optional[Union[IntegerAttr, int]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        parent = _get_value(parent)
+        super().__init__(result, parent, index=index, loc=loc, ip=ip)
 
 
 class RewriteOp:
-  """Specialization for PDL rewrite op class."""
-
-  def __init__(
-      self,
-      root: Optional[Union[OpView, Operation, Value]] = None,
-      name: Optional[Union[StringAttr, str]] = None,
-      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if args is None:
-      args = []
-    root = root if root is None else _get_value(root)
-    args = _get_values(args)
-    super().__init__(args, root=root, name=name, loc=loc, ip=ip)
-
-  def add_body(self):
-    """Add body (block) to the rewrite."""
-    self.regions[0].blocks.append()
-    return self.body
-
-  @property
-  def body(self):
-    """Return the body (block) of the rewrite."""
-    return self.regions[0].blocks[0]
+    """Specialization for PDL rewrite op class."""
+
+    def __init__(
+        self,
+        root: Optional[Union[OpView, Operation, Value]] = None,
+        name: Optional[Union[StringAttr, str]] = None,
+        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if args is None:
+            args = []
+        root = root if root is None else _get_value(root)
+        args = _get_values(args)
+        super().__init__(args, root=root, name=name, loc=loc, ip=ip)
+
+    def add_body(self):
+        """Add body (block) to the rewrite."""
+        self.regions[0].blocks.append()
+        return self.body
+
+    @property
+    def body(self):
+        """Return the body (block) of the rewrite."""
+        return self.regions[0].blocks[0]
 
 
 class TypeOp:
-  """Specialization for PDL type op class."""
+    """Specialization for PDL type op class."""
 
-  def __init__(self,
-               constantType: Optional[Union[TypeAttr, Type]] = None,
-               *,
-               loc=None,
-               ip=None):
-    result = pdl.TypeType.get()
-    super().__init__(result, constantType=constantType, loc=loc, ip=ip)
+    def __init__(
+        self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
+    ):
+        result = pdl.TypeType.get()
+        super().__init__(result, constantType=constantType, loc=loc, ip=ip)
 
 
 class TypesOp:
-  """Specialization for PDL types op class."""
-
-  def __init__(
-      self,
-      constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if constantTypes is None:
-      constantTypes = []
-    result = pdl.RangeType.get(pdl.TypeType.get())
-    super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
+    """Specialization for PDL types op class."""
+
+    def __init__(
+        self,
+        constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if constantTypes is None:
+            constantTypes = []
+        result = pdl.RangeType.get(pdl.TypeType.get())
+        super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)

diff  --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index 3c3e673021585..4b2519ef35357 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -3,105 +3,104 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
+    from ..ir import *
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Any, Optional, Sequence, Union
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+)
+
 
 class ForOp:
-  """Specialization for the SCF for op class."""
-
-  def __init__(self,
-               lower_bound,
-               upper_bound,
-               step,
-               iter_args: Optional[Union[Operation, OpView,
-                                         Sequence[Value]]] = None,
-               *,
-               loc=None,
-               ip=None):
-    """Creates an SCF `for` operation.
-
-    - `lower_bound` is the value to use as lower bound of the loop.
-    - `upper_bound` is the value to use as upper bound of the loop.
-    - `step` is the value to use as loop step.
-    - `iter_args` is a list of additional loop-carried arguments or an operation
-      producing them as results.
-    """
-    if iter_args is None:
-      iter_args = []
-    iter_args = _get_op_results_or_values(iter_args)
-
-    results = [arg.type for arg in iter_args]
-    super().__init__(
-        self.build_generic(
-            regions=1,
-            results=results,
-            operands=[
-                _get_op_result_or_value(o)
-                for o in [lower_bound, upper_bound, step]
-            ] + list(iter_args),
-            loc=loc,
-            ip=ip))
-    self.regions[0].blocks.append(IndexType.get(), *results)
-
-  @property
-  def body(self):
-    """Returns the body (block) of the loop."""
-    return self.regions[0].blocks[0]
-
-  @property
-  def induction_variable(self):
-    """Returns the induction variable of the loop."""
-    return self.body.arguments[0]
-
-  @property
-  def inner_iter_args(self):
-    """Returns the loop-carried arguments usable within the loop.
-
-    To obtain the loop-carried operands, use `iter_args`.
-    """
-    return self.body.arguments[1:]
+    """Specialization for the SCF for op class."""
+
+    def __init__(
+        self,
+        lower_bound,
+        upper_bound,
+        step,
+        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        loc=None,
+        ip=None
+    ):
+        """Creates an SCF `for` operation.
+
+        - `lower_bound` is the value to use as lower bound of the loop.
+        - `upper_bound` is the value to use as upper bound of the loop.
+        - `step` is the value to use as loop step.
+        - `iter_args` is a list of additional loop-carried arguments or an operation
+          producing them as results.
+        """
+        if iter_args is None:
+            iter_args = []
+        iter_args = _get_op_results_or_values(iter_args)
+
+        results = [arg.type for arg in iter_args]
+        super().__init__(
+            self.build_generic(
+                regions=1,
+                results=results,
+                operands=[
+                    _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
+                ]
+                + list(iter_args),
+                loc=loc,
+                ip=ip,
+            )
+        )
+        self.regions[0].blocks.append(IndexType.get(), *results)
+
+    @property
+    def body(self):
+        """Returns the body (block) of the loop."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def induction_variable(self):
+        """Returns the induction variable of the loop."""
+        return self.body.arguments[0]
+
+    @property
+    def inner_iter_args(self):
+        """Returns the loop-carried arguments usable within the loop.
+
+        To obtain the loop-carried operands, use `iter_args`.
+        """
+        return self.body.arguments[1:]
 
 
 class IfOp:
-  """Specialization for the SCF if op class."""
-
-  def __init__(self,
-               cond,
-               results_=[],
-               *,
-               hasElse=False,
-               loc=None,
-               ip=None):
-    """Creates an SCF `if` operation.
-
-    - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
-    - `hasElse` determines whether the if operation has the else branch.
-    """
-    operands = []
-    operands.append(cond)
-    results = []
-    results.extend(results_)
-    super().__init__(
-        self.build_generic(
-            regions=2,
-            results=results,
-            operands=operands,
-            loc=loc,
-            ip=ip))
-    self.regions[0].blocks.append(*[])
-    if hasElse:
-        self.regions[1].blocks.append(*[])
-
-  @property
-  def then_block(self):
-    """Returns the then block of the if operation."""
-    return self.regions[0].blocks[0]
-
-  @property
-  def else_block(self):
-    """Returns the else block of the if operation."""
-    return self.regions[1].blocks[0]
+    """Specialization for the SCF if op class."""
+
+    def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
+        """Creates an SCF `if` operation.
+
+        - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
+        - `hasElse` determines whether the if operation has the else branch.
+        """
+        operands = []
+        operands.append(cond)
+        results = []
+        results.extend(results_)
+        super().__init__(
+            self.build_generic(
+                regions=2, results=results, operands=operands, loc=loc, ip=ip
+            )
+        )
+        self.regions[0].blocks.append(*[])
+        if hasElse:
+            self.regions[1].blocks.append(*[])
+
+    @property
+    def then_block(self):
+        """Returns the then block of the if operation."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def else_block(self):
+        """Returns the else block of the if operation."""
+        return self.regions[1].blocks[0]

diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 9c051cd3d146d..30dafff6a11c5 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -3,11 +3,11 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-  from ..dialects import pdl, transform
+    from ..ir import *
+    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+    from ..dialects import pdl, transform
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import List, Optional, Sequence, Union, overload
 
@@ -16,312 +16,315 @@
 
 
 def _get_int_int_array_attr(
-    values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
-                                                     IntOrAttrList]]]]
+    values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
 ) -> ArrayAttr:
-  """Creates an array attribute containing array attributes of integers.
+    """Creates an array attribute containing array attributes of integers.
 
     If the operand is already an array attribute, forwards it. Otherwise treats
     the operand as a list of attributes or integers, potentially interpserced, to
     create a new array-of-array attribute. Expects the thread-local MLIR context
     to have been set by the context manager.
     """
-  if values is None:
-    return ArrayAttr.get([])
-  if isinstance(values, ArrayAttr):
-    return values
-  if isinstance(values, list):
-    values = [
-        ArrayAttr.get(
-            [IntegerAttr.get(IntegerType.get_signless(64), v)
-             for v in value])
-        for value in values
-    ]
+    if values is None:
+        return ArrayAttr.get([])
+    if isinstance(values, ArrayAttr):
+        return values
+    if isinstance(values, list):
+        values = [
+            ArrayAttr.get(
+                [IntegerAttr.get(IntegerType.get_signless(64), v) for v in value]
+            )
+            for value in values
+        ]
 
-  return ArrayAttr.get(values)
+    return ArrayAttr.get(values)
 
 
 class DecomposeOp:
-  """Specialization for DecomposeOp class."""
+    """Specialization for DecomposeOp class."""
 
-  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-    super().__init__(pdl.OperationType.get(),
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+        super().__init__(
+            pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
+        )
 
 
 class GeneralizeOp:
-  """Specialization for GeneralizeOp class."""
+    """Specialization for GeneralizeOp class."""
 
-  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-    super().__init__(pdl.OperationType.get(),
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+        super().__init__(
+            pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
+        )
 
 
 class InterchangeOp:
-  """Specialization for InterchangeOp class."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      *,
-      iterator_interchange: OptionalIntList = None,
-      loc=None,
-      ip=None,
-  ):
-    pdl_operation_type = pdl.OperationType.get()
-    super().__init__(
-        pdl_operation_type,
-        _get_op_result_or_value(target),
-        iterator_interchange=iterator_interchange,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for InterchangeOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        *,
+        iterator_interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        pdl_operation_type = pdl.OperationType.get()
+        super().__init__(
+            pdl_operation_type,
+            _get_op_result_or_value(target),
+            iterator_interchange=iterator_interchange,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class MatchOp:
-  """Specialization for MatchOp class."""
-
-  @classmethod
-  def match_op_names(
-      MatchOp,
-      target: Union[Operation, Value],
-      names: Sequence[str],
-      loc=None,
-      ip=None,
-  ):
-    pdl_operation_type = pdl.OperationType.get()
-    return MatchOp(
-        pdl_operation_type,
-        _get_op_result_or_value(target),
-        ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for MatchOp class."""
+
+    @classmethod
+    def match_op_names(
+        MatchOp,
+        target: Union[Operation, Value],
+        names: Sequence[str],
+        loc=None,
+        ip=None,
+    ):
+        pdl_operation_type = pdl.OperationType.get()
+        return MatchOp(
+            pdl_operation_type,
+            _get_op_result_or_value(target),
+            ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
+            loc=loc,
+            ip=ip,
+        )
 
 
 class MultiTileSizesOp:
-  """Specialization for MultitileSizesOp class."""
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      dimension: Union[int, IntegerAttr],
-      target_size: Union[int, IntegerAttr],
-      divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
-      loc=None,
-      ip=None,
-  ):
-    if divisor is None:
-      divisor = 1
-    super().__init__(
-        result_type,
-        result_type,
-        result_type,
-        _get_op_result_or_value(target),
-        dimension=dimension,
-        target_size=target_size,
-        divisor=divisor,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for MultitileSizesOp class."""
+
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        dimension: Union[int, IntegerAttr],
+        target_size: Union[int, IntegerAttr],
+        divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
+        loc=None,
+        ip=None,
+    ):
+        if divisor is None:
+            divisor = 1
+        super().__init__(
+            result_type,
+            result_type,
+            result_type,
+            _get_op_result_or_value(target),
+            dimension=dimension,
+            target_size=target_size,
+            divisor=divisor,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class PadOp:
-  """Specialization for PadOp class."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      *,
-      padding_values: Optional[Optional[Union[ArrayAttr,
-                                              Sequence[Attribute]]]] = None,
-      padding_dimensions: OptionalIntList = None,
-      pack_paddings: OptionalIntList = None,
-      transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
-          ArrayAttr, IntOrAttrList]]]] = None,
-      loc=None,
-      ip=None,
-  ):
-    if transpose_paddings is None:
-      transpose_paddings = []
-    if pack_paddings is None:
-      pack_paddings = []
-    if padding_dimensions is None:
-      padding_dimensions = []
-    if padding_values is None:
-      padding_values = []
-    pdl_operation_type = pdl.OperationType.get()
-    transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
-    super().__init__(
-        pdl_operation_type,
-        _get_op_result_or_value(target),
-        padding_values=padding_values,
-        padding_dimensions=padding_dimensions,
-        pack_paddings=pack_paddings,
-        transpose_paddings=transpose_paddings_attr,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for PadOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        *,
+        padding_values: Optional[
+            Optional[Union[ArrayAttr, Sequence[Attribute]]]
+        ] = None,
+        padding_dimensions: OptionalIntList = None,
+        pack_paddings: OptionalIntList = None,
+        transpose_paddings: Optional[
+            Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
+        ] = None,
+        loc=None,
+        ip=None,
+    ):
+        if transpose_paddings is None:
+            transpose_paddings = []
+        if pack_paddings is None:
+            pack_paddings = []
+        if padding_dimensions is None:
+            padding_dimensions = []
+        if padding_values is None:
+            padding_values = []
+        pdl_operation_type = pdl.OperationType.get()
+        transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
+        super().__init__(
+            pdl_operation_type,
+            _get_op_result_or_value(target),
+            padding_values=padding_values,
+            padding_dimensions=padding_dimensions,
+            pack_paddings=pack_paddings,
+            transpose_paddings=transpose_paddings_attr,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class ScalarizeOp:
-  """Specialization for ScalarizeOp class."""
+    """Specialization for ScalarizeOp class."""
 
-  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-    pdl_operation_type = pdl.OperationType.get()
-    super().__init__(pdl_operation_type,
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+        pdl_operation_type = pdl.OperationType.get()
+        super().__init__(
+            pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
+        )
 
 
 class SplitOp:
-  """Specialization for SplitOp class."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      dimension: Union[int, Attribute],
-      split_point: Union[int, Operation, Value, Attribute],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if isinstance(split_point, int):
-      static_split_point = split_point
-      dynamic_split_point = None
-    else:
-      static_split_point = ShapedType.get_dynamic_size()
-      dynamic_split_point = _get_op_result_or_value(split_point)
-
-    target = _get_op_result_or_value(target)
-
-    super().__init__(
-        target.type,
-        target.type,
-        target,
-        dimension=dimension,
-        static_split_point=static_split_point,
-        dynamic_split_point=dynamic_split_point,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for SplitOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        dimension: Union[int, Attribute],
+        split_point: Union[int, Operation, Value, Attribute],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if isinstance(split_point, int):
+            static_split_point = split_point
+            dynamic_split_point = None
+        else:
+            static_split_point = ShapedType.get_dynamic_size()
+            dynamic_split_point = _get_op_result_or_value(split_point)
+
+        target = _get_op_result_or_value(target)
+
+        super().__init__(
+            target.type,
+            target.type,
+            target,
+            dimension=dimension,
+            static_split_point=static_split_point,
+            dynamic_split_point=dynamic_split_point,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class TileOp:
-  """Specialization for TileOp class."""
-
-  @overload
-  def __init__(
-      self,
-      loop_types: Union[Type, List[Type]],
-      target: Union[Operation, Value],
-      *,
-      sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
-                            ArrayAttr]] = None,
-      interchange: OptionalIntList = None,
-      loc=None,
-      ip=None,
-  ):
-    ...
-
-  @overload
-  def __init__(
-      self,
-      target: Union[Operation, Value, OpView],
-      *,
-      sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
-                            ArrayAttr]] = None,
-      interchange: OptionalIntList = None,
-      loc=None,
-      ip=None,
-  ):
-    ...
-
-  def __init__(
-      self,
-      loop_types_or_target: Union[Type, List[Type], Operation, Value],
-      target_or_none: Optional[Union[Operation, Value, OpView]] = None,
-      *,
-      sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
-                            ArrayAttr]] = None,
-      interchange: OptionalIntList = None,
-      loc=None,
-      ip=None,
-  ):
-    if interchange is None:
-      interchange = []
-    if sizes is None:
-      sizes = []
-
-    static_sizes = []
-    dynamic_sizes = []
-    if isinstance(sizes, ArrayAttr):
-      sizes_attr = sizes
-    else:
-      for size in sizes:
-        if isinstance(size, int):
-          static_sizes.append(size)
+    """Specialization for TileOp class."""
+
+    @overload
+    def __init__(
+        self,
+        loop_types: Union[Type, List[Type]],
+        target: Union[Operation, Value],
+        *,
+        sizes: Optional[
+            Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
+        ] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        target: Union[Operation, Value, OpView],
+        *,
+        sizes: Optional[
+            Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
+        ] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    def __init__(
+        self,
+        loop_types_or_target: Union[Type, List[Type], Operation, Value],
+        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+        *,
+        sizes: Optional[
+            Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
+        ] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        if interchange is None:
+            interchange = []
+        if sizes is None:
+            sizes = []
+
+        static_sizes = []
+        dynamic_sizes = []
+        if isinstance(sizes, ArrayAttr):
+            sizes_attr = sizes
+        else:
+            for size in sizes:
+                if isinstance(size, int):
+                    static_sizes.append(size)
+                else:
+                    static_sizes.append(ShapedType.get_dynamic_size())
+                    dynamic_sizes.append(_get_op_result_or_value(size))
+            sizes_attr = DenseI64ArrayAttr.get(static_sizes)
+
+        num_loops = sum(v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
+
+        if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+            loop_types = [transform.AnyOpType.get()] * num_loops
+            target = loop_types_or_target
+            assert target_or_none is None, "Cannot construct TileOp with two targets."
         else:
-          static_sizes.append(ShapedType.get_dynamic_size())
-          dynamic_sizes.append(_get_op_result_or_value(size))
-      sizes_attr = DenseI64ArrayAttr.get(static_sizes)
-
-    num_loops = sum(
-        v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
-
-    if isinstance(loop_types_or_target, (Operation, Value, OpView)):
-      loop_types = [transform.AnyOpType.get()] * num_loops
-      target = loop_types_or_target
-      assert target_or_none is None, "Cannot construct TileOp with two targets."
-    else:
-      loop_types = (([loop_types_or_target] * num_loops) if isinstance(
-          loop_types_or_target, Type) else loop_types_or_target)
-      target = target_or_none
-
-    target = _get_op_result_or_value(target)
-
-    super().__init__(
-        target.type,
-        loop_types,
-        target,
-        dynamic_sizes=dynamic_sizes,
-        static_sizes=sizes_attr,
-        interchange=interchange,
-        loc=loc,
-        ip=ip,
-    )
-
-  def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
-    if not attr:
-      return []
-    return [element for element in attr]
+            loop_types = (
+                ([loop_types_or_target] * num_loops)
+                if isinstance(loop_types_or_target, Type)
+                else loop_types_or_target
+            )
+            target = target_or_none
+
+        target = _get_op_result_or_value(target)
+
+        super().__init__(
+            target.type,
+            loop_types,
+            target,
+            dynamic_sizes=dynamic_sizes,
+            static_sizes=sizes_attr,
+            interchange=interchange,
+            loc=loc,
+            ip=ip,
+        )
+
+    def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
+        if not attr:
+            return []
+        return [element for element in attr]
 
 
 class VectorizeOp:
-  """Specialization for VectorizeOp class."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      *,
-      vectorize_padding: Union[bool, BoolAttr] = False,
-      loc=None,
-      ip=None,
-  ):
-    pdl_operation_type = pdl.OperationType.get()
-    if isinstance(vectorize_padding, bool):
-      vectorize_padding = UnitAttr.get()
-    super().__init__(
-        pdl_operation_type,
-        _get_op_result_or_value(target),
-        vectorize_padding=vectorize_padding,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for VectorizeOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        *,
+        vectorize_padding: Union[bool, BoolAttr] = False,
+        loc=None,
+        ip=None,
+    ):
+        pdl_operation_type = pdl.OperationType.get()
+        if isinstance(vectorize_padding, bool):
+            vectorize_padding = UnitAttr.get()
+        super().__init__(
+            pdl_operation_type,
+            _get_op_result_or_value(target),
+            vectorize_padding=vectorize_padding,
+            loc=loc,
+            ip=ip,
+        )

diff  --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py
index 51d998b6e3ceb..09b9ec68db7d9 100644
--- a/mlir/python/mlir/dialects/_tensor_ops_ext.py
+++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py
@@ -3,40 +3,42 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
+    from ..ir import *
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Any, Optional, Sequence, Union
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+)
 
 
 class EmptyOp:
-  """Extends the tensor.empty op."""
+    """Extends the tensor.empty op."""
 
-  def __init__(self,
-               sizes: Sequence[Union[int, Value]],
-               element_type: Type,
-               *,
-               loc=None,
-               ip=None):
-    """Constructs an `empty` with mixed static/dynamic sizes."""
-    # TODO: Refactor the EmptyOp to take an element type attribute and
-    # then use normal result type inference, unifying the Python and C++ side
-    # with a standard mechanism (versus stashing that in builders).
-    dynamic_sizes = []
-    static_sizes = []
-    for s in sizes:
-      if isinstance(s, int):
-        static_sizes.append(s)
-      else:
-        static_sizes.append(ShapedType.get_dynamic_size())
-        dynamic_sizes.append(s)
-    result_type = RankedTensorType.get(static_sizes, element_type)
-    op = self.build_generic(
-        results=[result_type],
-        operands=dynamic_sizes,
-        attributes={},
-        loc=loc,
-        ip=ip)
-    OpView.__init__(self, op)
+    def __init__(
+        self,
+        sizes: Sequence[Union[int, Value]],
+        element_type: Type,
+        *,
+        loc=None,
+        ip=None
+    ):
+        """Constructs an `empty` with mixed static/dynamic sizes."""
+        # TODO: Refactor the EmptyOp to take an element type attribute and
+        # then use normal result type inference, unifying the Python and C++ side
+        # with a standard mechanism (versus stashing that in builders).
+        dynamic_sizes = []
+        static_sizes = []
+        for s in sizes:
+            if isinstance(s, int):
+                static_sizes.append(s)
+            else:
+                static_sizes.append(ShapedType.get_dynamic_size())
+                dynamic_sizes.append(s)
+        result_type = RankedTensorType.get(static_sizes, element_type)
+        op = self.build_generic(
+            results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
+        )
+        OpView.__init__(self, op)

diff  --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index cc4428ea5b115..425ec65859d39 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -3,144 +3,131 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import (
-      get_op_result_or_value as _get_op_result_or_value,
-      get_op_results_or_values as _get_op_results_or_values,
-  )
+    from ..ir import *
+    from ._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        get_op_results_or_values as _get_op_results_or_values,
+    )
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Sequence, Union
 
 
 class CastOp:
-
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               *,
-               loc=None,
-               ip=None):
-    super().__init__(result_type,
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(
+        self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
+    ):
+        super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
 
 
 class GetClosestIsolatedParentOp:
-
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               *,
-               loc=None,
-               ip=None):
-    super().__init__(result_type,
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(
+        self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
+    ):
+        super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
 
 
 class MergeHandlesOp:
-
-  def __init__(
-      self,
-      handles: Sequence[Union[Operation, Value]],
-      *,
-      deduplicate: bool = False,
-      loc=None,
-      ip=None,
-  ):
-    super().__init__(
-        [_get_op_result_or_value(h) for h in handles],
-        deduplicate=deduplicate,
-        loc=loc,
-        ip=ip,
-    )
+    def __init__(
+        self,
+        handles: Sequence[Union[Operation, Value]],
+        *,
+        deduplicate: bool = False,
+        loc=None,
+        ip=None,
+    ):
+        super().__init__(
+            [_get_op_result_or_value(h) for h in handles],
+            deduplicate=deduplicate,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class ReplicateOp:
-
-  def __init__(
-      self,
-      pattern: Union[Operation, Value],
-      handles: Sequence[Union[Operation, Value]],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    super().__init__(
-        [_get_op_result_or_value(h).type for h in handles],
-        _get_op_result_or_value(pattern),
-        [_get_op_result_or_value(h) for h in handles],
-        loc=loc,
-        ip=ip,
-    )
+    def __init__(
+        self,
+        pattern: Union[Operation, Value],
+        handles: Sequence[Union[Operation, Value]],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        super().__init__(
+            [_get_op_result_or_value(h).type for h in handles],
+            _get_op_result_or_value(pattern),
+            [_get_op_result_or_value(h) for h in handles],
+            loc=loc,
+            ip=ip,
+        )
 
 
 class SequenceOp:
-
-  def __init__(
-      self,
-      failure_propagation_mode,
-      results: Sequence[Type],
-      target: Union[Operation, Value, Type],
-      extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation,
-                                     OpView]] = None,
-  ):
-    root = (_get_op_result_or_value(target) if isinstance(
-        target, (Operation, Value)) else None)
-    root_type = root.type if not isinstance(target, Type) else target
-    if not isinstance(failure_propagation_mode, Attribute):
-      failure_propagation_mode_attr = IntegerAttr.get(
-          IntegerType.get_signless(32), failure_propagation_mode._as_int())
-    else:
-      failure_propagation_mode_attr = failure_propagation_mode
-
-    if extra_bindings is None:
-      extra_bindings = []
-    if isinstance(extra_bindings, (Operation, OpView)):
-      extra_bindings = _get_op_results_or_values(extra_bindings)
-
-    extra_binding_types = []
-    if len(extra_bindings) != 0:
-      if isinstance(extra_bindings[0], Type):
-        extra_binding_types = extra_bindings
-        extra_bindings = []
-      else:
-        extra_binding_types = [v.type for v in extra_bindings]
-
-    super().__init__(
-        results_=results,
-        failure_propagation_mode=failure_propagation_mode_attr,
-        root=root,
-        extra_bindings=extra_bindings,
-    )
-    self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
-
-  @property
-  def body(self) -> Block:
-    return self.regions[0].blocks[0]
-
-  @property
-  def bodyTarget(self) -> Value:
-    return self.body.arguments[0]
-
-  @property
-  def bodyExtraArgs(self) -> BlockArgumentList:
-    return self.body.arguments[1:]
+    def __init__(
+        self,
+        failure_propagation_mode,
+        results: Sequence[Type],
+        target: Union[Operation, Value, Type],
+        extra_bindings: Optional[
+            Union[Sequence[Value], Sequence[Type], Operation, OpView]
+        ] = None,
+    ):
+        root = (
+            _get_op_result_or_value(target)
+            if isinstance(target, (Operation, Value))
+            else None
+        )
+        root_type = root.type if not isinstance(target, Type) else target
+        if not isinstance(failure_propagation_mode, Attribute):
+            failure_propagation_mode_attr = IntegerAttr.get(
+                IntegerType.get_signless(32), failure_propagation_mode._as_int()
+            )
+        else:
+            failure_propagation_mode_attr = failure_propagation_mode
+
+        if extra_bindings is None:
+            extra_bindings = []
+        if isinstance(extra_bindings, (Operation, OpView)):
+            extra_bindings = _get_op_results_or_values(extra_bindings)
+
+        extra_binding_types = []
+        if len(extra_bindings) != 0:
+            if isinstance(extra_bindings[0], Type):
+                extra_binding_types = extra_bindings
+                extra_bindings = []
+            else:
+                extra_binding_types = [v.type for v in extra_bindings]
+
+        super().__init__(
+            results_=results,
+            failure_propagation_mode=failure_propagation_mode_attr,
+            root=root,
+            extra_bindings=extra_bindings,
+        )
+        self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTarget(self) -> Value:
+        return self.body.arguments[0]
+
+    @property
+    def bodyExtraArgs(self) -> BlockArgumentList:
+        return self.body.arguments[1:]
 
 
 class YieldOp:
-
-  def __init__(
-      self,
-      operands: Optional[Union[Operation, Sequence[Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if operands is None:
-      operands = []
-    super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+    def __init__(
+        self,
+        operands: Optional[Union[Operation, Sequence[Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if operands is None:
+            operands = []
+        super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
index 5a695d6216770..2f651319930fc 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
@@ -31,61 +31,60 @@
 
 
 def create_arg_parser() -> argparse.ArgumentParser:
-  p = argparse.ArgumentParser(description="Dump an oplib in various formats")
-  p.add_argument("modules",
-                 metavar="M",
-                 type=str,
-                 nargs="*",
-                 help="Op module to dump")
-  p.add_argument("--file",
-                 metavar="F",
-                 type=str,
-                 nargs="*",
-                 help="Python op file to dump")
-  p.add_argument("--format",
-                 type=str,
-                 dest="format",
-                 default="yaml",
-                 choices=("yaml", "repr"),
-                 help="Format in which to dump")
-  return p
+    p = argparse.ArgumentParser(description="Dump an oplib in various formats")
+    p.add_argument(
+        "modules", metavar="M", type=str, nargs="*", help="Op module to dump"
+    )
+    p.add_argument(
+        "--file", metavar="F", type=str, nargs="*", help="Python op file to dump"
+    )
+    p.add_argument(
+        "--format",
+        type=str,
+        dest="format",
+        default="yaml",
+        choices=("yaml", "repr"),
+        help="Format in which to dump",
+    )
+    return p
 
 
 def load_module_from_file(module_name, file_path):
-  spec = importlib.util.spec_from_file_location(module_name, file_path)
-  m = importlib.util.module_from_spec(spec)
-  spec.loader.exec_module(m)
-  return m
+    spec = importlib.util.spec_from_file_location(module_name, file_path)
+    m = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(m)
+    return m
 
 
 def main(args):
-  # Load all configs.
-  configs = []
-  modules = []
-  for module_name in args.modules:
-    modules.append(
-        importlib.import_module(module_name,
-                                package="mlir.dialects.linalg.opdsl"))
-  for i, file_path in enumerate(args.file or []):
-    modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
-  for m in modules:
-    for attr_name, value in m.__dict__.items():
-      # TODO: This class layering is awkward.
-      if isinstance(value, DefinedOpCallable):
-        try:
-          linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
-        except Exception as e:
-          raise ValueError(
-              f"Could not create LinalgOpConfig from {value.op_def}") from e
-        configs.extend(linalg_config)
-
-  # Print.
-  if args.format == "yaml":
-    print(yaml_dump_all(configs))
-  elif args.format == "repr":
-    for config in configs:
-      print(repr(config))
+    # Load all configs.
+    configs = []
+    modules = []
+    for module_name in args.modules:
+        modules.append(
+            importlib.import_module(module_name, package="mlir.dialects.linalg.opdsl")
+        )
+    for i, file_path in enumerate(args.file or []):
+        modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
+    for m in modules:
+        for attr_name, value in m.__dict__.items():
+            # TODO: This class layering is awkward.
+            if isinstance(value, DefinedOpCallable):
+                try:
+                    linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
+                except Exception as e:
+                    raise ValueError(
+                        f"Could not create LinalgOpConfig from {value.op_def}"
+                    ) from e
+                configs.extend(linalg_config)
+
+    # Print.
+    if args.format == "yaml":
+        print(yaml_dump_all(configs))
+    elif args.format == "repr":
+        for config in configs:
+            print(repr(config))
 
 
 if __name__ == "__main__":
-  main(create_arg_parser().parse_args())
+    main(create_arg_parser().parse_args())

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
index 038f068345428..9fa626dfa78b1 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
@@ -66,201 +66,201 @@
 
 
 class AffineBuildState:
-  """Internal state for the AffineExprDef._create impls.
-
-  Note that a "local" AffineBuildState can be created relative to a "global"
-  AffineBuildState. In that case, any affine expressions built will inherit
-  symbol and dim bindings from the global state and will update both as new
-  ones are discovered. This allows for building expressions across contexts
-  which share a common symbol and dim space.
-  """
-
-  def __init__(self,
-               *,
-               global_state: "AffineBuildState" = None,
-               allow_new_symbols: bool = True,
-               allow_new_dims: bool = True):
-    if not global_state:
-      self.all_symbols = dict()  # type: Dict[str, int]
-      self.all_dims = dict()  # type: Dict[str, int]
-    else:
-      # Alias the global dict.
-      self.all_symbols = global_state.all_symbols
-      self.all_dims = global_state.all_dims
-
-    # Map of symbols and dims in the current build.
-    self.local_symbols = dict()  # type: Dict[str, int]
-    self.local_dims = dict()  # type: Dict[str, int]
-    self.allow_new_symbols = allow_new_symbols
-    self.allow_new_dims = allow_new_dims
-
-  def get_dim(self, dimname: str) -> int:
-    """Gets the dim position given a name."""
-    pos = self.all_dims.get(dimname)
-    if pos is None:
-      if not self.allow_new_dims:
-        raise ValueError(
-            f"New dimensions not allowed in the current affine expression: "
-            f"Requested '{dimname}', Availble: {self.all_dims}")
-      pos = len(self.all_dims)
-      self.all_dims[dimname] = pos
-    self.local_dims[dimname] = pos
-    return pos
-
-  def get_symbol(self, symname: str) -> int:
-    """Geta a symbol position given a name."""
-    pos = self.all_symbols.get(symname)
-    if pos is None:
-      if not self.allow_new_symbols:
-        raise ValueError(
-            f"New symbols not allowed in the current affine expression: "
-            f"Requested '{symname}', Availble: {self.all_symbols}")
-      pos = len(self.all_symbols)
-      self.all_symbols[symname] = pos
-    self.local_symbols[symname] = pos
-    return pos
-
-  @property
-  def local_dim_count(self) -> int:
-    return len(self.local_dims)
-
-  @property
-  def local_symbol_count(self) -> int:
-    return len(self.local_symbols)
-
-  @property
-  def dim_count(self) -> int:
-    return len(self.all_dims)
-
-  @property
-  def symbol_count(self) -> int:
-    return len(self.all_symbols)
-
-  def __repr__(self):
-    lines = [f"AffineBuildState<"]
-    lines.append(f"  symbols={self.local_symbols}")
-    lines.append(f"  dims={self.local_dims}>")
-    return "\n".join(lines)
+    """Internal state for the AffineExprDef._create impls.
+
+    Note that a "local" AffineBuildState can be created relative to a "global"
+    AffineBuildState. In that case, any affine expressions built will inherit
+    symbol and dim bindings from the global state and will update both as new
+    ones are discovered. This allows for building expressions across contexts
+    which share a common symbol and dim space.
+    """
+
+    def __init__(
+        self,
+        *,
+        global_state: "AffineBuildState" = None,
+        allow_new_symbols: bool = True,
+        allow_new_dims: bool = True,
+    ):
+        if not global_state:
+            self.all_symbols = dict()  # type: Dict[str, int]
+            self.all_dims = dict()  # type: Dict[str, int]
+        else:
+            # Alias the global dict.
+            self.all_symbols = global_state.all_symbols
+            self.all_dims = global_state.all_dims
+
+        # Map of symbols and dims in the current build.
+        self.local_symbols = dict()  # type: Dict[str, int]
+        self.local_dims = dict()  # type: Dict[str, int]
+        self.allow_new_symbols = allow_new_symbols
+        self.allow_new_dims = allow_new_dims
+
+    def get_dim(self, dimname: str) -> int:
+        """Gets the dim position given a name."""
+        pos = self.all_dims.get(dimname)
+        if pos is None:
+            if not self.allow_new_dims:
+                raise ValueError(
+                    f"New dimensions not allowed in the current affine expression: "
+                    f"Requested '{dimname}', Availble: {self.all_dims}"
+                )
+            pos = len(self.all_dims)
+            self.all_dims[dimname] = pos
+        self.local_dims[dimname] = pos
+        return pos
+
+    def get_symbol(self, symname: str) -> int:
+        """Geta a symbol position given a name."""
+        pos = self.all_symbols.get(symname)
+        if pos is None:
+            if not self.allow_new_symbols:
+                raise ValueError(
+                    f"New symbols not allowed in the current affine expression: "
+                    f"Requested '{symname}', Availble: {self.all_symbols}"
+                )
+            pos = len(self.all_symbols)
+            self.all_symbols[symname] = pos
+        self.local_symbols[symname] = pos
+        return pos
+
+    @property
+    def local_dim_count(self) -> int:
+        return len(self.local_dims)
+
+    @property
+    def local_symbol_count(self) -> int:
+        return len(self.local_symbols)
+
+    @property
+    def dim_count(self) -> int:
+        return len(self.all_dims)
+
+    @property
+    def symbol_count(self) -> int:
+        return len(self.all_symbols)
+
+    def __repr__(self):
+        lines = [f"AffineBuildState<"]
+        lines.append(f"  symbols={self.local_symbols}")
+        lines.append(f"  dims={self.local_dims}>")
+        return "\n".join(lines)
 
 
 class AffineExprDef:
-  """Base class for an affine expression being defined."""
+    """Base class for an affine expression being defined."""
 
-  def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
-    """Builds the corresponding _ir.AffineExpr from the definitions.
-    """
-    state = AffineBuildState() if state is None else state
-    expr = self._create(state)
-    return expr
+    def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
+        """Builds the corresponding _ir.AffineExpr from the definitions."""
+        state = AffineBuildState() if state is None else state
+        expr = self._create(state)
+        return expr
 
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    raise NotImplementedError()
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        raise NotImplementedError()
 
-  @staticmethod
-  def coerce_from(py_value):
-    if isinstance(py_value, int):
-      return AffineConstantExpr(py_value)
-    assert isinstance(py_value, AffineExprDef)
-    return py_value
+    @staticmethod
+    def coerce_from(py_value):
+        if isinstance(py_value, int):
+            return AffineConstantExpr(py_value)
+        assert isinstance(py_value, AffineExprDef)
+        return py_value
 
-  def visit_affine_exprs(self, callback):
-    """Visits all AffineExprDefs including self."""
-    callback(self)
+    def visit_affine_exprs(self, callback):
+        """Visits all AffineExprDefs including self."""
+        callback(self)
 
-  def __add__(lhs, rhs):
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
+    def __add__(lhs, rhs):
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
 
-  def __mul__(lhs, rhs):
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
+    def __mul__(lhs, rhs):
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
 
-  def __mod__(lhs, rhs):
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
+    def __mod__(lhs, rhs):
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
 
-  def __floordiv__(lhs, rhs):
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
+    def __floordiv__(lhs, rhs):
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
 
-  def __truediv__(lhs, rhs):
-    # TODO: Not really a ceil div - taking liberties for the DSL.
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
+    def __truediv__(lhs, rhs):
+        # TODO: Not really a ceil div - taking liberties for the DSL.
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
 
 
 class AffineConstantExpr(AffineExprDef):
-  """An affine constant being defined."""
+    """An affine constant being defined."""
 
-  def __init__(self, value: int):
-    assert isinstance(value, int)
-    self.value = value
+    def __init__(self, value: int):
+        assert isinstance(value, int)
+        self.value = value
 
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    return _ir.AffineConstantExpr.get(self.value)
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        return _ir.AffineConstantExpr.get(self.value)
 
-  def __repr__(self):
-    return f"Const({self.value})"
+    def __repr__(self):
+        return f"Const({self.value})"
 
 
 class AffineBinaryExprDef(AffineExprDef):
-  """An affine binary expression being defined."""
+    """An affine binary expression being defined."""
 
-  def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
-    self.ir_ctor = ir_ctor
-    self.lhs = lhs
-    self.rhs = rhs
+    def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
+        self.ir_ctor = ir_ctor
+        self.lhs = lhs
+        self.rhs = rhs
 
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
 
-  def visit_affine_exprs(self, callback):
-    """Visits all AffineExprDefs including self."""
-    super().visit_affine_exprs(callback)
-    self.lhs.visit_affine_exprs(callback)
-    self.rhs.visit_affine_exprs(callback)
+    def visit_affine_exprs(self, callback):
+        """Visits all AffineExprDefs including self."""
+        super().visit_affine_exprs(callback)
+        self.lhs.visit_affine_exprs(callback)
+        self.rhs.visit_affine_exprs(callback)
 
-  def __repr__(self):
-    return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
+    def __repr__(self):
+        return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
 
 
 class DimDef(AffineExprDef):
-  """Represents a named dimension.
-
-  """
-  ALL_DIMS = dict()  # type: Dict[str, "DimDef"]
-
-  def __new__(cls, dimname: str):
-    existing = cls.ALL_DIMS.get(dimname)
-    if existing is not None:
-      return existing
-    new = super().__new__(cls)
-    new.dimname = dimname
-    cls.ALL_DIMS[dimname] = new
-    return new
-
-  def __repr__(self):
-    return f"Dim({self.dimname})"
-
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    pos = state.get_dim(self.dimname)
-    return _ir.AffineDimExpr.get(position=pos)
-
-  @classmethod
-  def create_expando(cls):
-    """Create an expando class that creates unique symbols based on attr access.
-    """
+    """Represents a named dimension."""
+
+    ALL_DIMS = dict()  # type: Dict[str, "DimDef"]
+
+    def __new__(cls, dimname: str):
+        existing = cls.ALL_DIMS.get(dimname)
+        if existing is not None:
+            return existing
+        new = super().__new__(cls)
+        new.dimname = dimname
+        cls.ALL_DIMS[dimname] = new
+        return new
 
-    class ExpandoDims:
+    def __repr__(self):
+        return f"Dim({self.dimname})"
 
-      def __getattr__(self, n):
-        return cls(n)
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        pos = state.get_dim(self.dimname)
+        return _ir.AffineDimExpr.get(position=pos)
 
-    return ExpandoDims()
+    @classmethod
+    def create_expando(cls):
+        """Create an expando class that creates unique symbols based on attr access."""
+
+        class ExpandoDims:
+            def __getattr__(self, n):
+                return cls(n)
+
+        return ExpandoDims()
 
 
 class SymbolDef(AffineExprDef):
-  """Represents a named symbol.
+    """Represents a named symbol.
 
     >>> s1 = SymbolDef("s1")
     >>> s1
@@ -270,36 +270,35 @@ class SymbolDef(AffineExprDef):
     False
     >>> s1 is SymbolDef("s1")
     True
-  """
-  ALL_SYMBOLS = dict()  # type: Dict[str, "SymbolDef"]
-
-  def __new__(cls, symname: str):
-    existing = cls.ALL_SYMBOLS.get(symname)
-    if existing is not None:
-      return existing
-    new = super().__new__(cls)
-    new.symname = symname
-    cls.ALL_SYMBOLS[symname] = new
-    return new
-
-  def __repr__(self):
-    return f"Symbol({self.symname})"
-
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    pos = state.get_symbol(self.symname)
-    return _ir.AffineSymbolExpr.get(position=pos)
-
-  @classmethod
-  def create_expando(cls):
-    """Create an expando class that creates unique symbols based on attr access.
     """
 
-    class ExpandoSymbols:
+    ALL_SYMBOLS = dict()  # type: Dict[str, "SymbolDef"]
+
+    def __new__(cls, symname: str):
+        existing = cls.ALL_SYMBOLS.get(symname)
+        if existing is not None:
+            return existing
+        new = super().__new__(cls)
+        new.symname = symname
+        cls.ALL_SYMBOLS[symname] = new
+        return new
+
+    def __repr__(self):
+        return f"Symbol({self.symname})"
+
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        pos = state.get_symbol(self.symname)
+        return _ir.AffineSymbolExpr.get(position=pos)
+
+    @classmethod
+    def create_expando(cls):
+        """Create an expando class that creates unique symbols based on attr access."""
 
-      def __getattr__(self, n):
-        return cls(n)
+        class ExpandoSymbols:
+            def __getattr__(self, n):
+                return cls(n)
 
-    return ExpandoSymbols()
+        return ExpandoSymbols()
 
 
 # Global accessor for on-demand dims and symbols.

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 135f55ea516d0..5d5866fdeabf6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -23,223 +23,232 @@
 
 
 class TensorExpression:
-  """An expression that can appear on the RHS of a comprehension."""
+    """An expression that can appear on the RHS of a comprehension."""
 
-  def to_scalar_expression(self) -> ScalarExpression:
-    raise NotImplementedError()
+    def to_scalar_expression(self) -> ScalarExpression:
+        raise NotImplementedError()
 
-  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
-    """Visits all tensor expression reachable by the expression."""
-    callback(self)
+    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+        """Visits all tensor expression reachable by the expression."""
+        callback(self)
 
-  def collect_dim_uses(self, uses: Set["DimDef"]):
-    """Collects all DimDefs reachable through this expression."""
+    def collect_dim_uses(self, uses: Set["DimDef"]):
+        """Collects all DimDefs reachable through this expression."""
 
-    def visit_dim_def(dim_def: AffineExprDef):
-      if isinstance(dim_def, DimDef):
-        uses.add(dim_def)
+        def visit_dim_def(dim_def: AffineExprDef):
+            if isinstance(dim_def, DimDef):
+                uses.add(dim_def)
 
-    def visit_affine_exprs(expr: "TensorExpression"):
-      if isinstance(expr, TensorUse):
-        for ind in expr.indices:
-          ind.visit_affine_exprs(visit_dim_def)
-      if isinstance(expr, TensorReduceFn):
-        for ind in expr.reduce_fn.reduce_dims:
-          ind.visit_affine_exprs(visit_dim_def)
+        def visit_affine_exprs(expr: "TensorExpression"):
+            if isinstance(expr, TensorUse):
+                for ind in expr.indices:
+                    ind.visit_affine_exprs(visit_dim_def)
+            if isinstance(expr, TensorReduceFn):
+                for ind in expr.reduce_fn.reduce_dims:
+                    ind.visit_affine_exprs(visit_dim_def)
 
-    self.visit_tensor_exprs(visit_affine_exprs)
+        self.visit_tensor_exprs(visit_affine_exprs)
 
-  def collect_tensor_uses(self, uses: Set["TensorUse"]):
-    """Collects all TensorUses reachable through this expression."""
+    def collect_tensor_uses(self, uses: Set["TensorUse"]):
+        """Collects all TensorUses reachable through this expression."""
 
-    def visit_tensor_use(expr: "TensorExpression"):
-      if isinstance(expr, TensorUse):
-        uses.add(expr)
+        def visit_tensor_use(expr: "TensorExpression"):
+            if isinstance(expr, TensorUse):
+                uses.add(expr)
 
-    self.visit_tensor_exprs(visit_tensor_use)
+        self.visit_tensor_exprs(visit_tensor_use)
 
-  def collect_indices(self, indices: Set["index"]):
-    """Collects all index accesses reachable through this expression."""
+    def collect_indices(self, indices: Set["index"]):
+        """Collects all index accesses reachable through this expression."""
 
-    def visit_index(expr: "TensorExpression"):
-      if isinstance(expr, index):
-        indices.add(expr)
+        def visit_index(expr: "TensorExpression"):
+            if isinstance(expr, index):
+                indices.add(expr)
 
-    self.visit_tensor_exprs(visit_index)
+        self.visit_tensor_exprs(visit_index)
 
-  def collect_scalar_uses(self, uses: Set["ScalarDef"]):
-    """Collects all ScalarDefs reachable through this expression."""
+    def collect_scalar_uses(self, uses: Set["ScalarDef"]):
+        """Collects all ScalarDefs reachable through this expression."""
 
-    def visit_scalar_def(expr: "TensorExpression"):
-      if isinstance(expr, ScalarDef):
-        uses.add(expr)
+        def visit_scalar_def(expr: "TensorExpression"):
+            if isinstance(expr, ScalarDef):
+                uses.add(expr)
 
-    self.visit_tensor_exprs(visit_scalar_def)
+        self.visit_tensor_exprs(visit_scalar_def)
 
-  def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
-    return BinaryFn.add(self, rhs)
+    def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
+        return BinaryFn.add(self, rhs)
 
-  def __mul__(self, rhs) -> "TensorExpression":
-    return BinaryFn.mul(self, rhs)
+    def __mul__(self, rhs) -> "TensorExpression":
+        return BinaryFn.mul(self, rhs)
 
-  def __sub__(self, rhs) -> "TensorExpression":
-    return BinaryFn.sub(self, rhs)
+    def __sub__(self, rhs) -> "TensorExpression":
+        return BinaryFn.sub(self, rhs)
 
-  def __hash__(self):
-    return hash(id(self))
+    def __hash__(self):
+        return hash(id(self))
 
 
 class TensorUse(TensorExpression):
-  """A used tensor represented by its (tensor_name, indices).
-
-  Note that forming a comprehension via direct assignment is performed through
-  __setitem__ on the TensorDef level. However, performing a reduction with
-  compound ops (+=, *=, etc) is done by doing a:
-    TensorDef.__getitem__
-    TensorUse.__iadd__
-    TensorDef.__setitem__
-  """
-
-  def __init__(self, operand_def: "OperandDef",
-               indices: Sequence[AffineExprDef]):
-    self.operand_def = operand_def
-    self.indices = tuple(indices)
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarArg(self.tensor_name).expr()
-
-  @property
-  def tensor_name(self) -> str:
-    name = self.operand_def.name
-    assert name is not None, "TensorDef not registered with an op"
-    return name
-
-  def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
-    # Computes the reduction dims for implicit reductions. Assumes that the rhs
-    # is the expression being reduced and self is being reduced into. Any
-    # indices referenced on the rhs and not in self are considered reduction
-    # dims and will be ordered as encountered on the rhs.
-    rhs_dims = set()
-    lhs_dims = set()
-    rhs.collect_dim_uses(rhs_dims)
-    self.collect_dim_uses(lhs_dims)
-    return rhs_dims - lhs_dims
-
-  def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
-    return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
-
-  def __repr__(self):
-    return (f"{self.operand_def.name}"
-            f"[{', '.join([repr(i) for i in self.indices])}]")
+    """A used tensor represented by its (tensor_name, indices).
+
+    Note that forming a comprehension via direct assignment is performed through
+    __setitem__ on the TensorDef level. However, performing a reduction with
+    compound ops (+=, *=, etc) is done by doing a:
+      TensorDef.__getitem__
+      TensorUse.__iadd__
+      TensorDef.__setitem__
+    """
+
+    def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]):
+        self.operand_def = operand_def
+        self.indices = tuple(indices)
+
+    def to_scalar_expression(self) -> ScalarExpression:
+        return ScalarArg(self.tensor_name).expr()
+
+    @property
+    def tensor_name(self) -> str:
+        name = self.operand_def.name
+        assert name is not None, "TensorDef not registered with an op"
+        return name
+
+    def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
+        # Computes the reduction dims for implicit reductions. Assumes that the rhs
+        # is the expression being reduced and self is being reduced into. Any
+        # indices referenced on the rhs and not in self are considered reduction
+        # dims and will be ordered as encountered on the rhs.
+        rhs_dims = set()
+        lhs_dims = set()
+        rhs.collect_dim_uses(rhs_dims)
+        self.collect_dim_uses(lhs_dims)
+        return rhs_dims - lhs_dims
+
+    def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
+        return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
+
+    def __repr__(self):
+        return (
+            f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]"
+        )
 
 
 class TensorFn(TensorExpression):
-  """Application of a tensor function."""
-
-  def __init__(self, kind: "FunctionKind", name: Optional[str],
-               operand_def: Optional["OperandDef"], type_var: Optional[TypeVar],
-               args: Sequence[TensorExpression]):
-    if bool(name) + bool(operand_def) != 1:
-      raise ValueError("One of 'name', 'operand_def' must be specified")
-    self.name = name
-    self.kind = kind
-    self.operand_def = operand_def
-    self.type_var = type_var
-    self.args = args
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    if self.operand_def:
-      assert self.operand_def.name, "TensorFn not registered with an op"
-    attr_name = self.operand_def.name if self.operand_def else None
-    args = [arg.to_scalar_expression() for arg in self.args]
-    return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
-
-  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
-    super().visit_tensor_exprs(callback)
-    for arg in self.args:
-      arg.visit_tensor_exprs(callback)
-
-  def __repr__(self):
-    name = self.operand_def.name if self.operand_def else self.name
-    return (f"{self.kind.name}.{name}(type_var={self.type_var}, "
-            f"args={', '.join(repr(a) for a in self.args)})")
+    """Application of a tensor function."""
+
+    def __init__(
+        self,
+        kind: "FunctionKind",
+        name: Optional[str],
+        operand_def: Optional["OperandDef"],
+        type_var: Optional[TypeVar],
+        args: Sequence[TensorExpression],
+    ):
+        if bool(name) + bool(operand_def) != 1:
+            raise ValueError("One of 'name', 'operand_def' must be specified")
+        self.name = name
+        self.kind = kind
+        self.operand_def = operand_def
+        self.type_var = type_var
+        self.args = args
+
+    def to_scalar_expression(self) -> ScalarExpression:
+        if self.operand_def:
+            assert self.operand_def.name, "TensorFn not registered with an op"
+        attr_name = self.operand_def.name if self.operand_def else None
+        args = [arg.to_scalar_expression() for arg in self.args]
+        return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
+
+    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+        super().visit_tensor_exprs(callback)
+        for arg in self.args:
+            arg.visit_tensor_exprs(callback)
+
+    def __repr__(self):
+        name = self.operand_def.name if self.operand_def else self.name
+        return (
+            f"{self.kind.name}.{name}(type_var={self.type_var}, "
+            f"args={', '.join(repr(a) for a in self.args)})"
+        )
 
 
 class TensorReduceFn(TensorExpression):
-  """Application of a reduction function.
-
-  This captures the lhs (initial value) separately from the rhs.
-  """
-
-  def __init__(self, reduce_use: "ReduceFnUse",
-               args: Sequence[TensorExpression]):
-    self.reduce_use = reduce_use
-    self.lhs = None  # type: Optional[TensorUse]
-    self.args = args
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    if self.lhs is None:
-      raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
-                       f"bound to its lhs: {self}")
-    full_args = [self.lhs.to_scalar_expression()
-                ] + [arg.to_scalar_expression() for arg in self.args]
-    fn_name = None
-    attr_name = None
-    if self.reduce_use.binary_fn:
-      fn_name = self.reduce_use.binary_fn.fn_name
-    if self.reduce_use.binary_attr:
-      attr_name = self.reduce_use.binary_attr.operand_def.name
-    return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None,
-                    full_args).expr()
-
-  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
-    for arg in self.args:
-      arg.visit_tensor_exprs(callback)
-
-  def __repr__(self):
-    return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
+    """Application of a reduction function.
+
+    This captures the lhs (initial value) separately from the rhs.
+    """
+
+    def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]):
+        self.reduce_use = reduce_use
+        self.lhs = None  # type: Optional[TensorUse]
+        self.args = args
+
+    def to_scalar_expression(self) -> ScalarExpression:
+        if self.lhs is None:
+            raise ValueError(
+                f"Cannot scalarize a TensorReduceFn that has not been "
+                f"bound to its lhs: {self}"
+            )
+        full_args = [self.lhs.to_scalar_expression()] + [
+            arg.to_scalar_expression() for arg in self.args
+        ]
+        fn_name = None
+        attr_name = None
+        if self.reduce_use.binary_fn:
+            fn_name = self.reduce_use.binary_fn.fn_name
+        if self.reduce_use.binary_attr:
+            attr_name = self.reduce_use.binary_attr.operand_def.name
+        return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr()
+
+    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+        for arg in self.args:
+            arg.visit_tensor_exprs(callback)
+
+    def __repr__(self):
+        return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
 
 
 class const(TensorExpression):
-  """Returns the given constant floating point or integer value."""
+    """Returns the given constant floating point or integer value."""
 
-  def __init__(self, value: Any):
-    with _ir.Context():
-      if isinstance(value, float):
-        self.value = str(_ir.FloatAttr.get_f64(float(value)))
-      elif isinstance(value, int):
-        self.value = str(
-            _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
-      else:
-        raise ValueError(f"const requires int or float but got {type(value)}")
+    def __init__(self, value: Any):
+        with _ir.Context():
+            if isinstance(value, float):
+                self.value = str(_ir.FloatAttr.get_f64(float(value)))
+            elif isinstance(value, int):
+                self.value = str(
+                    _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))
+                )
+            else:
+                raise ValueError(f"const requires int or float but got {type(value)}")
 
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarConst(self.value).expr()
+    def to_scalar_expression(self) -> ScalarExpression:
+        return ScalarConst(self.value).expr()
 
-  def __repr__(self):
-    return f"const({self.value})"
+    def __repr__(self):
+        return f"const({self.value})"
 
 
 class index(TensorExpression):
-  """Returns the iteration index for a given dimension name.
+    """Returns the iteration index for a given dimension name.
 
-  Resolves the given dimension name to obtain its position in the iteration
-  domain of the operation.
-  """
+    Resolves the given dimension name to obtain its position in the iteration
+    domain of the operation.
+    """
 
-  def __init__(self, dim: DimDef):
-    self.dim_def = dim
-    self.dim = -1
+    def __init__(self, dim: DimDef):
+        self.dim_def = dim
+        self.dim = -1
 
-  def resolve_dimension_name(self, affine_state: AffineBuildState):
-    self.dim = affine_state.get_dim(self.dim_def.dimname)
+    def resolve_dimension_name(self, affine_state: AffineBuildState):
+        self.dim = affine_state.get_dim(self.dim_def.dimname)
 
-  def to_scalar_expression(self) -> ScalarExpression:
-    assert self.dim != -1, "Dimension name not resolved"
-    return ScalarIndex(self.dim).expr()
+    def to_scalar_expression(self) -> ScalarExpression:
+        assert self.dim != -1, "Dimension name not resolved"
+        return ScalarIndex(self.dim).expr()
 
-  def __repr__(self):
-    return f"index({repr(self.dim)})"
+    def __repr__(self):
+        return f"index({repr(self.dim)})"
 
 
 ###############################################################################
@@ -248,155 +257,160 @@ def __repr__(self):
 
 
 class FunctionKind(Enum):
-  UNARY = 0
-  BINARY = 1
-  TYPE = 2
+    UNARY = 0
+    BINARY = 1
+    TYPE = 2
 
 
 class UnaryFnType:
-  """Unary function.
+    """Unary function.
 
-  A unary function takes one tensor expression and returns the
-  function evaluation result.
-  """
+    A unary function takes one tensor expression and returns the
+    function evaluation result.
+    """
 
-  def __init__(self, fn_name: str):
-    self.fn_name = fn_name
+    def __init__(self, fn_name: str):
+        self.fn_name = fn_name
 
-  def __call__(self, arg: TensorExpression) -> "TensorFn":
-    return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
+    def __call__(self, arg: TensorExpression) -> "TensorFn":
+        return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
 
-  def __repr__(self):
-    return f"{self.fn_name}"
+    def __repr__(self):
+        return f"{self.fn_name}"
 
 
 class UnaryFn:
-  """Unary function namespace."""
-  exp = UnaryFnType("exp")
-  log = UnaryFnType("log")
-  abs = UnaryFnType("abs")
-  ceil = UnaryFnType("ceil")
-  floor = UnaryFnType("floor")
-  negf = UnaryFnType("negf")
+    """Unary function namespace."""
+
+    exp = UnaryFnType("exp")
+    log = UnaryFnType("log")
+    abs = UnaryFnType("abs")
+    ceil = UnaryFnType("ceil")
+    floor = UnaryFnType("floor")
+    negf = UnaryFnType("negf")
 
 
 class BinaryFnType:
-  """Binary function.
+    """Binary function.
 
-  A binary function takes two tensor expressions and returns the
-  function evaluation result.
-  """
+    A binary function takes two tensor expressions and returns the
+    function evaluation result.
+    """
 
-  def __init__(self, fn_name: str):
-    self.fn_name = fn_name
+    def __init__(self, fn_name: str):
+        self.fn_name = fn_name
 
-  def __call__(self, arg0: TensorExpression,
-               arg1: TensorExpression) -> "TensorFn":
-    return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
+    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn":
+        return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
 
-  def __repr__(self):
-    return f"{self.fn_name}"
+    def __repr__(self):
+        return f"{self.fn_name}"
 
 
 class BinaryFn:
-  """Binary function namespace.
+    """Binary function namespace.
 
-  As the integer types are signless, signedness is implement by 
diff erent
-  functions that treat integers as signed or unsigned values.
+    As the integer types are signless, signedness is implement by 
diff erent
+    functions that treat integers as signed or unsigned values.
+
+    Examples:
+    - max -> `arith.MaxSIOp`
+    - max_unsinged -> `arith.MaxUIOp`
+    """
 
-  Examples:
-  - max -> `arith.MaxSIOp`
-  - max_unsinged -> `arith.MaxUIOp`
-  """
-  add = BinaryFnType("add")
-  sub = BinaryFnType("sub")
-  mul = BinaryFnType("mul")
-  max_signed = BinaryFnType("max_signed")
-  min_signed = BinaryFnType("min_signed")
-  max_unsigned = BinaryFnType("max_unsigned")
-  min_unsigned = BinaryFnType("min_unsigned")
+    add = BinaryFnType("add")
+    sub = BinaryFnType("sub")
+    mul = BinaryFnType("mul")
+    max_signed = BinaryFnType("max_signed")
+    min_signed = BinaryFnType("min_signed")
+    max_unsigned = BinaryFnType("max_unsigned")
+    min_unsigned = BinaryFnType("min_unsigned")
 
 
 class TypeFnType:
-  """Type conversion function.
+    """Type conversion function.
 
-  A type conversion function takes a target type and a tensor expression and
-  returns the casted tensor expression.
-  """
+    A type conversion function takes a target type and a tensor expression and
+    returns the casted tensor expression.
+    """
 
-  def __init__(self, fn_name: str):
-    self.fn_name = fn_name
+    def __init__(self, fn_name: str):
+        self.fn_name = fn_name
 
-  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
-    return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
+    def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
+        return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
 
-  def __repr__(self):
-    return f"{self.fn_name}"
+    def __repr__(self):
+        return f"{self.fn_name}"
 
 
 class TypeFn:
-  """Type conversion function namespace.
+    """Type conversion function namespace.
+
+    As the integer types are signless, signedness is implement by 
diff erent cast
+    functions that treat integers as signed (`cast_signed`) or unsigned
+    (`cast_unsigned`) values.
 
-  As the integer types are signless, signedness is implement by 
diff erent cast
-  functions that treat integers as signed (`cast_signed`) or unsigned
-  (`cast_unsigned`) values.
+    Examples:
+    - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
+    - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
+    """
 
-  Examples:
-  - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
-  - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
-  """
-  cast_signed = TypeFnType("cast_signed")
-  cast_unsigned = TypeFnType("cast_unsigned")
+    cast_signed = TypeFnType("cast_signed")
+    cast_unsigned = TypeFnType("cast_unsigned")
 
 
 class ReduceFnUse:
-  """Reduction function use.
+    """Reduction function use.
 
-  A reduction use specifies the reduction function and dimensions.
-  """
+    A reduction use specifies the reduction function and dimensions.
+    """
 
-  def __init__(self, binary_fn: Optional[BinaryFnType],
-               binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef):
-    if bool(binary_fn) + bool(binary_attr) != 1:
-      raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
-    self.binary_fn = binary_fn
-    self.binary_attr = binary_attr
-    self.reduce_dims = reduce_dims
+    def __init__(
+        self,
+        binary_fn: Optional[BinaryFnType],
+        binary_attr: Optional["BinaryFnAttrDef"],
+        *reduce_dims: DimDef,
+    ):
+        if bool(binary_fn) + bool(binary_attr) != 1:
+            raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
+        self.binary_fn = binary_fn
+        self.binary_attr = binary_attr
+        self.reduce_dims = reduce_dims
 
-  def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
-    return TensorReduceFn(self, args)
+    def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
+        return TensorReduceFn(self, args)
 
-  def __repr__(self):
-    fn = self.binary_fn if self.binary_fn else self.binary_attr
-    return (
-        f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})")
+    def __repr__(self):
+        fn = self.binary_fn if self.binary_fn else self.binary_attr
+        return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})"
 
 
 class ReduceFnType:
-  """Reduction function.
+    """Reduction function.
 
-  A binary function that reduces its RHS into its LHS.
-  """
+    A binary function that reduces its RHS into its LHS.
+    """
 
-  def __init__(self, binary_fn: BinaryFnType):
-    if not isinstance(binary_fn, BinaryFnType):
-      raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
-    self.binary_fn = binary_fn
+    def __init__(self, binary_fn: BinaryFnType):
+        if not isinstance(binary_fn, BinaryFnType):
+            raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
+        self.binary_fn = binary_fn
 
-  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
-    return ReduceFnUse(self.binary_fn, None, *reduce_dims)
+    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+        return ReduceFnUse(self.binary_fn, None, *reduce_dims)
 
-  def __repr__(self):
-    return f"reduce_{repr(self.binary_fn)}"
+    def __repr__(self):
+        return f"reduce_{repr(self.binary_fn)}"
 
 
 class ReduceFn:
-  add = ReduceFnType(BinaryFn.add)
-  mul = ReduceFnType(BinaryFn.mul)
-  max_signed = ReduceFnType(BinaryFn.max_signed)
-  min_signed = ReduceFnType(BinaryFn.min_signed)
-  max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
-  min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
+    add = ReduceFnType(BinaryFn.add)
+    mul = ReduceFnType(BinaryFn.mul)
+    max_signed = ReduceFnType(BinaryFn.max_signed)
+    min_signed = ReduceFnType(BinaryFn.min_signed)
+    max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
+    min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
 
 
 ###############################################################################
@@ -405,237 +419,265 @@ class ReduceFn:
 
 
 class OperandKind(Enum):
-  INPUT_TENSOR = 0
-  SCALAR = 1
-  OUTPUT_TENSOR = 2
-  INDEX_ATTR = 3
-  UNARY_FN_ATTR = 4
-  BINARY_FN_ATTR = 5
-  TYPE_FN_ATTR = 6
+    INPUT_TENSOR = 0
+    SCALAR = 1
+    OUTPUT_TENSOR = 2
+    INDEX_ATTR = 3
+    UNARY_FN_ATTR = 4
+    BINARY_FN_ATTR = 5
+    TYPE_FN_ATTR = 6
 
 
 class OperandDef:
-  """Definition of an operand passed to an operation.
-
-  Keep the meta information of Tensor, Scalar, and Attribute operands and
-  provide the shared registration functionality.
-  """
-
-  def __init__(self,
-               kind: OperandKind,
-               type_var: Optional[TypeVar] = None,
-               size_exprs: Optional[Sequence[AffineExprDef]] = None,
-               index_dims: Optional[Sequence[DimDef]] = None,
-               default_indices: Optional[Sequence[int]] = None,
-               default_fn: Optional[str] = None):
-    if type_var and not isinstance(type_var, TypeVar):
-      raise ValueError(
-          f"OperandDef requires a TypeVar but got {repr(type_var)}")
-    self.owner = None  # type: Optional["LinalgOpDef"]
-    self.type_var = type_var
-    self.size_exprs = size_exprs
-    self.index_dims = index_dims
-    self.default_indices = default_indices
-    self.default_fn = default_fn
-    self.kind = kind
-    self.name = None  # type: Optional[str]
-    self.registered_index = -1  # type: int
-
-  def attach(self, index: int, name: str, owner: "LinalgOpDef"):
-    if self.owner:
-      raise ValueError(f"OperandDef already registered with an op: {self}")
-    self.registered_index = index
-    self.name = name
-    self.owner = owner
-
-  def is_input(self) -> bool:
-    return (self.kind == OperandKind.SCALAR or
-            self.kind == OperandKind.INPUT_TENSOR)
-
-  def is_tensor(self) -> bool:
-    return (self.kind == OperandKind.INPUT_TENSOR or
-            self.kind == OperandKind.OUTPUT_TENSOR)
-
-  def is_attribute(self) -> bool:
-    return (self.kind == OperandKind.INDEX_ATTR or
-            self.kind == OperandKind.UNARY_FN_ATTR or
-            self.kind == OperandKind.BINARY_FN_ATTR or
-            self.kind == OperandKind.TYPE_FN_ATTR)
-
-  def __hash__(self):
-    return hash(id(self))
-
-  def __repr__(self):
-    return (f"{self.name}:OperandDef(kind={self.kind.name}, "
+    """Definition of an operand passed to an operation.
+
+    Keep the meta information of Tensor, Scalar, and Attribute operands and
+    provide the shared registration functionality.
+    """
+
+    def __init__(
+        self,
+        kind: OperandKind,
+        type_var: Optional[TypeVar] = None,
+        size_exprs: Optional[Sequence[AffineExprDef]] = None,
+        index_dims: Optional[Sequence[DimDef]] = None,
+        default_indices: Optional[Sequence[int]] = None,
+        default_fn: Optional[str] = None,
+    ):
+        if type_var and not isinstance(type_var, TypeVar):
+            raise ValueError(f"OperandDef requires a TypeVar but got {repr(type_var)}")
+        self.owner = None  # type: Optional["LinalgOpDef"]
+        self.type_var = type_var
+        self.size_exprs = size_exprs
+        self.index_dims = index_dims
+        self.default_indices = default_indices
+        self.default_fn = default_fn
+        self.kind = kind
+        self.name = None  # type: Optional[str]
+        self.registered_index = -1  # type: int
+
+    def attach(self, index: int, name: str, owner: "LinalgOpDef"):
+        if self.owner:
+            raise ValueError(f"OperandDef already registered with an op: {self}")
+        self.registered_index = index
+        self.name = name
+        self.owner = owner
+
+    def is_input(self) -> bool:
+        return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR
+
+    def is_tensor(self) -> bool:
+        return (
+            self.kind == OperandKind.INPUT_TENSOR
+            or self.kind == OperandKind.OUTPUT_TENSOR
+        )
+
+    def is_attribute(self) -> bool:
+        return (
+            self.kind == OperandKind.INDEX_ATTR
+            or self.kind == OperandKind.UNARY_FN_ATTR
+            or self.kind == OperandKind.BINARY_FN_ATTR
+            or self.kind == OperandKind.TYPE_FN_ATTR
+        )
+
+    def __hash__(self):
+        return hash(id(self))
+
+    def __repr__(self):
+        return (
+            f"{self.name}:OperandDef(kind={self.kind.name}, "
             f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, "
             f"index_dims={self.index_dims}, "
             f"default_indices={self.default_indices}, "
-            f"default_fn={self.default_fn})")
+            f"default_fn={self.default_fn})"
+        )
 
 
 class TensorDef:
-  """Tensor operand definition.
-
-  Tensor operands are indexed using the associated indexing_map when forwarded
-  to the body of the structured op. A unique name identifies the tensor operands
-  and an index determines their position in the operation's parameter list. A
-  tensor definition takes type, a shape, and an optional flag to mark output
-  tensors. Additionally, a tuple of index dimensions may be used to map the
-  tensor to the loop dimensions of the operation. This mapping is needed to
-  compute the indexing map of shape-only tensors that have no uses.
-  """
-
-  def __init__(self,
-               type_var: TypeVar,
-               *shape: AffineExprDef,
-               index_dims: Optional[Sequence[DimDef]] = None,
-               output: bool = False):
-    if index_dims and len(shape) != len(index_dims):
-      raise ValueError(f"Expected the shape rank {len(shape)} to match the "
-                       f"number of index_dims {len(index_dims)}")
-    if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
-      raise ValueError(f"TensorDef requires index dims of type DimDef but "
-                       f"got {index_dims}")
-    kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
-    self.operand_def = OperandDef(
-        kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
-
-  def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
-    assert self.operand_def.owner, "TensorDef is not registered with an op"
-    state = AffineBuildState(
-        global_state=self.operand_def.owner._affine_state,
-        allow_new_symbols=False)
-    if not isinstance(dims, tuple):
-      dims = (dims,)  # Handle single subscript case.
-    # Special case: (None) is a 0d-scalar use.
-    if dims == (None,):
-      dims = ()
-
-    exprs = []
-    for expr_def in dims:
-      if not isinstance(expr_def, AffineExprDef):
-        raise KeyError(
-            "A TensorDef can only be subscripted by a tuple of affine dims")
-      exprs.append(expr_def)
-    return TensorUse(self.operand_def, exprs)
-
-  def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
-    """Creates a new 1:1 comprehension by binding this tensor to an expression.
-
-    Note that due to the way assignment works in Python, we have to capture
-    direct assignment as a setitem on the TensorDef.
+    """Tensor operand definition.
+
+    Tensor operands are indexed using the associated indexing_map when forwarded
+    to the body of the structured op. A unique name identifies the tensor operands
+    and an index determines their position in the operation's parameter list. A
+    tensor definition takes type, a shape, and an optional flag to mark output
+    tensors. Additionally, a tuple of index dimensions may be used to map the
+    tensor to the loop dimensions of the operation. This mapping is needed to
+    compute the indexing map of shape-only tensors that have no uses.
     """
-    if not isinstance(value, TensorExpression):
-      raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. "
-                       f"Got: {repr(value)}")
-    use = self[dims]
-    comp = Comprehension((use, value))
-    self.operand_def.owner.comprehensions.append(comp)
+
+    def __init__(
+        self,
+        type_var: TypeVar,
+        *shape: AffineExprDef,
+        index_dims: Optional[Sequence[DimDef]] = None,
+        output: bool = False,
+    ):
+        if index_dims and len(shape) != len(index_dims):
+            raise ValueError(
+                f"Expected the shape rank {len(shape)} to match the "
+                f"number of index_dims {len(index_dims)}"
+            )
+        if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
+            raise ValueError(
+                f"TensorDef requires index dims of type DimDef but " f"got {index_dims}"
+            )
+        kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
+        self.operand_def = OperandDef(
+            kind, type_var=type_var, size_exprs=shape, index_dims=index_dims
+        )
+
+    def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
+        assert self.operand_def.owner, "TensorDef is not registered with an op"
+        state = AffineBuildState(
+            global_state=self.operand_def.owner._affine_state, allow_new_symbols=False
+        )
+        if not isinstance(dims, tuple):
+            dims = (dims,)  # Handle single subscript case.
+        # Special case: (None) is a 0d-scalar use.
+        if dims == (None,):
+            dims = ()
+
+        exprs = []
+        for expr_def in dims:
+            if not isinstance(expr_def, AffineExprDef):
+                raise KeyError(
+                    "A TensorDef can only be subscripted by a tuple of affine dims"
+                )
+            exprs.append(expr_def)
+        return TensorUse(self.operand_def, exprs)
+
+    def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
+        """Creates a new 1:1 comprehension by binding this tensor to an expression.
+
+        Note that due to the way assignment works in Python, we have to capture
+        direct assignment as a setitem on the TensorDef.
+        """
+        if not isinstance(value, TensorExpression):
+            raise ValueError(
+                f"Only TensorExpressions can be assigned to TensorDefs. "
+                f"Got: {repr(value)}"
+            )
+        use = self[dims]
+        comp = Comprehension((use, value))
+        self.operand_def.owner.comprehensions.append(comp)
 
 
 class ScalarDef(TensorExpression):
-  """Scalar operand definition.
+    """Scalar operand definition.
 
-  Scalar operands are forwarded to the body of the structured op as they are.
-  A unique name identifies the scalars and an index determines their position in
-  the operation's parameter list.
-  """
+    Scalar operands are forwarded to the body of the structured op as they are.
+    A unique name identifies the scalars and an index determines their position in
+    the operation's parameter list.
+    """
 
-  def __init__(self, type_var: TypeVar):
-    self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
+    def __init__(self, type_var: TypeVar):
+        self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
 
-  @property
-  def scalar_name(self) -> str:
-    name = self.operand_def.name
-    assert name is not None, "ScalarDef not registered with an op"
-    return name
+    @property
+    def scalar_name(self) -> str:
+        name = self.operand_def.name
+        assert name is not None, "ScalarDef not registered with an op"
+        return name
 
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarArg(self.scalar_name).expr()
+    def to_scalar_expression(self) -> ScalarExpression:
+        return ScalarArg(self.scalar_name).expr()
 
 
 class IndexAttrDef:
-  """Index attribute definition.
-
-  Index attributes provide a way to define and set symbols that can be used in
-  indexing expressions. Every attribute specifies a tuple of symbols that at
-  compile-time are replaced by integer values as well as their default values.
-  """
-
-  def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
-    if any(not isinstance(size, SymbolDef) for size in sizes):
-      raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef "
-                       f"but got {sizes}")
-    if any(not isinstance(default_val, int) for default_val in default):
-      raise ValueError(f"IndexAttrDef requires default values of type int "
-                       f"but got {default}")
-    if len(sizes) != len(default):
-      raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
-                       f"but got {len(default)}")
-    self.operand_def = OperandDef(
-        OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
+    """Index attribute definition.
+
+    Index attributes provide a way to define and set symbols that can be used in
+    indexing expressions. Every attribute specifies a tuple of symbols that at
+    compile-time are replaced by integer values as well as their default values.
+    """
+
+    def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
+        if any(not isinstance(size, SymbolDef) for size in sizes):
+            raise ValueError(
+                f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}"
+            )
+        if any(not isinstance(default_val, int) for default_val in default):
+            raise ValueError(
+                f"IndexAttrDef requires default values of type int "
+                f"but got {default}"
+            )
+        if len(sizes) != len(default):
+            raise ValueError(
+                f"IndexAttrDef expects {len(sizes)} default values "
+                f"but got {len(default)}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default
+        )
 
 
 class UnaryFnAttrDef:
-  """Unary function attribute definition.
+    """Unary function attribute definition.
 
-  Unary function attributes provide a way to make the arithmetic computation
-  parametrizable. Every attribute specifies a default unary function
-  that may be overwritten at operation instantiation time.
-  """
+    Unary function attributes provide a way to make the arithmetic computation
+    parametrizable. Every attribute specifies a default unary function
+    that may be overwritten at operation instantiation time.
+    """
 
-  def __init__(self, default: "UnaryFnType"):
-    if not isinstance(default, UnaryFnType):
-      raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType "
-                       f"but got {default}")
-    self.operand_def = OperandDef(
-        OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name)
+    def __init__(self, default: "UnaryFnType"):
+        if not isinstance(default, UnaryFnType):
+            raise ValueError(
+                f"UnaryFnAttrDef requires default of type UnaryFnType "
+                f"but got {default}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name
+        )
 
-  def __call__(self, arg: TensorExpression) -> TensorFn:
-    return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
+    def __call__(self, arg: TensorExpression) -> TensorFn:
+        return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
 
 
 class BinaryFnAttrDef:
-  """Binary function attribute definition.
+    """Binary function attribute definition.
 
-  Binary function attributes provide a way to make the arithmetic computation
-  parametrizable. Every attribute specifies a default binary function
-  that may be overwritten at operation instantiation time.
-  """
+    Binary function attributes provide a way to make the arithmetic computation
+    parametrizable. Every attribute specifies a default binary function
+    that may be overwritten at operation instantiation time.
+    """
 
-  def __init__(self, default: "BinaryFnType"):
-    if not isinstance(default, BinaryFnType):
-      raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType "
-                       f"but got {default}")
-    self.operand_def = OperandDef(
-        OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name)
+    def __init__(self, default: "BinaryFnType"):
+        if not isinstance(default, BinaryFnType):
+            raise ValueError(
+                f"BinaryFnAttrDef requires default of type BinaryFnType "
+                f"but got {default}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name
+        )
 
-  def __call__(self, arg0: TensorExpression,
-               arg1: TensorExpression) -> TensorFn:
-    return TensorFn(FunctionKind.BINARY, None, self.operand_def, None,
-                    [arg0, arg1])
+    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
+        return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1])
 
-  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
-    return ReduceFnUse(None, self, *reduce_dims)
+    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+        return ReduceFnUse(None, self, *reduce_dims)
 
 
 class TypeFnAttrDef:
-  """Type conversion function attribute definition.
+    """Type conversion function attribute definition.
 
-  Type conversion function attributes provide a way to make type conversions
-  parameterizable. Every attribute specifies a default type conversion function
-  that may be overwritten at operation instantiation time.
-  """
+    Type conversion function attributes provide a way to make type conversions
+    parameterizable. Every attribute specifies a default type conversion function
+    that may be overwritten at operation instantiation time.
+    """
 
-  def __init__(self, default: "TypeFnType"):
-    if not isinstance(default, TypeFnType):
-      raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType "
-                       f"but got {default}")
-    self.operand_def = OperandDef(
-        OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
+    def __init__(self, default: "TypeFnType"):
+        if not isinstance(default, TypeFnType):
+            raise ValueError(
+                f"TypeFnAttrDef requires default of type TypeFnType "
+                f"but got {default}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name
+        )
 
-  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
-    return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
+    def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
+        return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
 
 
 ###############################################################################
@@ -644,48 +686,48 @@ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
 
 
 class Comprehension:
-  """Represents a single comprehension."""
-
-  def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
-    self.definitions = list()  # List[TensorUse]
-    self.values = list()  # List[TensorExpression]
-
-    # Find the lhs to reduction rhs.
-    for assign, value in bindings:
-      if isinstance(value, TensorReduceFn):
-        if value.lhs:
-          raise ValueError(f"Reduction expression already assigns: {value}")
-        value.lhs = assign
-      self.definitions.append(assign)
-      self.values.append(value)
-
-  @property
-  def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
-    """Gets the reduction dims for the comprehension or None."""
-    result = set()
-    for use in self.values:
-      if isinstance(use, TensorReduceFn):
-        result.add(use.reduce_use.reduce_dims)
-      else:
-        result.add(tuple())
-    return result
-
-  def __repr__(self):
-    if len(self.definitions) > 1:
-      defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
-      values_repr = f"({', '.join(repr(v) for v in self.values)})"
-    else:
-      defs_repr = f"{repr(self.definitions[0])}"
-      values_repr = f"{repr(self.values[0])}"
-
-    return f"{defs_repr} = {values_repr}"
+    """Represents a single comprehension."""
+
+    def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
+        self.definitions = list()  # List[TensorUse]
+        self.values = list()  # List[TensorExpression]
+
+        # Find the lhs to reduction rhs.
+        for assign, value in bindings:
+            if isinstance(value, TensorReduceFn):
+                if value.lhs:
+                    raise ValueError(f"Reduction expression already assigns: {value}")
+                value.lhs = assign
+            self.definitions.append(assign)
+            self.values.append(value)
+
+    @property
+    def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
+        """Gets the reduction dims for the comprehension or None."""
+        result = set()
+        for use in self.values:
+            if isinstance(use, TensorReduceFn):
+                result.add(use.reduce_use.reduce_dims)
+            else:
+                result.add(tuple())
+        return result
+
+    def __repr__(self):
+        if len(self.definitions) > 1:
+            defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
+            values_repr = f"({', '.join(repr(v) for v in self.values)})"
+        else:
+            defs_repr = f"{repr(self.definitions[0])}"
+            values_repr = f"{repr(self.values[0])}"
+
+        return f"{defs_repr} = {values_repr}"
 
 
 class OpInterfaceDef:
-  """An interface that an op implements."""
+    """An interface that an op implements."""
 
-  def __init__(self, cpp_name: str):
-    self.cpp_name = cpp_name
+    def __init__(self, cpp_name: str):
+        self.cpp_name = cpp_name
 
 
 ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
@@ -694,86 +736,94 @@ def __init__(self, cpp_name: str):
 
 
 class OpDefinitionDef:
-  """A method that an op implements."""
+    """A method that an op implements."""
 
-  def __init__(self, def_name: str):
-    self.def_name = def_name
+    def __init__(self, def_name: str):
+        self.def_name = def_name
 
 
 Canonicalizer = OpDefinitionDef("hasCanonicalizer")
 
 
 class OpMetadataDef(YAMLObject):
-  """Metadata about the op (generally not behavior impacting)."""
-  yaml_tag = "!LinalgOpMetadata"
-
-  def __init__(self, name: str, cpp_class_name: Optional[str],
-               doc: Optional[str]):
-    self.name = name
-    self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
-    self.doc = doc
-    self.implements = []  # type: List[OpInterfaceDef]
-    self.defines = []  # type: List[OpDefinitionsDef]
-
-  def to_yaml_custom_dict(self):
-    d = dict(
-        name=self.name,
-        cpp_class_name=self.cpp_class_name,
-        doc=self.doc,
-    )
-    if self.implements:
-      d["implements"] = [intr.cpp_name for intr in self.implements]
-    if self.defines:
-      d["defines"] = [defi.def_name for defi in self.defines]
-    return d
+    """Metadata about the op (generally not behavior impacting)."""
+
+    yaml_tag = "!LinalgOpMetadata"
+
+    def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]):
+        self.name = name
+        self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
+        self.doc = doc
+        self.implements = []  # type: List[OpInterfaceDef]
+        self.defines = []  # type: List[OpDefinitionsDef]
+
+    def to_yaml_custom_dict(self):
+        d = dict(
+            name=self.name,
+            cpp_class_name=self.cpp_class_name,
+            doc=self.doc,
+        )
+        if self.implements:
+            d["implements"] = [intr.cpp_name for intr in self.implements]
+        if self.defines:
+            d["defines"] = [defi.def_name for defi in self.defines]
+        return d
 
 
 class LinalgOpDef:
-  """Definition of a linalg op."""
-
-  def __init__(self,
-               name: str,
-               cpp_class_name: Optional[str] = None,
-               doc: Optional[str] = None):
-    self.metadata = OpMetadataDef(
-        name=name, cpp_class_name=cpp_class_name, doc=doc)
-    self.registered_operands = dict()  # type: Dict[str, OperandDef]
-    self.domain = list()  # type: List[DimDef]
-    self.comprehensions = list()  # type: List[Comprehension]
-    self._affine_state = AffineBuildState()
-
-  def add_operand(self, name: str, operand: OperandDef):
-    """Registers an operand."""
-    if name in self.registered_operands:
-      raise ValueError(f"The operand {name} is already registered "
-                       f"to {self.registered_operands['name']}")
-    structured_op_methods = [
-        "inputs", "outputs", "result_tensors", "region", "iterator_types",
-        "indexing_maps", "getRegionBuilder", "getLibraryCallName"
-    ]
-    if operand.is_attribute() and name in structured_op_methods:
-      raise ValueError(f"The attribute name {name} conflicts with a structured "
-                       f"op method name")
-    # Ensure output tensors are registered after input tensors and scalars and
-    # attributes are registered after all other operand types.
-    if operand.is_input() and any(
-        not op_def.is_input() for op_def in self.registered_operands.values()):
-      raise ValueError(f"Input {name} registered after an output or attribute")
-    if operand.kind == OperandKind.OUTPUT_TENSOR and any(
-        op_def.is_attribute() for op_def in self.registered_operands.values()):
-      raise ValueError(f"Output {name} registered after an attribute")
-    operand.attach(len(self.registered_operands), name, self)
-    self.registered_operands[name] = operand
-
-  def __repr__(self):
-    lines = [
-        f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"
-    ]
-    for name, operand in self.registered_operands.items():
-      lines.append(f"  {operand}")
-    if self.comprehensions:
-      lines[-1] += " {"
-      for comprehension in self.comprehensions:
-        lines.append(f"    {comprehension}")
-      lines.append("}")
-    return "\n".join(lines)
+    """Definition of a linalg op."""
+
+    def __init__(
+        self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None
+    ):
+        self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
+        self.registered_operands = dict()  # type: Dict[str, OperandDef]
+        self.domain = list()  # type: List[DimDef]
+        self.comprehensions = list()  # type: List[Comprehension]
+        self._affine_state = AffineBuildState()
+
+    def add_operand(self, name: str, operand: OperandDef):
+        """Registers an operand."""
+        if name in self.registered_operands:
+            raise ValueError(
+                f"The operand {name} is already registered "
+                f"to {self.registered_operands['name']}"
+            )
+        structured_op_methods = [
+            "inputs",
+            "outputs",
+            "result_tensors",
+            "region",
+            "iterator_types",
+            "indexing_maps",
+            "getRegionBuilder",
+            "getLibraryCallName",
+        ]
+        if operand.is_attribute() and name in structured_op_methods:
+            raise ValueError(
+                f"The attribute name {name} conflicts with a structured "
+                f"op method name"
+            )
+        # Ensure output tensors are registered after input tensors and scalars and
+        # attributes are registered after all other operand types.
+        if operand.is_input() and any(
+            not op_def.is_input() for op_def in self.registered_operands.values()
+        ):
+            raise ValueError(f"Input {name} registered after an output or attribute")
+        if operand.kind == OperandKind.OUTPUT_TENSOR and any(
+            op_def.is_attribute() for op_def in self.registered_operands.values()
+        ):
+            raise ValueError(f"Output {name} registered after an attribute")
+        operand.attach(len(self.registered_operands), name, self)
+        self.registered_operands[name] = operand
+
+    def __repr__(self):
+        lines = [f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"]
+        for name, operand in self.registered_operands.items():
+            lines.append(f"  {operand}")
+        if self.comprehensions:
+            lines[-1] += " {"
+            for comprehension in self.comprehensions:
+                lines.append(f"    {comprehension}")
+            lines.append("}")
+        return "\n".join(lines)

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 2a0da68295265..d522d5712d253 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -21,422 +21,468 @@
 
 
 def _serialize_affine_map(affine_map: _ir.AffineMap) -> str:
-  with affine_map.context:
-    # Affine map printing/parsing is via an AffineMap attr.
-    attr = _ir.AffineMapAttr.get(affine_map)
-    return str(attr)
+    with affine_map.context:
+        # Affine map printing/parsing is via an AffineMap attr.
+        attr = _ir.AffineMapAttr.get(affine_map)
+        return str(attr)
 
 
 class TensorUseConfig:
-  """Wrapper around a TensorUse with additional context-bound state."""
+    """Wrapper around a TensorUse with additional context-bound state."""
 
-  def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
-    self.tensor_use = tensor_use
-    self.indexing_map = indexing_map
+    def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
+        self.tensor_use = tensor_use
+        self.indexing_map = indexing_map
 
-  def __repr__(self):
-    return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
+    def __repr__(self):
+        return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
 
 
 class OperandDefConfig(YAMLObject):
-  """Wrapper containing an operand definition with additional state."""
-  yaml_tag = "!LinalgOperandDefConfig"
-
-  def __init__(self,
-               operand_def: OperandDef,
-               shape_map: Optional[_ir.AffineMap] = None,
-               index_attr_map: Optional[_ir.AffineMap] = None):
-    self.operand_def = operand_def
-    self.shape_map = shape_map  # type: Optional[_ir.AffineMap]
-    self.index_attr_map = index_attr_map  # type: Optional[_ir.AffineMap]
-    self.indexing_map = None  # type: Optional[_ir.AffineMap]
-
-  @property
-  def name(self) -> str:
-    return self.operand_def.name
-
-  @property
-  def kind(self) -> OperandKind:
-    return self.operand_def.kind
-
-  @property
-  def type_var(self) -> TypeVar:
-    return self.operand_def.type_var
-
-  def to_yaml_custom_dict(self):
-    self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
-    if self.type_var:
-      self_dict["type_var"] = self.type_var.name
-    if self.shape_map:
-      self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
-    if self.index_attr_map:
-      self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
-    if self.operand_def.default_indices:
-      self_dict["default_indices"] = self.operand_def.default_indices
-    if self.operand_def.default_fn:
-      self_dict["default_fn"] = self.operand_def.default_fn
-    return self_dict
-
-  def __repr__(self):
-    return (f"OperandDefConfig({self.operand_def}, "
+    """Wrapper containing an operand definition with additional state."""
+
+    yaml_tag = "!LinalgOperandDefConfig"
+
+    def __init__(
+        self,
+        operand_def: OperandDef,
+        shape_map: Optional[_ir.AffineMap] = None,
+        index_attr_map: Optional[_ir.AffineMap] = None,
+    ):
+        self.operand_def = operand_def
+        self.shape_map = shape_map  # type: Optional[_ir.AffineMap]
+        self.index_attr_map = index_attr_map  # type: Optional[_ir.AffineMap]
+        self.indexing_map = None  # type: Optional[_ir.AffineMap]
+
+    @property
+    def name(self) -> str:
+        return self.operand_def.name
+
+    @property
+    def kind(self) -> OperandKind:
+        return self.operand_def.kind
+
+    @property
+    def type_var(self) -> TypeVar:
+        return self.operand_def.type_var
+
+    def to_yaml_custom_dict(self):
+        self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
+        if self.type_var:
+            self_dict["type_var"] = self.type_var.name
+        if self.shape_map:
+            self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
+        if self.index_attr_map:
+            self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
+        if self.operand_def.default_indices:
+            self_dict["default_indices"] = self.operand_def.default_indices
+        if self.operand_def.default_fn:
+            self_dict["default_fn"] = self.operand_def.default_fn
+        return self_dict
+
+    def __repr__(self):
+        return (
+            f"OperandDefConfig({self.operand_def}, "
             f"shape_map={self.shape_map}, "
             f"index_attr_map={self.index_attr_map}, "
-            f"indexing_map={self.indexing_map})")
+            f"indexing_map={self.indexing_map})"
+        )
 
 
 class LinalgIndexingMapsConfig(YAMLObject):
-  """Abstracts the style of indexing maps that the op exports.
-
-  Presently only static (tied to the op name) indexing maps are supported. In
-  the future, it is expected that we will have additional variants:
-    - Dynamic based on attributes
-    - Dynamic based on operands
-  Each is expected to require a 
diff erent variant of specification.
-  """
-  yaml_tag = "!LinalgIndexingMapsConfig"
-
-  def __init__(self,
-               static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
-    self.static_indexing_maps = static_indexing_maps
-
-  def to_yaml_custom_dict(self):
-    if self.static_indexing_maps is not None:
-      return dict(static_indexing_maps=[
-          _serialize_affine_map(m) for m in self.static_indexing_maps
-      ])
-    raise ValueError(
-        f"LinalgIndexingMapsConfig must have one type of indexing map"
-        f"(got none)")
+    """Abstracts the style of indexing maps that the op exports.
 
+    Presently only static (tied to the op name) indexing maps are supported. In
+    the future, it is expected that we will have additional variants:
+      - Dynamic based on attributes
+      - Dynamic based on operands
+    Each is expected to require a 
diff erent variant of specification.
+    """
 
-class LinalgStructuredOpConfig(YAMLObject):
-  """Configuration for metadata sufficient to construct a linalg named op."""
-
-  yaml_tag = "!LinalgStructuredOpConfig"
-
-  def __init__(self,
-               comprehension: Comprehension,
-               domain: Sequence[DimDef],
-               registered_operands: Sequence[OperandDef],
-               context: Optional[_ir.Context] = None):
-    self.context = context if context is not None else _ir.Context()
-    self.affine_state = AffineBuildState()
-    self.writes = list()  # type: List[Tuple[TensorUse, TensorExpression]]
-    self.operands = dict()  # type: Dict[OperandDef, OperandDefConfig]
-    self.uses = dict()  # type: Dict[TensorUse, TensorUseConfig]
-
-    # Compute the ordered set of writes and collect the tensor, capture, dims,
-    # and index uses.
-    collected_tensor_uses = set()
-    collected_scalar_uses = set()
-    collected_dim_uses = set()
-    collected_indices = set()
-    for write_use, read_use in zip(comprehension.definitions,
-                                   comprehension.values):
-      self.writes.append((write_use, read_use))
-
-    for write_use, read_use in self.writes:
-      collected_tensor_uses.add(write_use)
-      read_use.collect_tensor_uses(collected_tensor_uses)
-      read_use.collect_scalar_uses(collected_scalar_uses)
-      read_use.collect_dim_uses(collected_dim_uses)
-      write_use.collect_dim_uses(collected_dim_uses)
-      read_use.collect_indices(collected_indices)
-
-    # Set domain to the sorted list of uses if no domain annotation is given.
-    if not domain:
-      domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
-
-    # Verify the domain dimensions match the used dimensions.
-    if (len(domain) != len(collected_dim_uses) or
-        any(dim not in collected_dim_uses for dim in domain)):
-      raise ValueError(f"Expected the annotated domain dimensions {domain} to "
-                       f"match the set of dimension used by the tensor "
-                       f"comprehension {collected_dim_uses}")
-
-    # Instantiate the dimensions in the given order.
-    with self.context:
-      local_state = AffineBuildState(
-          global_state=self.affine_state, allow_new_symbols=False)
-      for dim in domain:
-        dim.build(state=local_state)
-
-    # Collect all attribute definitions.
-    collected_attr_defs = list()
-    for operand in registered_operands:
-      if operand.is_attribute():
-        collected_attr_defs.append(operand)
-
-    # Collect all tensors with manual indexing annotation.
-    collected_index_defs = list()
-    for operand in registered_operands:
-      if operand.index_dims:
-        if any(dim not in collected_dim_uses for dim in operand.index_dims):
-          raise ValueError(f"Expected all index dims {operand.index_dims} of "
-                           f"operand {operand.name} to have uses.")
-        collected_index_defs.append(operand)
-
-    # Collect the operand definitions of all tensor/scalar uses, attributes, and
-    # shape-only tensors.
-    all_operand_defs = list()
-    for use in collected_tensor_uses:
-      all_operand_defs.append(use.operand_def)
-    for use in collected_scalar_uses:
-      all_operand_defs.append(use.operand_def)
-    for definition in collected_attr_defs:
-      all_operand_defs.append(definition)
-    for definition in collected_index_defs:
-      all_operand_defs.append(definition)
-
-    # Add all operands in registration order to ensure the symbols are
-    # registered in the order they appear.
-    all_operand_defs = sorted(
-        all_operand_defs, key=lambda operand_def: operand_def.registered_index)
-    for operand_def in all_operand_defs:
-      self.add_operand(operand_def)
-
-    # Add all shape-only tensor index_dim annotations and all tensor uses.
-    for definition in collected_index_defs:
-      self.add_indexed_operand(definition)
-    for use in collected_tensor_uses:
-      self.add_tensor_use(use)
-
-    # Normalize all shape and indexing maps now that full count of dims and
-    # symbols are known.
-    for cuse in self.uses.values():
-      cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
-    for definition in collected_index_defs:
-      self.operands[definition].indexing_map = self._normalize_affine_map(
-          self.operands[definition].indexing_map)
-    for operand_config in self.operands.values():
-      if operand_config.shape_map:
-        operand_config.shape_map = self._normalize_affine_map(
-            operand_config.shape_map, with_dims=False)
-      if operand_config.index_attr_map:
-        operand_config.index_attr_map = self._normalize_affine_map(
-            operand_config.index_attr_map, with_dims=False)
-
-    # Now for each write use, propagate the indexing maps from the use to the
-    # tensor, ensuring that there are not conflicts.
-    for write_use, _ in self.writes:
-      write_tensor_config = self.operands[write_use.operand_def]
-      if write_tensor_config.indexing_map:
-        raise ValueError(
-            f"Unexpected multi-write to a single tensor: {write_tensor_config}")
-      write_tensor_config.indexing_map = self.uses[write_use].indexing_map
-
-    # For each read use, propagate the indexing maps from the use to the
-    # tensor, ensuring that there are not conflicts.
-    for _, read_expr in self.writes:
-      read_uses = set()  # type: Set[TensorUse]
-      read_expr.collect_tensor_uses(read_uses)
-      for read_use in read_uses:
-        read_operand_config = self.operands[read_use.operand_def]
-        if (read_operand_config.indexing_map and
-            read_operand_config.indexing_map !=
-            self.uses[read_use].indexing_map):
-          raise ValueError(
-              f"Unexpected multi-read of a tensor with 
diff erent accesses:"
-              f"{read_operand_config} vs {read_use}")
-        read_operand_config.indexing_map = self.uses[read_use].indexing_map
-
-    # Set the indexing map of all scalar uses to the empty map.
-    for operand_config in self.operands.values():
-      if operand_config.operand_def.kind == OperandKind.SCALAR:
-        operand_config.indexing_map = self._get_scalar_map()
-
-    # Check all registered tensor and scalar operands have an indexing map.
-    for operand in registered_operands:
-      if operand.is_attribute():
-        continue
-      if not (operand in self.operands and self.operands[operand].indexing_map):
-        raise ValueError(f"Failed to compute an indexing map for operand "
-                         f"{operand.name}")
-
-    # Collect reduction dims and ensure all the same.
-    all_reduction_dims = set(comprehension.all_reduction_dims)
-    if len(all_reduction_dims) != 1:
-      raise ValueError(
-          f"All writes within a generic must have the same reduction "
-          f"dims. Got: {all_reduction_dims}")
-    self.reduction_dims = next(iter(all_reduction_dims))
-
-    # Check the index dimension exists and resolve.
-    for index in collected_indices:
-      if index.dim_def.dimname not in self.affine_state.all_dims:
+    yaml_tag = "!LinalgIndexingMapsConfig"
+
+    def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
+        self.static_indexing_maps = static_indexing_maps
+
+    def to_yaml_custom_dict(self):
+        if self.static_indexing_maps is not None:
+            return dict(
+                static_indexing_maps=[
+                    _serialize_affine_map(m) for m in self.static_indexing_maps
+                ]
+            )
         raise ValueError(
-            f"The dimension {index.dim_def.dimname} is not part of the "
-            f"iteration domain {self.affine_state.all_dims}")
-      index.resolve_dimension_name(self.affine_state)
-
-    # Generate the scalar assignments (used to build a body).
-    self.assignments = [
-        ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
-        for write_use, read_expr in self.writes
-    ]
-
-  @property
-  def ordered_operands(self) -> Sequence[OperandDefConfig]:
-    return sorted(
-        self.operands.values(),
-        key=lambda operand: operand.operand_def.registered_index)
-
-  @property
-  def ordered_dims(self) -> Sequence[Tuple[str, int]]:
-    """Gets the ordered list of dim bindings (symbolic name, position).
-
-    TODO: The original parser relies on parse ordering to arrive at the
-    iterator types, but that ordering is not defined on the Python side, so
-    this may be ambiguous.
-    """
-    return list(self.affine_state.all_dims.items())
-
-  @property
-  def indexing_maps(self) -> Sequence[_ir.AffineMap]:
-    return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
-
-  @property
-  def iterator_types(self) -> Sequence[str]:
-
-    def get_type(symbolic_name, position):
-      for reduction_dim_expr in self.reduction_dims:
-        if reduction_dim_expr.dimname == symbolic_name:
-          return "reduction"
-      return "parallel"
-
-    return [get_type(*dim) for dim in self.ordered_dims]
-
-  def add_operand(self, operand_def: OperandDef):
-    if operand_def in self.operands:
-      return
-    if not (operand_def.is_tensor() or
-            operand_def.kind == OperandKind.INDEX_ATTR):
-      self.operands[operand_def] = OperandDefConfig(operand_def)
-      return
-    with self.context:
-      local_state = AffineBuildState(
-          global_state=self.affine_state, allow_new_dims=False)
-      exprs = []
-      for expr in operand_def.size_exprs:
-        exprs.append(expr.build(state=local_state))
-      assert local_state.local_dim_count == 0
-      affine_map = _ir.AffineMap.get(
-          dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
-      if operand_def.kind == OperandKind.INDEX_ATTR:
-        self.operands[operand_def] = OperandDefConfig(
-            operand_def, index_attr_map=affine_map)
-      else:
-        self.operands[operand_def] = OperandDefConfig(
-            operand_def, shape_map=affine_map)
-
-  def add_indexed_operand(self, operand_def: OperandDef):
-    with self.context:
-      local_state = AffineBuildState(
-          global_state=self.affine_state, allow_new_symbols=False)
-      exprs = []
-      for expr in operand_def.index_dims:
-        exprs.append(expr.build(state=local_state))
-      self.operands[operand_def].indexing_map = _ir.AffineMap.get(
-          dim_count=local_state.dim_count,
-          symbol_count=local_state.symbol_count,
-          exprs=exprs)
-
-  def add_tensor_use(self, tensor_use: TensorUse):
-    if tensor_use in self.uses:
-      return
-    with self.context:
-      local_state = AffineBuildState(
-          global_state=self.affine_state, allow_new_symbols=False)
-      exprs = []
-      for expr in tensor_use.indices:
-        exprs.append(expr.build(state=local_state))
-      indexing_map = _ir.AffineMap.get(
-          dim_count=local_state.dim_count,
-          symbol_count=local_state.symbol_count,
-          exprs=exprs)
-
-      use_config = TensorUseConfig(tensor_use, indexing_map)
-      self.uses[tensor_use] = use_config
-
-  def _get_scalar_map(self) -> _ir.AffineMap:
-    """Create an empty affine map used to index a scalar."""
-    with self.context:
-      return _ir.AffineMap.get(
-          dim_count=self.affine_state.dim_count,
-          symbol_count=self.affine_state.symbol_count,
-          exprs=list())
-
-  def _normalize_affine_map(self,
-                            affine_map: _ir.AffineMap,
-                            with_dims: bool = True) -> _ir.AffineMap:
-    """Normalizes an indexing map to have the max known symbols and dims."""
-    with self.context:
-      return _ir.AffineMap.get(
-          dim_count=self.affine_state.dim_count if with_dims else 0,
-          symbol_count=self.affine_state.symbol_count,
-          exprs=list(affine_map.results))
-
-  def to_yaml_custom_dict(self):
-    self_dict = dict(args=self.ordered_operands)
-    # TODO: Refactor the hierarchy internally when supporting more
-    # than static (preserving this serialized form).
-    self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
-        static_indexing_maps=self.indexing_maps)
-    self_dict["iterator_types"] = self.iterator_types
-    self_dict["assignments"] = self.assignments
-    return self_dict
-
-  def __repr__(self):
-    lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
-    lines.append("operands=[")
-    for def_config in self.ordered_operands:
-      lines.append(f"  {repr(def_config)}")
-    lines.append("], indexing_maps=[")
-    for m in self.indexing_maps:
-      lines.append(f"  {repr(m)}")
-    lines.append(f"], iterator_types=[")
-    for t in self.iterator_types:
-      lines.append(f"  {t}")
-    lines.append("])")
-    return "\n".join(lines)
+            f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)"
+        )
+
+
+class LinalgStructuredOpConfig(YAMLObject):
+    """Configuration for metadata sufficient to construct a linalg named op."""
+
+    yaml_tag = "!LinalgStructuredOpConfig"
+
+    def __init__(
+        self,
+        comprehension: Comprehension,
+        domain: Sequence[DimDef],
+        registered_operands: Sequence[OperandDef],
+        context: Optional[_ir.Context] = None,
+    ):
+        self.context = context if context is not None else _ir.Context()
+        self.affine_state = AffineBuildState()
+        self.writes = list()  # type: List[Tuple[TensorUse, TensorExpression]]
+        self.operands = dict()  # type: Dict[OperandDef, OperandDefConfig]
+        self.uses = dict()  # type: Dict[TensorUse, TensorUseConfig]
+
+        # Compute the ordered set of writes and collect the tensor, capture, dims,
+        # and index uses.
+        collected_tensor_uses = set()
+        collected_scalar_uses = set()
+        collected_dim_uses = set()
+        collected_indices = set()
+        for write_use, read_use in zip(comprehension.definitions, comprehension.values):
+            self.writes.append((write_use, read_use))
+
+        for write_use, read_use in self.writes:
+            collected_tensor_uses.add(write_use)
+            read_use.collect_tensor_uses(collected_tensor_uses)
+            read_use.collect_scalar_uses(collected_scalar_uses)
+            read_use.collect_dim_uses(collected_dim_uses)
+            write_use.collect_dim_uses(collected_dim_uses)
+            read_use.collect_indices(collected_indices)
+
+        # Set domain to the sorted list of uses if no domain annotation is given.
+        if not domain:
+            domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
+
+        # Verify the domain dimensions match the used dimensions.
+        if len(domain) != len(collected_dim_uses) or any(
+            dim not in collected_dim_uses for dim in domain
+        ):
+            raise ValueError(
+                f"Expected the annotated domain dimensions {domain} to "
+                f"match the set of dimension used by the tensor "
+                f"comprehension {collected_dim_uses}"
+            )
+
+        # Instantiate the dimensions in the given order.
+        with self.context:
+            local_state = AffineBuildState(
+                global_state=self.affine_state, allow_new_symbols=False
+            )
+            for dim in domain:
+                dim.build(state=local_state)
+
+        # Collect all attribute definitions.
+        collected_attr_defs = list()
+        for operand in registered_operands:
+            if operand.is_attribute():
+                collected_attr_defs.append(operand)
+
+        # Collect all tensors with manual indexing annotation.
+        collected_index_defs = list()
+        for operand in registered_operands:
+            if operand.index_dims:
+                if any(dim not in collected_dim_uses for dim in operand.index_dims):
+                    raise ValueError(
+                        f"Expected all index dims {operand.index_dims} of "
+                        f"operand {operand.name} to have uses."
+                    )
+                collected_index_defs.append(operand)
+
+        # Collect the operand definitions of all tensor/scalar uses, attributes, and
+        # shape-only tensors.
+        all_operand_defs = list()
+        for use in collected_tensor_uses:
+            all_operand_defs.append(use.operand_def)
+        for use in collected_scalar_uses:
+            all_operand_defs.append(use.operand_def)
+        for definition in collected_attr_defs:
+            all_operand_defs.append(definition)
+        for definition in collected_index_defs:
+            all_operand_defs.append(definition)
+
+        # Add all operands in registration order to ensure the symbols are
+        # registered in the order they appear.
+        all_operand_defs = sorted(
+            all_operand_defs, key=lambda operand_def: operand_def.registered_index
+        )
+        for operand_def in all_operand_defs:
+            self.add_operand(operand_def)
+
+        # Add all shape-only tensor index_dim annotations and all tensor uses.
+        for definition in collected_index_defs:
+            self.add_indexed_operand(definition)
+        for use in collected_tensor_uses:
+            self.add_tensor_use(use)
+
+        # Normalize all shape and indexing maps now that full count of dims and
+        # symbols are known.
+        for cuse in self.uses.values():
+            cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
+        for definition in collected_index_defs:
+            self.operands[definition].indexing_map = self._normalize_affine_map(
+                self.operands[definition].indexing_map
+            )
+        for operand_config in self.operands.values():
+            if operand_config.shape_map:
+                operand_config.shape_map = self._normalize_affine_map(
+                    operand_config.shape_map, with_dims=False
+                )
+            if operand_config.index_attr_map:
+                operand_config.index_attr_map = self._normalize_affine_map(
+                    operand_config.index_attr_map, with_dims=False
+                )
+
+        # Now for each write use, propagate the indexing maps from the use to the
+        # tensor, ensuring that there are not conflicts.
+        for write_use, _ in self.writes:
+            write_tensor_config = self.operands[write_use.operand_def]
+            if write_tensor_config.indexing_map:
+                raise ValueError(
+                    f"Unexpected multi-write to a single tensor: {write_tensor_config}"
+                )
+            write_tensor_config.indexing_map = self.uses[write_use].indexing_map
+
+        # For each read use, propagate the indexing maps from the use to the
+        # tensor, ensuring that there are not conflicts.
+        for _, read_expr in self.writes:
+            read_uses = set()  # type: Set[TensorUse]
+            read_expr.collect_tensor_uses(read_uses)
+            for read_use in read_uses:
+                read_operand_config = self.operands[read_use.operand_def]
+                if (
+                    read_operand_config.indexing_map
+                    and read_operand_config.indexing_map
+                    != self.uses[read_use].indexing_map
+                ):
+                    raise ValueError(
+                        f"Unexpected multi-read of a tensor with 
diff erent accesses:"
+                        f"{read_operand_config} vs {read_use}"
+                    )
+                read_operand_config.indexing_map = self.uses[read_use].indexing_map
+
+        # Set the indexing map of all scalar uses to the empty map.
+        for operand_config in self.operands.values():
+            if operand_config.operand_def.kind == OperandKind.SCALAR:
+                operand_config.indexing_map = self._get_scalar_map()
+
+        # Check all registered tensor and scalar operands have an indexing map.
+        for operand in registered_operands:
+            if operand.is_attribute():
+                continue
+            if not (operand in self.operands and self.operands[operand].indexing_map):
+                raise ValueError(
+                    f"Failed to compute an indexing map for operand " f"{operand.name}"
+                )
+
+        # Collect reduction dims and ensure all the same.
+        all_reduction_dims = set(comprehension.all_reduction_dims)
+        if len(all_reduction_dims) != 1:
+            raise ValueError(
+                f"All writes within a generic must have the same reduction "
+                f"dims. Got: {all_reduction_dims}"
+            )
+        self.reduction_dims = next(iter(all_reduction_dims))
+
+        # Check the index dimension exists and resolve.
+        for index in collected_indices:
+            if index.dim_def.dimname not in self.affine_state.all_dims:
+                raise ValueError(
+                    f"The dimension {index.dim_def.dimname} is not part of the "
+                    f"iteration domain {self.affine_state.all_dims}"
+                )
+            index.resolve_dimension_name(self.affine_state)
+
+        # Generate the scalar assignments (used to build a body).
+        self.assignments = [
+            ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
+            for write_use, read_expr in self.writes
+        ]
+
+    @property
+    def ordered_operands(self) -> Sequence[OperandDefConfig]:
+        return sorted(
+            self.operands.values(),
+            key=lambda operand: operand.operand_def.registered_index,
+        )
+
+    @property
+    def ordered_dims(self) -> Sequence[Tuple[str, int]]:
+        """Gets the ordered list of dim bindings (symbolic name, position).
+
+        TODO: The original parser relies on parse ordering to arrive at the
+        iterator types, but that ordering is not defined on the Python side, so
+        this may be ambiguous.
+        """
+        return list(self.affine_state.all_dims.items())
+
+    @property
+    def indexing_maps(self) -> Sequence[_ir.AffineMap]:
+        return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
+
+    @property
+    def iterator_types(self) -> Sequence[str]:
+        def get_type(symbolic_name, position):
+            for reduction_dim_expr in self.reduction_dims:
+                if reduction_dim_expr.dimname == symbolic_name:
+                    return "reduction"
+            return "parallel"
+
+        return [get_type(*dim) for dim in self.ordered_dims]
+
+    def add_operand(self, operand_def: OperandDef):
+        if operand_def in self.operands:
+            return
+        if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR):
+            self.operands[operand_def] = OperandDefConfig(operand_def)
+            return
+        with self.context:
+            local_state = AffineBuildState(
+                global_state=self.affine_state, allow_new_dims=False
+            )
+            exprs = []
+            for expr in operand_def.size_exprs:
+                exprs.append(expr.build(state=local_state))
+            assert local_state.local_dim_count == 0
+            affine_map = _ir.AffineMap.get(
+                dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs
+            )
+            if operand_def.kind == OperandKind.INDEX_ATTR:
+                self.operands[operand_def] = OperandDefConfig(
+                    operand_def, index_attr_map=affine_map
+                )
+            else:
+                self.operands[operand_def] = OperandDefConfig(
+                    operand_def, shape_map=affine_map
+                )
+
+    def add_indexed_operand(self, operand_def: OperandDef):
+        with self.context:
+            local_state = AffineBuildState(
+                global_state=self.affine_state, allow_new_symbols=False
+            )
+            exprs = []
+            for expr in operand_def.index_dims:
+                exprs.append(expr.build(state=local_state))
+            self.operands[operand_def].indexing_map = _ir.AffineMap.get(
+                dim_count=local_state.dim_count,
+                symbol_count=local_state.symbol_count,
+                exprs=exprs,
+            )
+
+    def add_tensor_use(self, tensor_use: TensorUse):
+        if tensor_use in self.uses:
+            return
+        with self.context:
+            local_state = AffineBuildState(
+                global_state=self.affine_state, allow_new_symbols=False
+            )
+            exprs = []
+            for expr in tensor_use.indices:
+                exprs.append(expr.build(state=local_state))
+            indexing_map = _ir.AffineMap.get(
+                dim_count=local_state.dim_count,
+                symbol_count=local_state.symbol_count,
+                exprs=exprs,
+            )
+
+            use_config = TensorUseConfig(tensor_use, indexing_map)
+            self.uses[tensor_use] = use_config
+
+    def _get_scalar_map(self) -> _ir.AffineMap:
+        """Create an empty affine map used to index a scalar."""
+        with self.context:
+            return _ir.AffineMap.get(
+                dim_count=self.affine_state.dim_count,
+                symbol_count=self.affine_state.symbol_count,
+                exprs=list(),
+            )
+
+    def _normalize_affine_map(
+        self, affine_map: _ir.AffineMap, with_dims: bool = True
+    ) -> _ir.AffineMap:
+        """Normalizes an indexing map to have the max known symbols and dims."""
+        with self.context:
+            return _ir.AffineMap.get(
+                dim_count=self.affine_state.dim_count if with_dims else 0,
+                symbol_count=self.affine_state.symbol_count,
+                exprs=list(affine_map.results),
+            )
+
+    def to_yaml_custom_dict(self):
+        self_dict = dict(args=self.ordered_operands)
+        # TODO: Refactor the hierarchy internally when supporting more
+        # than static (preserving this serialized form).
+        self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
+            static_indexing_maps=self.indexing_maps
+        )
+        self_dict["iterator_types"] = self.iterator_types
+        self_dict["assignments"] = self.assignments
+        return self_dict
+
+    def __repr__(self):
+        lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
+        lines.append("operands=[")
+        for def_config in self.ordered_operands:
+            lines.append(f"  {repr(def_config)}")
+        lines.append("], indexing_maps=[")
+        for m in self.indexing_maps:
+            lines.append(f"  {repr(m)}")
+        lines.append(f"], iterator_types=[")
+        for t in self.iterator_types:
+            lines.append(f"  {t}")
+        lines.append("])")
+        return "\n".join(lines)
 
 
 class LinalgOpConfig(YAMLObject):
-  """Container for any supported linalg op type.
-
-  This includes the concrete type by name for ease of parsing by systems
-  that ignore tags.
-  """
-  yaml_tag = "!LinalgOpConfig"
-
-  def __init__(self,
-               metadata: OpMetadataDef,
-               *,
-               structured_op: Optional[LinalgStructuredOpConfig] = None):
-    self.metadata = metadata
-    self.structured_op = structured_op
-
-  def to_yaml_custom_dict(self):
-    self_dict = dict(metadata=self.metadata,)
-    if self.structured_op:
-      self_dict["structured_op"] = self.structured_op
-    return self_dict
-
-  @staticmethod
-  def from_linalg_op_def(
-      op_def: LinalgOpDef,
-      context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
-    """Expands a LinalgOpDef into corresponding Linalg configured ops."""
-    # TODO: Many LinalgOpDef patterns need to expand to multiple generics.
-    assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
-    return [
-        LinalgOpConfig(
-            op_def.metadata,
-            structured_op=LinalgStructuredOpConfig(
-                op_def.comprehensions[0], op_def.domain,
-                op_def.registered_operands.values(), context)),
-    ]
-
-  def __repr__(self):
-    return (f"LinalgOpConfig(metadata={self.metadata},\n"
-            f"structured_op={self.structured_op})")
+    """Container for any supported linalg op type.
+
+    This includes the concrete type by name for ease of parsing by systems
+    that ignore tags.
+    """
+
+    yaml_tag = "!LinalgOpConfig"
+
+    def __init__(
+        self,
+        metadata: OpMetadataDef,
+        *,
+        structured_op: Optional[LinalgStructuredOpConfig] = None,
+    ):
+        self.metadata = metadata
+        self.structured_op = structured_op
+
+    def to_yaml_custom_dict(self):
+        self_dict = dict(
+            metadata=self.metadata,
+        )
+        if self.structured_op:
+            self_dict["structured_op"] = self.structured_op
+        return self_dict
+
+    @staticmethod
+    def from_linalg_op_def(
+        op_def: LinalgOpDef, context: Optional[_ir.Context] = None
+    ) -> Sequence["LinalgOpConfig"]:
+        """Expands a LinalgOpDef into corresponding Linalg configured ops."""
+        # TODO: Many LinalgOpDef patterns need to expand to multiple generics.
+        assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
+        return [
+            LinalgOpConfig(
+                op_def.metadata,
+                structured_op=LinalgStructuredOpConfig(
+                    op_def.comprehensions[0],
+                    op_def.domain,
+                    op_def.registered_operands.values(),
+                    context,
+                ),
+            ),
+        ]
+
+    def __repr__(self):
+        return (
+            f"LinalgOpConfig(metadata={self.metadata},\n"
+            f"structured_op={self.structured_op})"
+        )

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 45b8d5ccd13d6..8b8726f8f9a03 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -10,160 +10,192 @@
 import threading
 
 from ..... import ir
-from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ...._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+)
 from .comprehension import *
 from .config import *
 from .emitter import *
 
 _CONTEXT = threading.local()
 
-StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
-                         Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
+StructuredOpOuts = Union[
+    ir.Operation,
+    ir.OpView,
+    ir.OpResultList,
+    Sequence[Union[ir.Value, ir.Operation, ir.OpView]],
+]
 
 
 @contextmanager
 def bind_op_def(op_def: LinalgOpDef):
-  if hasattr(_CONTEXT, "current_op_def"):
-    raise ValueError("Cannot recursively define an operation")
-  _CONTEXT.current_op_def = op_def
-  try:
-    yield op_def
-  finally:
-    del _CONTEXT.current_op_def
+    if hasattr(_CONTEXT, "current_op_def"):
+        raise ValueError("Cannot recursively define an operation")
+    _CONTEXT.current_op_def = op_def
+    try:
+        yield op_def
+    finally:
+        del _CONTEXT.current_op_def
 
 
 def current_op_def() -> LinalgOpDef:
-  try:
-    return _CONTEXT.current_op_def
-  except AttributeError:
-    raise ValueError(
-        "Attempt to access the current op definition being defined "
-        "but none is set. Did you mean to call this in an op definition?")
+    try:
+        return _CONTEXT.current_op_def
+    except AttributeError:
+        raise ValueError(
+            "Attempt to access the current op definition being defined "
+            "but none is set. Did you mean to call this in an op definition?"
+        )
 
 
 def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
-  if isinstance(outs, (ir.Operation, ir.OpView)):
-    return _get_op_results_or_values(outs)
-  elif isinstance(outs, ir.OpResultList):
-    return outs
+    if isinstance(outs, (ir.Operation, ir.OpView)):
+        return _get_op_results_or_values(outs)
+    elif isinstance(outs, ir.OpResultList):
+        return outs
 
-  return [_get_op_result_or_value(o) for o in outs]
+    return [_get_op_result_or_value(o) for o in outs]
 
 
 class DefinedOpCallable:
-  """Callable that wraps any defined op function."""
-
-  def __init__(self, op_name: str, op_def: LinalgOpDef):
-    self.op_name = op_name
-    self.op_def = op_def
-
-  def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
-               outs: StructuredOpOuts, **kwargs):
-    """Emits the corresponding op definition as IR.
-
-    Most arguments are passed through to the underlying emitter. The following
-    keyword argument is interpreted here:
-      emit_generic: Emits a generic form as appropriate (default True). If
-        False, a named form is emitted (which must have been built in to the
-        compiler).
-    """
-    emit_generic = kwargs.pop("emit_generic", False)
-    if not isinstance(emit_generic, bool):
-      raise ValueError(f"The named argument 'emit_generic' needs to be "
-                       f" of type bool but got {type(emit_generic)}")
-
-    op_configs = LinalgOpConfig.from_linalg_op_def(
-        self.op_def, context=ir.Context.current)
-
-    if len(op_configs) != 1:
-      # TODO: Support composite ops.
-      raise NotImplementedError(
-          f"Emission of composite linalg ops not supported: {op_configs}")
-
-    ctx = ir.Context.current
-    linalgDialect = ctx.get_dialect_descriptor("linalg")
-    fully_qualified_name = "linalg." + self.op_name
-    emit_generic = (
-        emit_generic or not ctx.is_registered_operation(fully_qualified_name))
-
-    op_config = op_configs[0]
-    out_values = _prepare_structured_op_outs(outs)
-    in_values = [_get_op_result_or_value(i) for i in ins]
-    if op_config.structured_op:
-      if emit_generic:
-        return emit_generic_structured_op(
-            op_config.structured_op, *in_values, outs=out_values, **kwargs)
-      else:
-        return emit_named_structured_op(
-            op_config.structured_op,
-            self.op_name,
-            self.op_def.metadata.cpp_class_name,
-            *in_values,
-            outs=out_values,
-            **kwargs)
-
-    raise NotImplementedError(
-        f"Emission of linalg op type not supported: {op_config}")
-
-
-def linalg_structured_op(dsl_func=None,
-                         *,
-                         op_name=None,
-                         op_class_name=None) -> DefinedOpCallable:
-  if dsl_func is None:
-    # Curry the keyword args in for delayed application.
-    return functools.partial(
-        linalg_structured_op, op_name=op_name, op_class_name=op_class_name)
-  # Determine default names by introspecting the function.
-  if op_name is None:
-    op_name = dsl_func.__name__
-  if op_class_name is None:
-    # Camel case it.
-    op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
-
-  op_def = LinalgOpDef(
-      name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func))
-
-  # Extract arguments and TensorDefs from the signature.
-  dsl_func_args = list()
-  sig = inspect.signature(dsl_func)
-  for param_name, param in sig.parameters.items():
-    param_default = param.default
-    if isinstance(param_default,
-                  (TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef,
-                   BinaryFnAttrDef, TypeFnAttrDef)):
-      op_def.add_operand(param_name, param_default.operand_def)
-    else:
-      raise ValueError(
-          f"@linalg_structured_op function parameters must be defaulted as "
-          f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
-          f"Found {param_name}: {param_default}")
-    dsl_func_args.append(param_default)
-
-  # Invoke the DSL func to finish populating the op definition.
-  with bind_op_def(op_def):
-    dsl_func(*dsl_func_args)
-
-  # TODO: The returned callable should be an IR emitter but that is not
-  # upstreamed yet.
-  return DefinedOpCallable(op_name, op_def)
+    """Callable that wraps any defined op function."""
+
+    def __init__(self, op_name: str, op_def: LinalgOpDef):
+        self.op_name = op_name
+        self.op_def = op_def
+
+    def __call__(
+        self,
+        *ins: Union[ir.Operation, ir.OpView, ir.Value],
+        outs: StructuredOpOuts,
+        **kwargs,
+    ):
+        """Emits the corresponding op definition as IR.
+
+        Most arguments are passed through to the underlying emitter. The following
+        keyword argument is interpreted here:
+          emit_generic: Emits a generic form as appropriate (default True). If
+            False, a named form is emitted (which must have been built in to the
+            compiler).
+        """
+        emit_generic = kwargs.pop("emit_generic", False)
+        if not isinstance(emit_generic, bool):
+            raise ValueError(
+                f"The named argument 'emit_generic' needs to be "
+                f" of type bool but got {type(emit_generic)}"
+            )
+
+        op_configs = LinalgOpConfig.from_linalg_op_def(
+            self.op_def, context=ir.Context.current
+        )
+
+        if len(op_configs) != 1:
+            # TODO: Support composite ops.
+            raise NotImplementedError(
+                f"Emission of composite linalg ops not supported: {op_configs}"
+            )
+
+        ctx = ir.Context.current
+        linalgDialect = ctx.get_dialect_descriptor("linalg")
+        fully_qualified_name = "linalg." + self.op_name
+        emit_generic = emit_generic or not ctx.is_registered_operation(
+            fully_qualified_name
+        )
+
+        op_config = op_configs[0]
+        out_values = _prepare_structured_op_outs(outs)
+        in_values = [_get_op_result_or_value(i) for i in ins]
+        if op_config.structured_op:
+            if emit_generic:
+                return emit_generic_structured_op(
+                    op_config.structured_op, *in_values, outs=out_values, **kwargs
+                )
+            else:
+                return emit_named_structured_op(
+                    op_config.structured_op,
+                    self.op_name,
+                    self.op_def.metadata.cpp_class_name,
+                    *in_values,
+                    outs=out_values,
+                    **kwargs,
+                )
+
+        raise NotImplementedError(
+            f"Emission of linalg op type not supported: {op_config}"
+        )
+
+
+def linalg_structured_op(
+    dsl_func=None, *, op_name=None, op_class_name=None
+) -> DefinedOpCallable:
+    if dsl_func is None:
+        # Curry the keyword args in for delayed application.
+        return functools.partial(
+            linalg_structured_op, op_name=op_name, op_class_name=op_class_name
+        )
+    # Determine default names by introspecting the function.
+    if op_name is None:
+        op_name = dsl_func.__name__
+    if op_class_name is None:
+        # Camel case it.
+        op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
+
+    op_def = LinalgOpDef(
+        name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)
+    )
+
+    # Extract arguments and TensorDefs from the signature.
+    dsl_func_args = list()
+    sig = inspect.signature(dsl_func)
+    for param_name, param in sig.parameters.items():
+        param_default = param.default
+        if isinstance(
+            param_default,
+            (
+                TensorDef,
+                ScalarDef,
+                IndexAttrDef,
+                UnaryFnAttrDef,
+                BinaryFnAttrDef,
+                TypeFnAttrDef,
+            ),
+        ):
+            op_def.add_operand(param_name, param_default.operand_def)
+        else:
+            raise ValueError(
+                f"@linalg_structured_op function parameters must be defaulted as "
+                f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
+                f"Found {param_name}: {param_default}"
+            )
+        dsl_func_args.append(param_default)
+
+    # Invoke the DSL func to finish populating the op definition.
+    with bind_op_def(op_def):
+        dsl_func(*dsl_func_args)
+
+    # TODO: The returned callable should be an IR emitter but that is not
+    # upstreamed yet.
+    return DefinedOpCallable(op_name, op_def)
 
 
 def domain(*dimensions: DimDef):
-  if any(not isinstance(d, DimDef) for d in dimensions):
-    raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
-  current_op_def().domain.extend(dimensions)
+    if any(not isinstance(d, DimDef) for d in dimensions):
+        raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
+    current_op_def().domain.extend(dimensions)
 
 
 def implements(*interfaces: OpInterfaceDef):
-  if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
-    raise ValueError(
-        f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
-  current_op_def().metadata.implements.extend(interfaces)
+    if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
+        raise ValueError(
+            f"Expected interfaces of type OpInterfaceDef but got {interfaces}"
+        )
+    current_op_def().metadata.implements.extend(interfaces)
 
 
 def defines(*definitions: OpDefinitionDef):
-  if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
-    raise ValueError(
-        f"Expected definitions of type OpDefinitionDef but got {definitions}")
-  current_op_def().metadata.defines.extend(definitions)
+    if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
+        raise ValueError(
+            f"Expected definitions of type OpDefinitionDef but got {definitions}"
+        )
+    current_op_def().metadata.defines.extend(definitions)

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index b63cb40717aed..62730d9ca4d8e 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -11,7 +11,10 @@
 from .... import math
 from .... import arith
 from .... import complex
-from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ...._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+)
 
 from .scalar_expr import *
 from .config import *
@@ -29,529 +32,618 @@
 
 
 def isa(cls: Type, ty: Type):
-  try:
-    cls(ty)
-    return True
-  except ValueError:
-    return False
-
-
-def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
-                                 *ins: Value, outs: ValueList,
-                                 **attrs: Union[Sequence[int], TypeFnType]):
-  all_arg_defs = op_config.ordered_operands
-  in_arg_defs = [
-      d for d in all_arg_defs
-      if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
-  ]
-  out_arg_defs = [
-      d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
-  ]
-  index_attr_arg_defs = [
-      d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
-  ]
-  fn_attr_arg_defs = [
-      d for d in all_arg_defs if d.kind in [
-          OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR,
-          OperandKind.TYPE_FN_ATTR
-      ]
-  ]
-
-  # Verify outs is a sequence or a list of results.
-  if not isinstance(outs, (Sequence, OpResultList)):
-    raise ValueError(f"Expected named argument outs to have type Sequence or "
-                     f"OpResultLis but got {type(outs)}")
-
-  # Arity validation.
-  if len(ins) != len(in_arg_defs):
-    raise ValueError(f"Expected {len(in_arg_defs)} inputs but got "
-                     f"{len(ins)} for {op_config}")
-  if outs and len(outs) != len(out_arg_defs):
-    raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
-                     f"{len(outs)} for {op_config}")
-
-  # Compute a replacement list for all index attribute symbols.
-  expressions = []  # type: Sequence[AffineExpr]
-  replacements = []  # type: Sequence[AffineExpr]
-  for index_attr in index_attr_arg_defs:
-    index_attr_vals = index_attr.operand_def.default_indices
-    if index_attr.name in attrs:
-      index_attr_vals = attrs.get(index_attr.name)
-    assert index_attr_vals, "Index attribute has no value"
-    if not all(isinstance(value, int) for value in index_attr_vals):
-      raise ValueError(f"Attribute {index_attr.name} needs to be of type "
-                       f"Sequence[int] but got {type(index_attr_vals)}")
-    results = index_attr.index_attr_map.results  # type: AffineExprList
-    if len(index_attr_vals) != len(results):
-      raise ValueError(f"Attribute {index_attr.name} has length {len(results)} "
-                       f"but got {len(index_attr_vals)} values")
-    for expr, value in zip(results, index_attr_vals):
-      expressions.append(expr)
-      replacements.append(AffineConstantExpr.get(value))
-
-  # Replace all index attribute symbols by their value.
-  # TODO: Add support for shape symbols.
-  indexing_maps = []  # type: Sequence[AffineMap]
-  for curr in op_config.indexing_maps:
-    for expression, replacement in zip(expressions, replacements):
-      curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
-    indexing_maps.append(curr)
-
-  # TODO: Linalg verification does not currently allow symbols.
-  # Compress them for now and verify none are left.
-  indexing_maps = AffineMap.compress_unused_symbols(indexing_maps,
-                                                    Context.current)
-  if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
-    raise ValueError(f"Expected indexing_maps to use no symbols after "
-                     f"replacement and compression but got {indexing_maps}")
-
-  outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
-                                           out_arg_defs, outs)
-
-  result_types = [t for t in out_types if isa(RankedTensorType, t)]
-
-  # Initialize the type dictionary with the predefined types.
-  type_mapping = dict()  # type: Dict[str, Type]
-  type_mapping["F32"] = F32Type.get()
-  type_mapping["F64"] = F64Type.get()
-  type_mapping["I32"] = IntegerType.get_signless(32)
-  type_mapping["I64"] = IntegerType.get_signless(64)
-
-  # Extract type vars for input/output based types.
-  block_arg_types = list()  # type: List[Type]
-  for arg_def, arg_element_type in zip(in_arg_defs + out_arg_defs,
-                                       _get_types_from_values(*ins, *outs)):
-    _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types)
-
-  # Emit the generic op.
-  # TODO: Support emission of pure memref form.
-  indexing_maps_attr = ArrayAttr.get(
-      [AffineMapAttr.get(am) for am in indexing_maps])
-  iterator_types_attr = ArrayAttr.get([
-      Attribute.parse(f"#linalg.iterator_type<{s}>")
-      for s in op_config.iterator_types
-  ])
-
-  # Compute the index attributes used when emitting a named structured op.
-  index_attrs = {}  # type: Dict[str, DenseElementAttr]
-  for index_attr in index_attr_arg_defs:
-    index_attr_vals = attrs.get(index_attr.name)
-    # Only forward attributes set to a non-default value.
-    if index_attr_vals:
-      array = np.array(index_attr_vals, dtype=np.int64)
-      index_attrs[index_attr.name] = DenseElementsAttr.get(array)
-
-  # Compute the function attribute mapping.
-  fn_attr_mapping = {}
-  for fn_attr in fn_attr_arg_defs:
-    attr_val = fn_attr.operand_def.default_fn
-    attr_kind = fn_attr.kind
-    if fn_attr.name in attrs:
-      fn = attrs.get(fn_attr.name)
-      if attr_kind == OperandKind.UNARY_FN_ATTR:
-        if not isinstance(fn, UnaryFnType):
-          raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
-                           f"UnaryFnType but got {type(attr_val)}")
-      elif attr_kind == OperandKind.BINARY_FN_ATTR:
-        if not isinstance(fn, BinaryFnType):
-          raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
-                           f"BinaryFnType but got {type(attr_val)}")
-      else:
-        if not isinstance(fn, TypeFnType):
-          raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
-                           f"TypeFnType but got {type(attr_val)}")
-      attr_val = fn.fn_name
-    assert attr_val, "Function attribute has no value"
-    fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
-
-  return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
-          type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
-          fn_attr_mapping, block_arg_types)
-
-
-def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
-                               outs: ValueList, **attrs: Sequence[int]):
-  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
-  block_arg_types = \
-     prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
-
-  # An operation that accesses only scalars and scalar/rank zero tensors is
-  # rank polymorhpic. We implement rank polymorphism by generating 
diff erent
-  # indexing maps and iterators that match the rank of the first output tensor.
-  # An operation is rank polymorphic if the iteration domain has rank zero.
-  if not iterator_types_attr:
-    rank = ShapedType(outs[0].type).rank
-    iterator_types_attr = ArrayAttr.get(
-        [Attribute.parse("#linalg.iterator_type<parallel>")] * rank)
-    scalar_map = AffineMap.get(rank, 0, [])
-    tensor_map = AffineMap.get_identity(rank)
-    indexing_maps = []
-    for arg_def in all_arg_defs:
-      if arg_def.operand_def.kind == OperandKind.SCALAR:
-        indexing_maps.append(scalar_map)
-      if arg_def.operand_def.is_tensor():
-        idx = arg_def.operand_def.registered_index
-        if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
-          indexing_maps.append(scalar_map)
-        else:
-          indexing_maps.append(tensor_map)
-    indexing_maps_attr = ArrayAttr.get(
-        [AffineMapAttr.get(am) for am in indexing_maps])
-
-  generic_op = linalg.GenericOp(
-      result_tensors=result_types,
-      inputs=ins,
-      outputs=outs,
-      indexing_maps=indexing_maps_attr,
-      iterator_types=iterator_types_attr,
-      doc=None,  # TODO: Make optional.
-      library_call=None)  # TODO: Make optional.
-
-  # Construct the body.
-  block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs)
-  block = generic_op.regions[0].blocks.append(*block_arg_types)
-  block_arg_mapping = dict(zip(block_arg_names, block.arguments))
-  with InsertionPoint(block):
-    body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
-                                fn_attr_mapping)
-    for assignment in op_config.assignments:
-      body_builder.assign(assignment)
-    body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
-
-  if len(result_types) == 1:
-    return generic_op.result
-  else:
-    return generic_op.results
-
-
-def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
-                             op_class_name: str, *ins: Value, outs: ValueList,
-                             **attrs: Sequence[int]):
-  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
-  block_arg_types = \
-     prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
-
-  # If we get here, there must exist a builtin class `op_class_name`.
-  ctx = Context.current
-  fully_qualified_name = "linalg." + op_name
-  if (not ctx.is_registered_operation(fully_qualified_name) or
-      not op_class_name in linalg.__dict__.keys()):
-    raise NotImplementedError(
-        f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
-
-  # Set the index attributes used to compute the indexing maps.
-  named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
-  for name, value in index_attrs.items():
-    named_op.operation.attributes[name] = value
-
-  # Compute the function attributes by combining operand kind and function name.
-  for name, (fn_name, kind) in fn_attr_mapping.items():
-    assert kind.name.lower().endswith("_attr")
-    enum_name = kind.name.lower()[:-5]
-    named_op.operation.attributes[name] = Attribute.parse(
-        f"#linalg.{enum_name}<{fn_name}>")
+    try:
+        cls(ty)
+        return True
+    except ValueError:
+        return False
 
-  linalg.fill_builtin_region(named_op.operation)
 
-  if len(result_types) == 1:
-    return named_op.result
-  else:
-    return named_op.results
+def prepare_common_structured_op(
+    op_config: LinalgStructuredOpConfig,
+    *ins: Value,
+    outs: ValueList,
+    **attrs: Union[Sequence[int], TypeFnType],
+):
+    all_arg_defs = op_config.ordered_operands
+    in_arg_defs = [
+        d
+        for d in all_arg_defs
+        if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
+    ]
+    out_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR]
+    index_attr_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR]
+    fn_attr_arg_defs = [
+        d
+        for d in all_arg_defs
+        if d.kind
+        in [
+            OperandKind.UNARY_FN_ATTR,
+            OperandKind.BINARY_FN_ATTR,
+            OperandKind.TYPE_FN_ATTR,
+        ]
+    ]
+
+    # Verify outs is a sequence or a list of results.
+    if not isinstance(outs, (Sequence, OpResultList)):
+        raise ValueError(
+            f"Expected named argument outs to have type Sequence or "
+            f"OpResultLis but got {type(outs)}"
+        )
+
+    # Arity validation.
+    if len(ins) != len(in_arg_defs):
+        raise ValueError(
+            f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}"
+        )
+    if outs and len(outs) != len(out_arg_defs):
+        raise ValueError(
+            f"Expected {len(out_arg_defs)} outputs but got "
+            f"{len(outs)} for {op_config}"
+        )
+
+    # Compute a replacement list for all index attribute symbols.
+    expressions = []  # type: Sequence[AffineExpr]
+    replacements = []  # type: Sequence[AffineExpr]
+    for index_attr in index_attr_arg_defs:
+        index_attr_vals = index_attr.operand_def.default_indices
+        if index_attr.name in attrs:
+            index_attr_vals = attrs.get(index_attr.name)
+        assert index_attr_vals, "Index attribute has no value"
+        if not all(isinstance(value, int) for value in index_attr_vals):
+            raise ValueError(
+                f"Attribute {index_attr.name} needs to be of type "
+                f"Sequence[int] but got {type(index_attr_vals)}"
+            )
+        results = index_attr.index_attr_map.results  # type: AffineExprList
+        if len(index_attr_vals) != len(results):
+            raise ValueError(
+                f"Attribute {index_attr.name} has length {len(results)} "
+                f"but got {len(index_attr_vals)} values"
+            )
+        for expr, value in zip(results, index_attr_vals):
+            expressions.append(expr)
+            replacements.append(AffineConstantExpr.get(value))
+
+    # Replace all index attribute symbols by their value.
+    # TODO: Add support for shape symbols.
+    indexing_maps = []  # type: Sequence[AffineMap]
+    for curr in op_config.indexing_maps:
+        for expression, replacement in zip(expressions, replacements):
+            curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
+        indexing_maps.append(curr)
+
+    # TODO: Linalg verification does not currently allow symbols.
+    # Compress them for now and verify none are left.
+    indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, Context.current)
+    if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
+        raise ValueError(
+            f"Expected indexing_maps to use no symbols after "
+            f"replacement and compression but got {indexing_maps}"
+        )
+
+    outs, out_types = _infer_structured_outs(
+        op_config, in_arg_defs, ins, out_arg_defs, outs
+    )
+
+    result_types = [t for t in out_types if isa(RankedTensorType, t)]
+
+    # Initialize the type dictionary with the predefined types.
+    type_mapping = dict()  # type: Dict[str, Type]
+    type_mapping["F32"] = F32Type.get()
+    type_mapping["F64"] = F64Type.get()
+    type_mapping["I32"] = IntegerType.get_signless(32)
+    type_mapping["I64"] = IntegerType.get_signless(64)
+
+    # Extract type vars for input/output based types.
+    block_arg_types = list()  # type: List[Type]
+    for arg_def, arg_element_type in zip(
+        in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs)
+    ):
+        _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types)
+
+    # Emit the generic op.
+    # TODO: Support emission of pure memref form.
+    indexing_maps_attr = ArrayAttr.get([AffineMapAttr.get(am) for am in indexing_maps])
+    iterator_types_attr = ArrayAttr.get(
+        [
+            Attribute.parse(f"#linalg.iterator_type<{s}>")
+            for s in op_config.iterator_types
+        ]
+    )
+
+    # Compute the index attributes used when emitting a named structured op.
+    index_attrs = {}  # type: Dict[str, DenseElementAttr]
+    for index_attr in index_attr_arg_defs:
+        index_attr_vals = attrs.get(index_attr.name)
+        # Only forward attributes set to a non-default value.
+        if index_attr_vals:
+            array = np.array(index_attr_vals, dtype=np.int64)
+            index_attrs[index_attr.name] = DenseElementsAttr.get(array)
+
+    # Compute the function attribute mapping.
+    fn_attr_mapping = {}
+    for fn_attr in fn_attr_arg_defs:
+        attr_val = fn_attr.operand_def.default_fn
+        attr_kind = fn_attr.kind
+        if fn_attr.name in attrs:
+            fn = attrs.get(fn_attr.name)
+            if attr_kind == OperandKind.UNARY_FN_ATTR:
+                if not isinstance(fn, UnaryFnType):
+                    raise ValueError(
+                        f"Attribute {fn_attr.name} needs to be of type "
+                        f"UnaryFnType but got {type(attr_val)}"
+                    )
+            elif attr_kind == OperandKind.BINARY_FN_ATTR:
+                if not isinstance(fn, BinaryFnType):
+                    raise ValueError(
+                        f"Attribute {fn_attr.name} needs to be of type "
+                        f"BinaryFnType but got {type(attr_val)}"
+                    )
+            else:
+                if not isinstance(fn, TypeFnType):
+                    raise ValueError(
+                        f"Attribute {fn_attr.name} needs to be of type "
+                        f"TypeFnType but got {type(attr_val)}"
+                    )
+            attr_val = fn.fn_name
+        assert attr_val, "Function attribute has no value"
+        fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
+
+    return (
+        all_arg_defs,
+        in_arg_defs,
+        out_arg_defs,
+        outs,
+        result_types,
+        type_mapping,
+        indexing_maps_attr,
+        iterator_types_attr,
+        index_attrs,
+        fn_attr_mapping,
+        block_arg_types,
+    )
+
+
+def emit_generic_structured_op(
+    op_config: LinalgStructuredOpConfig,
+    *ins: Value,
+    outs: ValueList,
+    **attrs: Sequence[int],
+):
+    (
+        all_arg_defs,
+        in_arg_defs,
+        out_arg_defs,
+        outs,
+        result_types,
+        type_mapping,
+        indexing_maps_attr,
+        iterator_types_attr,
+        index_attrs,
+        fn_attr_mapping,
+        block_arg_types,
+    ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)
+
+    # An operation that accesses only scalars and scalar/rank zero tensors is
+    # rank polymorhpic. We implement rank polymorphism by generating 
diff erent
+    # indexing maps and iterators that match the rank of the first output tensor.
+    # An operation is rank polymorphic if the iteration domain has rank zero.
+    if not iterator_types_attr:
+        rank = ShapedType(outs[0].type).rank
+        iterator_types_attr = ArrayAttr.get(
+            [Attribute.parse("#linalg.iterator_type<parallel>")] * rank
+        )
+        scalar_map = AffineMap.get(rank, 0, [])
+        tensor_map = AffineMap.get_identity(rank)
+        indexing_maps = []
+        for arg_def in all_arg_defs:
+            if arg_def.operand_def.kind == OperandKind.SCALAR:
+                indexing_maps.append(scalar_map)
+            if arg_def.operand_def.is_tensor():
+                idx = arg_def.operand_def.registered_index
+                if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
+                    indexing_maps.append(scalar_map)
+                else:
+                    indexing_maps.append(tensor_map)
+        indexing_maps_attr = ArrayAttr.get(
+            [AffineMapAttr.get(am) for am in indexing_maps]
+        )
+
+    generic_op = linalg.GenericOp(
+        result_tensors=result_types,
+        inputs=ins,
+        outputs=outs,
+        indexing_maps=indexing_maps_attr,
+        iterator_types=iterator_types_attr,
+        doc=None,  # TODO: Make optional.
+        library_call=None,
+    )  # TODO: Make optional.
+
+    # Construct the body.
+    block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs)
+    block = generic_op.regions[0].blocks.append(*block_arg_types)
+    block_arg_mapping = dict(zip(block_arg_names, block.arguments))
+    with InsertionPoint(block):
+        body_builder = _BodyBuilder(type_mapping, block_arg_mapping, fn_attr_mapping)
+        for assignment in op_config.assignments:
+            body_builder.assign(assignment)
+        body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
+
+    if len(result_types) == 1:
+        return generic_op.result
+    else:
+        return generic_op.results
+
+
+def emit_named_structured_op(
+    op_config: LinalgStructuredOpConfig,
+    op_name: str,
+    op_class_name: str,
+    *ins: Value,
+    outs: ValueList,
+    **attrs: Sequence[int],
+):
+    (
+        all_arg_defs,
+        in_arg_defs,
+        out_arg_defs,
+        outs,
+        result_types,
+        type_mapping,
+        indexing_maps_attr,
+        iterator_types_attr,
+        index_attrs,
+        fn_attr_mapping,
+        block_arg_types,
+    ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)
+
+    # If we get here, there must exist a builtin class `op_class_name`.
+    ctx = Context.current
+    fully_qualified_name = "linalg." + op_name
+    if (
+        not ctx.is_registered_operation(fully_qualified_name)
+        or not op_class_name in linalg.__dict__.keys()
+    ):
+        raise NotImplementedError(
+            f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}"
+        )
+
+    # Set the index attributes used to compute the indexing maps.
+    named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
+    for name, value in index_attrs.items():
+        named_op.operation.attributes[name] = value
+
+    # Compute the function attributes by combining operand kind and function name.
+    for name, (fn_name, kind) in fn_attr_mapping.items():
+        assert kind.name.lower().endswith("_attr")
+        enum_name = kind.name.lower()[:-5]
+        named_op.operation.attributes[name] = Attribute.parse(
+            f"#linalg.{enum_name}<{fn_name}>"
+        )
+
+    linalg.fill_builtin_region(named_op.operation)
+
+    if len(result_types) == 1:
+        return named_op.result
+    else:
+        return named_op.results
 
 
 class _BodyBuilder:
-  """Constructs a structured op body by evaluating assignments."""
-
-  def __init__(self, type_mapping: Dict[str, Type],
-               block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str,
-                                                                          str]):
-    self.type_mapping = type_mapping
-    self.block_arg_mapping = block_arg_mapping
-    self.fn_attr_mapping = fn_attr_mapping
-    self.yield_mapping = dict()  # type: Dict[str, Value]
-
-  def assign(self, assignment: ScalarAssign):
-    if assignment.arg in self.yield_mapping:
-      raise ValueError(
-          f"Multiple assignments to the same argument are forbidden: "
-          f"{assignment}")
-    self.yield_mapping[assignment.arg] = self.expression(assignment.value)
-
-  def expression(self, expr: ScalarExpression) -> Value:
-    if expr.scalar_arg:
-      try:
-        return self.block_arg_mapping[expr.scalar_arg.arg]
-      except KeyError:
-        raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for "
-                         f"this structured op.")
-    elif expr.scalar_const:
-      value_attr = Attribute.parse(expr.scalar_const.value)
-      return arith.ConstantOp(value_attr.type, value_attr).result
-    elif expr.scalar_index:
-      dim_attr = IntegerAttr.get(
-          IntegerType.get_signless(64), expr.scalar_index.dim)
-      return linalg.IndexOp(dim_attr).result
-    elif expr.scalar_fn:
-      kind = expr.scalar_fn.kind.name.lower()
-      fn_name = expr.scalar_fn.fn_name
-      if expr.scalar_fn.attr_name:
-        fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
-      fn = self._get_function(f"_{kind}_{fn_name}")
-      operand_values = [
-          self.expression(operand) for operand in expr.scalar_fn.operands
-      ]
-      if expr.scalar_fn.kind == FunctionKind.TYPE:
-        operand_values = [expr.scalar_fn.type_var.name] + operand_values
-      return fn(*operand_values)
-    raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
-
-  def yield_outputs(self, *output_names: str):
-    output_values = []
-    for n in output_names:
-      try:
-        output_values.append(self.yield_mapping[n])
-      except KeyError:
-        raise ValueError(f"Body assignments do not assign all outputs: "
-                         f"missing '{n}'")
-    linalg.YieldOp(output_values)
-
-  def _get_function(self, fn_name: str) -> Callable:
-    try:
-      fn = getattr(self, f"{fn_name}")
-    except AttributeError:
-      raise ValueError(f"Function '{fn_name}' is not a known function")
-    return fn
-
-  def _cast(self,
-            type_var_name: str,
-            operand: Value,
-            is_unsigned_cast: bool = False) -> Value:
-    try:
-      to_type = self.type_mapping[type_var_name]
-    except KeyError:
-      raise ValueError(f"Unbound type variable '{type_var_name}' ("
-                       f"expected one of {self.type_mapping.keys()}")
-    if operand.type == to_type:
-      return operand
-    if _is_integer_type(to_type):
-      return self._cast_to_integer(to_type, operand, is_unsigned_cast)
-    elif _is_floating_point_type(to_type):
-      return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
-
-  def _cast_to_integer(self, to_type: Type, operand: Value,
-                       is_unsigned_cast: bool) -> Value:
-    to_width = IntegerType(to_type).width
-    operand_type = operand.type
-    if _is_floating_point_type(operand_type):
-      if is_unsigned_cast:
-        return arith.FPToUIOp(to_type, operand).result
-      return arith.FPToSIOp(to_type, operand).result
-    if _is_index_type(operand_type):
-      return arith.IndexCastOp(to_type, operand).result
-    # Assume integer.
-    from_width = IntegerType(operand_type).width
-    if to_width > from_width:
-      if is_unsigned_cast:
-        return arith.ExtUIOp(to_type, operand).result
-      return arith.ExtSIOp(to_type, operand).result
-    elif to_width < from_width:
-      return arith.TruncIOp(to_type, operand).result
-    raise ValueError(f"Unable to cast body expression from {operand_type} to "
-                     f"{to_type}")
-
-  def _cast_to_floating_point(self, to_type: Type, operand: Value,
-                              is_unsigned_cast: bool) -> Value:
-    operand_type = operand.type
-    if _is_integer_type(operand_type):
-      if is_unsigned_cast:
-        return arith.UIToFPOp(to_type, operand).result
-      return arith.SIToFPOp(to_type, operand).result
-    # Assume FloatType.
-    to_width = _get_floating_point_width(to_type)
-    from_width = _get_floating_point_width(operand_type)
-    if to_width > from_width:
-      return arith.ExtFOp(to_type, operand).result
-    elif to_width < from_width:
-      return arith.TruncFOp(to_type, operand).result
-    raise ValueError(f"Unable to cast body expression from {operand_type} to "
-                     f"{to_type}")
-
-  def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
-    return self._cast(type_var_name, operand, False)
-
-  def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
-    return self._cast(type_var_name, operand, True)
-
-  def _unary_exp(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.ExpOp(x).result
-    raise NotImplementedError("Unsupported 'exp' operand: {x}")
-
-  def _unary_log(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.LogOp(x).result
-    raise NotImplementedError("Unsupported 'log' operand: {x}")
-
-  def _unary_abs(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.AbsFOp(x).result
-    raise NotImplementedError("Unsupported 'abs' operand: {x}")
-
-  def _unary_ceil(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.CeilOp(x).result
-    raise NotImplementedError("Unsupported 'ceil' operand: {x}")
-
-  def _unary_floor(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.FloorOp(x).result
-    raise NotImplementedError("Unsupported 'floor' operand: {x}")
-
-  def _unary_negf(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return arith.NegFOp(x).result
-    if _is_complex_type(x.type):
-      return complex.NegOp(x).result
-    raise NotImplementedError("Unsupported 'negf' operand: {x}")
-
-  def _binary_add(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.AddFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.AddIOp(lhs, rhs).result
-    if _is_complex_type(lhs.type):
-      return complex.AddOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
-
-  def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.SubFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.SubIOp(lhs, rhs).result
-    if _is_complex_type(lhs.type):
-      return complex.SubOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
-
-  def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MulFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MulIOp(lhs, rhs).result
-    if _is_complex_type(lhs.type):
-      return complex.MulOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
-
-  def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MaxFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MaxSIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
-
-  def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MaxFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MaxUIOp(lhs, rhs).result
-    raise NotImplementedError(
-        "Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
-
-  def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MinFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MinSIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
-
-  def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MinFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MinUIOp(lhs, rhs).result
-    raise NotImplementedError(
-        "Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
+    """Constructs a structured op body by evaluating assignments."""
+
+    def __init__(
+        self,
+        type_mapping: Dict[str, Type],
+        block_arg_mapping: Dict[str, Value],
+        fn_attr_mapping: Dict[str, str],
+    ):
+        self.type_mapping = type_mapping
+        self.block_arg_mapping = block_arg_mapping
+        self.fn_attr_mapping = fn_attr_mapping
+        self.yield_mapping = dict()  # type: Dict[str, Value]
+
+    def assign(self, assignment: ScalarAssign):
+        if assignment.arg in self.yield_mapping:
+            raise ValueError(
+                f"Multiple assignments to the same argument are forbidden: "
+                f"{assignment}"
+            )
+        self.yield_mapping[assignment.arg] = self.expression(assignment.value)
+
+    def expression(self, expr: ScalarExpression) -> Value:
+        if expr.scalar_arg:
+            try:
+                return self.block_arg_mapping[expr.scalar_arg.arg]
+            except KeyError:
+                raise ValueError(
+                    f"Argument {expr.scalar_arg.arg} is not bound for "
+                    f"this structured op."
+                )
+        elif expr.scalar_const:
+            value_attr = Attribute.parse(expr.scalar_const.value)
+            return arith.ConstantOp(value_attr.type, value_attr).result
+        elif expr.scalar_index:
+            dim_attr = IntegerAttr.get(
+                IntegerType.get_signless(64), expr.scalar_index.dim
+            )
+            return linalg.IndexOp(dim_attr).result
+        elif expr.scalar_fn:
+            kind = expr.scalar_fn.kind.name.lower()
+            fn_name = expr.scalar_fn.fn_name
+            if expr.scalar_fn.attr_name:
+                fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
+            fn = self._get_function(f"_{kind}_{fn_name}")
+            operand_values = [
+                self.expression(operand) for operand in expr.scalar_fn.operands
+            ]
+            if expr.scalar_fn.kind == FunctionKind.TYPE:
+                operand_values = [expr.scalar_fn.type_var.name] + operand_values
+            return fn(*operand_values)
+        raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
+
+    def yield_outputs(self, *output_names: str):
+        output_values = []
+        for n in output_names:
+            try:
+                output_values.append(self.yield_mapping[n])
+            except KeyError:
+                raise ValueError(
+                    f"Body assignments do not assign all outputs: " f"missing '{n}'"
+                )
+        linalg.YieldOp(output_values)
+
+    def _get_function(self, fn_name: str) -> Callable:
+        try:
+            fn = getattr(self, f"{fn_name}")
+        except AttributeError:
+            raise ValueError(f"Function '{fn_name}' is not a known function")
+        return fn
+
+    def _cast(
+        self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False
+    ) -> Value:
+        try:
+            to_type = self.type_mapping[type_var_name]
+        except KeyError:
+            raise ValueError(
+                f"Unbound type variable '{type_var_name}' ("
+                f"expected one of {self.type_mapping.keys()}"
+            )
+        if operand.type == to_type:
+            return operand
+        if _is_integer_type(to_type):
+            return self._cast_to_integer(to_type, operand, is_unsigned_cast)
+        elif _is_floating_point_type(to_type):
+            return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
+
+    def _cast_to_integer(
+        self, to_type: Type, operand: Value, is_unsigned_cast: bool
+    ) -> Value:
+        to_width = IntegerType(to_type).width
+        operand_type = operand.type
+        if _is_floating_point_type(operand_type):
+            if is_unsigned_cast:
+                return arith.FPToUIOp(to_type, operand).result
+            return arith.FPToSIOp(to_type, operand).result
+        if _is_index_type(operand_type):
+            return arith.IndexCastOp(to_type, operand).result
+        # Assume integer.
+        from_width = IntegerType(operand_type).width
+        if to_width > from_width:
+            if is_unsigned_cast:
+                return arith.ExtUIOp(to_type, operand).result
+            return arith.ExtSIOp(to_type, operand).result
+        elif to_width < from_width:
+            return arith.TruncIOp(to_type, operand).result
+        raise ValueError(
+            f"Unable to cast body expression from {operand_type} to " f"{to_type}"
+        )
+
+    def _cast_to_floating_point(
+        self, to_type: Type, operand: Value, is_unsigned_cast: bool
+    ) -> Value:
+        operand_type = operand.type
+        if _is_integer_type(operand_type):
+            if is_unsigned_cast:
+                return arith.UIToFPOp(to_type, operand).result
+            return arith.SIToFPOp(to_type, operand).result
+        # Assume FloatType.
+        to_width = _get_floating_point_width(to_type)
+        from_width = _get_floating_point_width(operand_type)
+        if to_width > from_width:
+            return arith.ExtFOp(to_type, operand).result
+        elif to_width < from_width:
+            return arith.TruncFOp(to_type, operand).result
+        raise ValueError(
+            f"Unable to cast body expression from {operand_type} to " f"{to_type}"
+        )
+
+    def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
+        return self._cast(type_var_name, operand, False)
+
+    def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
+        return self._cast(type_var_name, operand, True)
+
+    def _unary_exp(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.ExpOp(x).result
+        raise NotImplementedError("Unsupported 'exp' operand: {x}")
+
+    def _unary_log(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.LogOp(x).result
+        raise NotImplementedError("Unsupported 'log' operand: {x}")
+
+    def _unary_abs(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.AbsFOp(x).result
+        raise NotImplementedError("Unsupported 'abs' operand: {x}")
+
+    def _unary_ceil(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.CeilOp(x).result
+        raise NotImplementedError("Unsupported 'ceil' operand: {x}")
+
+    def _unary_floor(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.FloorOp(x).result
+        raise NotImplementedError("Unsupported 'floor' operand: {x}")
+
+    def _unary_negf(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return arith.NegFOp(x).result
+        if _is_complex_type(x.type):
+            return complex.NegOp(x).result
+        raise NotImplementedError("Unsupported 'negf' operand: {x}")
+
+    def _binary_add(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.AddFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.AddIOp(lhs, rhs).result
+        if _is_complex_type(lhs.type):
+            return complex.AddOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
+
+    def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.SubFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.SubIOp(lhs, rhs).result
+        if _is_complex_type(lhs.type):
+            return complex.SubOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
+
+    def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MulFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MulIOp(lhs, rhs).result
+        if _is_complex_type(lhs.type):
+            return complex.MulOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
+
+    def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MaxFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MaxSIOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
+
+    def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MaxFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MaxUIOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
+
+    def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MinFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MinSIOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
+
+    def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MinFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MinUIOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
 
 
 def _infer_structured_outs(
     op_config: LinalgStructuredOpConfig,
-    in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value],
+    in_arg_defs: Sequence[OperandDefConfig],
+    ins: Sequence[Value],
     out_arg_defs: Sequence[OperandDefConfig],
-    outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]:
-  """Infers implicit outs and output types.
+    outs: Union[Sequence[Value], OpResultList],
+) -> Tuple[ValueList, List[Type]]:
+    """Infers implicit outs and output types.
 
-  Respects existing contents of outs if not empty.
+    Respects existing contents of outs if not empty.
 
-  Returns:
-    normalized outs, output types
-  """
-  # If outs were explicitly provided, we accept them verbatim.
-  if outs:
-    return outs, [out.type for out in outs]
+    Returns:
+      normalized outs, output types
+    """
+    # If outs were explicitly provided, we accept them verbatim.
+    if outs:
+        return outs, [out.type for out in outs]
 
-  raise NotImplementedError(f"Output tensor inference not yet supported for "
-                            "structured ops")
+    raise NotImplementedError(
+        f"Output tensor inference not yet supported for " "structured ops"
+    )
 
 
 def _get_types_from_values(*values: Value) -> Sequence[Type]:
-  types = []
-  for v in values:
-    types.append(v.type)
-  return types
+    types = []
+    for v in values:
+        types.append(v.type)
+    return types
 
 
 def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]:
-  return [odc.operand_def.name for odc in operand_configs]
-
-
-def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type,
-                      type_mapping: Dict[str, Type],
-                      block_arg_types: Sequence[Type]):
-  element_or_self_type = operand_type
-  # Get the element type for tensor operands and the type itself for scalars.
-  if operand_config.shape_map:
-    try:
-      element_or_self_type = ShapedType(operand_type).element_type
-    except Exception as e:
-      raise ValueError(f"Expected ShapedType but got {operand_type}") from e
-  name = operand_config.type_var.name
-  if name in type_mapping:
-    if type_mapping[name] != element_or_self_type:
-      raise ValueError(f"Cannot overwrite type mapping {name} = "
-                       f"{type_mapping[name]} by type {element_or_self_type}")
-  type_mapping[name] = element_or_self_type
-  block_arg_types.append(element_or_self_type)
+    return [odc.operand_def.name for odc in operand_configs]
+
+
+def _add_type_mapping(
+    operand_config: OperandDefConfig,
+    operand_type: Type,
+    type_mapping: Dict[str, Type],
+    block_arg_types: Sequence[Type],
+):
+    element_or_self_type = operand_type
+    # Get the element type for tensor operands and the type itself for scalars.
+    if operand_config.shape_map:
+        try:
+            element_or_self_type = ShapedType(operand_type).element_type
+        except Exception as e:
+            raise ValueError(f"Expected ShapedType but got {operand_type}") from e
+    name = operand_config.type_var.name
+    if name in type_mapping:
+        if type_mapping[name] != element_or_self_type:
+            raise ValueError(
+                f"Cannot overwrite type mapping {name} = "
+                f"{type_mapping[name]} by type {element_or_self_type}"
+            )
+    type_mapping[name] = element_or_self_type
+    block_arg_types.append(element_or_self_type)
 
 
 def _is_complex_type(t: Type) -> bool:
-  return ComplexType.isinstance(t)
+    return ComplexType.isinstance(t)
 
 
 def _is_floating_point_type(t: Type) -> bool:
-  # TODO: Create a FloatType in the Python API and implement the switch
-  # there.
-  return (F64Type.isinstance(t) or F32Type.isinstance(t) or
-          F16Type.isinstance(t) or BF16Type.isinstance(t))
+    # TODO: Create a FloatType in the Python API and implement the switch
+    # there.
+    return (
+        F64Type.isinstance(t)
+        or F32Type.isinstance(t)
+        or F16Type.isinstance(t)
+        or BF16Type.isinstance(t)
+    )
 
 
 def _is_integer_type(t: Type) -> bool:
-  return IntegerType.isinstance(t)
+    return IntegerType.isinstance(t)
 
 
 def _is_index_type(t: Type) -> bool:
-  return IndexType.isinstance(t)
+    return IndexType.isinstance(t)
 
 
 def _get_floating_point_width(t: Type) -> int:
-  # TODO: Create a FloatType in the Python API and implement the switch
-  # there.
-  if F64Type.isinstance(t):
-    return 64
-  if F32Type.isinstance(t):
-    return 32
-  if F16Type.isinstance(t):
-    return 16
-  if BF16Type.isinstance(t):
-    return 16
-  raise NotImplementedError(f"Unhandled floating point type switch {t}")
+    # TODO: Create a FloatType in the Python API and implement the switch
+    # there.
+    if F64Type.isinstance(t):
+        return 64
+    if F32Type.isinstance(t):
+        return 32
+    if F16Type.isinstance(t):
+        return 16
+    if BF16Type.isinstance(t):
+        return 16
+    raise NotImplementedError(f"Unhandled floating point type switch {t}")

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
index aa894dc10954f..86853994c0a1c 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
@@ -30,123 +30,137 @@
 
 
 class ScalarFn:
-  """A type of ScalarExpression that applies a function."""
-
-  def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
-               attr_name: Optional[str], type_var: Optional["TypeVar"],
-               operands: Sequence["ScalarExpression"]):
-    if bool(fn_name) + bool(attr_name) != 1:
-      raise ValueError("One of 'fn_name', 'attr_name' must be specified")
-    self.kind = kind
-    self.fn_name = fn_name
-    self.attr_name = attr_name
-    self.type_var = type_var
-    self.operands = operands
-
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(scalar_fn=self)
-
-  def __repr__(self):
-    name = self.fn_name if self.fn_name else self.attr_name
-    return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
-            f"operands=[{', '.join(self.operands)}])")
+    """A type of ScalarExpression that applies a function."""
+
+    def __init__(
+        self,
+        kind: "FunctionKind",
+        fn_name: Optional[str],
+        attr_name: Optional[str],
+        type_var: Optional["TypeVar"],
+        operands: Sequence["ScalarExpression"],
+    ):
+        if bool(fn_name) + bool(attr_name) != 1:
+            raise ValueError("One of 'fn_name', 'attr_name' must be specified")
+        self.kind = kind
+        self.fn_name = fn_name
+        self.attr_name = attr_name
+        self.type_var = type_var
+        self.operands = operands
+
+    def expr(self) -> "ScalarExpression":
+        return ScalarExpression(scalar_fn=self)
+
+    def __repr__(self):
+        name = self.fn_name if self.fn_name else self.attr_name
+        return (
+            f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
+            f"operands=[{', '.join(self.operands)}])"
+        )
 
 
 class ScalarArg:
-  """A type of ScalarExpression that references a named argument."""
+    """A type of ScalarExpression that references a named argument."""
 
-  def __init__(self, arg: str):
-    self.arg = arg
+    def __init__(self, arg: str):
+        self.arg = arg
 
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(scalar_arg=self)
+    def expr(self) -> "ScalarExpression":
+        return ScalarExpression(scalar_arg=self)
 
-  def __repr__(self):
-    return f"(ScalarArg({self.arg})"
+    def __repr__(self):
+        return f"(ScalarArg({self.arg})"
 
 
 class ScalarConst:
-  """A type of ScalarExpression representing a constant."""
+    """A type of ScalarExpression representing a constant."""
 
-  def __init__(self, value: str):
-    self.value = value
+    def __init__(self, value: str):
+        self.value = value
 
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(scalar_const=self)
+    def expr(self) -> "ScalarExpression":
+        return ScalarExpression(scalar_const=self)
 
-  def __repr__(self):
-    return f"(ScalarConst({self.value})"
+    def __repr__(self):
+        return f"(ScalarConst({self.value})"
 
 
 class ScalarIndex:
-  """A type of ScalarExpression accessing an iteration index."""
+    """A type of ScalarExpression accessing an iteration index."""
 
-  def __init__(self, dim: int):
-    self.dim = dim
+    def __init__(self, dim: int):
+        self.dim = dim
 
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(scalar_index=self)
+    def expr(self) -> "ScalarExpression":
+        return ScalarExpression(scalar_index=self)
 
-  def __repr__(self):
-    return f"(ScalarIndex({self.dim})"
+    def __repr__(self):
+        return f"(ScalarIndex({self.dim})"
 
 
 class ScalarExpression(YAMLObject):
-  """An expression on scalar values.
-
-  Can be one of:
-    - ScalarFn
-    - ScalarArg
-    - ScalarConst
-    - ScalarIndex
-  """
-  yaml_tag = "!ScalarExpression"
-
-  def __init__(self,
-               scalar_fn: Optional[ScalarFn] = None,
-               scalar_arg: Optional[ScalarArg] = None,
-               scalar_const: Optional[ScalarConst] = None,
-               scalar_index: Optional[ScalarIndex] = None):
-    if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
-        bool(scalar_index)) != 1:
-      raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
-                       "'scalar_index' must be specified")
-    self.scalar_fn = scalar_fn
-    self.scalar_arg = scalar_arg
-    self.scalar_const = scalar_const
-    self.scalar_index = scalar_index
-
-  def to_yaml_custom_dict(self):
-    if self.scalar_fn:
-      scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
-      if self.scalar_fn.fn_name:
-        scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
-      if self.scalar_fn.attr_name:
-        scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
-      if self.scalar_fn.type_var:
-        scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
-      scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
-      return dict(scalar_fn=scalar_fn_dict)
-    elif self.scalar_arg:
-      return dict(scalar_arg=self.scalar_arg.arg)
-    elif self.scalar_const:
-      return dict(scalar_const=self.scalar_const.value)
-    elif self.scalar_index:
-      return dict(scalar_index=self.scalar_index.dim)
-    else:
-      raise ValueError(f"Unexpected ScalarExpression type: {self}")
+    """An expression on scalar values.
+
+    Can be one of:
+      - ScalarFn
+      - ScalarArg
+      - ScalarConst
+      - ScalarIndex
+    """
+
+    yaml_tag = "!ScalarExpression"
+
+    def __init__(
+        self,
+        scalar_fn: Optional[ScalarFn] = None,
+        scalar_arg: Optional[ScalarArg] = None,
+        scalar_const: Optional[ScalarConst] = None,
+        scalar_index: Optional[ScalarIndex] = None,
+    ):
+        if (
+            bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index)
+        ) != 1:
+            raise ValueError(
+                "One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
+                "'scalar_index' must be specified"
+            )
+        self.scalar_fn = scalar_fn
+        self.scalar_arg = scalar_arg
+        self.scalar_const = scalar_const
+        self.scalar_index = scalar_index
+
+    def to_yaml_custom_dict(self):
+        if self.scalar_fn:
+            scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
+            if self.scalar_fn.fn_name:
+                scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
+            if self.scalar_fn.attr_name:
+                scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
+            if self.scalar_fn.type_var:
+                scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
+            scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
+            return dict(scalar_fn=scalar_fn_dict)
+        elif self.scalar_arg:
+            return dict(scalar_arg=self.scalar_arg.arg)
+        elif self.scalar_const:
+            return dict(scalar_const=self.scalar_const.value)
+        elif self.scalar_index:
+            return dict(scalar_index=self.scalar_index.dim)
+        else:
+            raise ValueError(f"Unexpected ScalarExpression type: {self}")
 
 
 class ScalarAssign(YAMLObject):
-  """An assignment to a named argument (LHS of a comprehension)."""
-  yaml_tag = "!ScalarAssign"
+    """An assignment to a named argument (LHS of a comprehension)."""
+
+    yaml_tag = "!ScalarAssign"
 
-  def __init__(self, arg: str, value: ScalarExpression):
-    self.arg = arg
-    self.value = value
+    def __init__(self, arg: str, value: ScalarExpression):
+        self.arg = arg
+        self.value = value
 
-  def to_yaml_custom_dict(self):
-    return dict(arg=self.arg, value=self.value)
+    def to_yaml_custom_dict(self):
+        return dict(arg=self.arg, value=self.value)
 
-  def __repr__(self):
-    return f"ScalarAssign({self.arg}, {self.value})"
+    def __repr__(self):
+        return f"ScalarAssign({self.arg}, {self.value})"

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py
index ddac87287e617..4f36029b7fb88 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py
@@ -21,13 +21,11 @@
 __all__ = [
     "TypeVar",
     "TV",
-
     # Predefined types.
     "I32",
     "I64",
     "F32",
     "F64",
-
     # TypeVar aliases.
     "T",
     "U",
@@ -36,34 +34,34 @@
 
 
 class TypeVar:
-  """A replaceable type variable.
+    """A replaceable type variable.
 
-  Type variables are uniqued by name.
-  """
-  ALL_TYPEVARS = dict()  # type: Dict[str, "TypeVar"]
+    Type variables are uniqued by name.
+    """
 
-  def __new__(cls, name: str):
-    existing = cls.ALL_TYPEVARS.get(name)
-    if existing is not None:
-      return existing
-    new = super().__new__(cls)
-    new.name = name
-    cls.ALL_TYPEVARS[name] = new
-    return new
+    ALL_TYPEVARS = dict()  # type: Dict[str, "TypeVar"]
 
-  def __repr__(self):
-    return f"TypeVar({self.name})"
+    def __new__(cls, name: str):
+        existing = cls.ALL_TYPEVARS.get(name)
+        if existing is not None:
+            return existing
+        new = super().__new__(cls)
+        new.name = name
+        cls.ALL_TYPEVARS[name] = new
+        return new
 
-  @classmethod
-  def create_expando(cls):
-    """Create an expando class that creates unique type vars on attr access."""
+    def __repr__(self):
+        return f"TypeVar({self.name})"
 
-    class ExpandoTypeVars:
+    @classmethod
+    def create_expando(cls):
+        """Create an expando class that creates unique type vars on attr access."""
 
-      def __getattr__(self, n):
-        return cls(n)
+        class ExpandoTypeVars:
+            def __getattr__(self, n):
+                return cls(n)
 
-    return ExpandoTypeVars()
+        return ExpandoTypeVars()
 
 
 # Expando access via TV.foo

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
index 1945eea53f80b..1672656b3a1f8 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
@@ -6,11 +6,12 @@
 import sys
 
 try:
-  import yaml
+    import yaml
 except ModuleNotFoundError as e:
-  raise ModuleNotFoundError(
-      f"This tool requires PyYAML but it was not installed. "
-      f"Recommend: {sys.executable} -m pip install PyYAML") from e
+    raise ModuleNotFoundError(
+        f"This tool requires PyYAML but it was not installed. "
+        f"Recommend: {sys.executable} -m pip install PyYAML"
+    ) from e
 
 __all__ = [
     "yaml_dump",
@@ -20,35 +21,33 @@
 
 
 class YAMLObject(yaml.YAMLObject):
+    @classmethod
+    def to_yaml(cls, dumper, self):
+        """Default to a custom dictionary mapping."""
+        return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
 
-  @classmethod
-  def to_yaml(cls, dumper, self):
-    """Default to a custom dictionary mapping."""
-    return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
+    def to_yaml_custom_dict(self):
+        raise NotImplementedError()
 
-  def to_yaml_custom_dict(self):
-    raise NotImplementedError()
-
-  def as_linalg_yaml(self):
-    return yaml_dump(self)
+    def as_linalg_yaml(self):
+        return yaml_dump(self)
 
 
 def multiline_str_representer(dumper, data):
-  if len(data.splitlines()) > 1:
-    return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
-  else:
-    return dumper.represent_scalar('tag:yaml.org,2002:str', data)
+    if len(data.splitlines()) > 1:
+        return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
+    else:
+        return dumper.represent_scalar("tag:yaml.org,2002:str", data)
 
 
 yaml.add_representer(str, multiline_str_representer)
 
 
 def yaml_dump(data, sort_keys=False, **kwargs):
-  return yaml.dump(data, sort_keys=sort_keys, **kwargs)
+    return yaml.dump(data, sort_keys=sort_keys, **kwargs)
 
 
 def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
-  return yaml.dump_all(data,
-                       sort_keys=sort_keys,
-                       explicit_start=explicit_start,
-                       **kwargs)
+    return yaml.dump_all(
+        data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
+    )

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 9c96868c10f1a..bac22a2e59db1 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -7,99 +7,113 @@
 
 
 @linalg_structured_op
-def copy(I=TensorDef(T1),
-         O=TensorDef(U, output=True),
-         cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Copies the tensor elementwise.
+def copy(
+    I=TensorDef(T1),
+    O=TensorDef(U, output=True),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Copies the tensor elementwise.
 
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  O[None] = cast(U, I[None])
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    O[None] = cast(U, I[None])
 
 
 @linalg_structured_op
-def elemwise_unary(I=TensorDef(T1),
-                   O=TensorDef(U, output=True),
-                   fun=UnaryFnAttrDef(default=UnaryFn.exp),
-                   cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Applies the unary function fun elementwise.
+def elemwise_unary(
+    I=TensorDef(T1),
+    O=TensorDef(U, output=True),
+    fun=UnaryFnAttrDef(default=UnaryFn.exp),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Applies the unary function fun elementwise.
 
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  O[None] = fun(cast(U, I[None]))
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    O[None] = fun(cast(U, I[None]))
 
 
 @linalg_structured_op
-def elemwise_binary(lhs=TensorDef(T1),
-                    rhs=TensorDef(T2),
-                    O=TensorDef(U, output=True),
-                    fun=BinaryFnAttrDef(default=BinaryFn.add),
-                    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Applies the binary function fun elementwise.
+def elemwise_binary(
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T2),
+    O=TensorDef(U, output=True),
+    fun=BinaryFnAttrDef(default=BinaryFn.add),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Applies the binary function fun elementwise.
 
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
 
 
 @linalg_structured_op
-def matmul(A=TensorDef(T1, S.M, S.K),
-           B=TensorDef(T2, S.K, S.N),
-           C=TensorDef(U, S.M, S.N, output=True),
-           cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Performs a matrix multiplication of two 2D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+def matmul(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Performs a matrix multiplication of two 2D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
 
 @linalg_structured_op
-def matmul_unsigned(A=TensorDef(T1, S.M, S.K),
-                    B=TensorDef(T2, S.K, S.N),
-                    C=TensorDef(U, S.M, S.N, output=True)):
-  """Performs an unsigned matrix multiplication of two 2D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
-      U, B[D.k, D.n])
+def matmul_unsigned(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    """Performs an unsigned matrix multiplication of two 2D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
+        U, B[D.k, D.n]
+    )
 
 
 @linalg_structured_op
-def quantized_matmul(A=TensorDef(T1, S.M, S.K),
-                     B=TensorDef(T2, S.K, S.N),
-                     AZp=ScalarDef(I32),
-                     BZp=ScalarDef(I32),
-                     C=TensorDef(U, S.M, S.N, output=True)):
-  """Performs a matrix multiplication of two 2D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. The quantized variant
-  includes zero-point adjustments for the left and right operands of the
-  matmul.
-  """
-  domain(D.m, D.n, D.k)
-  C[D.m,
-    D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) -
-             TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(U, B[D.k, D.n]) -
-                                            TypeFn.cast_signed(U, BZp))
+def quantized_matmul(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    AZp=ScalarDef(I32),
+    BZp=ScalarDef(I32),
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    """Performs a matrix multiplication of two 2D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. The quantized variant
+    includes zero-point adjustments for the left and right operands of the
+    matmul.
+    """
+    domain(D.m, D.n, D.k)
+    C[D.m, D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
+        TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp)
+    )
 
 
 @linalg_structured_op
-def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
-          rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
-          accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)):
-  """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
+def mmt4d(
+    lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
+    rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
+    accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True),
+):
+    """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
 
     Differences from linalg.matmul:
     * The right hand side is transposed, whence the 't' in 'mmt'.
@@ -108,1132 +122,1201 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
       whence the 2+2=4 dimensions. The inner tile dimensions are identified with
       '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
       as: MxK tiles, each of shape M0xK0.
-  """
-  domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
-  implements(ContractionOpInterface)
-  accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
-      TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed(
-          TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
+    """
+    domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
+    implements(ContractionOpInterface)
+    accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
+        TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]
+    ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
 
 
 @linalg_structured_op
-def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
-                 B=TensorDef(T2, Batch, S.K, S.N),
-                 C=TensorDef(U, Batch, S.M, S.N, output=True)):
-  """Performs a batched matrix multiplication of two 3D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.b, D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.b, D.m,
-    D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-        U, B[D.b, D.k, D.n])
+def batch_matmul(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+    """Performs a batched matrix multiplication of two 3D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.b, D.k, D.n]
+    )
 
 
 @linalg_structured_op
-def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
-                           B=TensorDef(T2, Batch, S.K, S.N),
-                           AZp=ScalarDef(I32),
-                           BZp=ScalarDef(I32),
-                           C=TensorDef(U, Batch, S.M, S.N, output=True)):
-  """Performs a batched matrix multiplication of two 3D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. The quantized variant
-  includes zero-point adjustments for the left and right operands of the
-  matmul.
-  """
-  domain(D.b, D.m, D.n, D.k)
-  C[D.b, D.m, D.n] += (TypeFn.cast_signed(U, A[D.b, D.m, D.k]) -
-                       TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
-                           U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
+def quantized_batch_matmul(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    AZp=ScalarDef(I32),
+    BZp=ScalarDef(I32),
+    C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+    """Performs a batched matrix multiplication of two 3D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. The quantized variant
+    includes zero-point adjustments for the left and right operands of the
+    matmul.
+    """
+    domain(D.b, D.m, D.n, D.k)
+    C[D.b, D.m, D.n] += (
+        TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)
+    ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
 
 
 @linalg_structured_op
-def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
-                        B=TensorDef(T2, Batch, S.K, S.N),
-                        C=TensorDef(U, S.M, S.N, output=True)):
-  """Performs a batch-reduce matrix multiplication of two 3D inputs.
-  The partial multiplication results are reduced into a 2D output.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.b, D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.m, D.n] += TypeFn.cast_signed(
-      U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n]))
+def batch_reduce_matmul(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    """Performs a batch-reduce matrix multiplication of two 3D inputs.
+    The partial multiplication results are reduced into a 2D output.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.m, D.n] += TypeFn.cast_signed(
+        U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
+    )
 
 
 @linalg_structured_op
-def matvec(A=TensorDef(T1, S.M, S.N),
-           y=TensorDef(T2, S.N),
-           x=TensorDef(U, S.M, output=True)):
-  """Performs a matrix-vector multiplication.
+def matvec(
+    A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True)
+):
+    """Performs a matrix-vector multiplication.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.m, D.n)
-  implements(ContractionOpInterface)
-  x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.m, D.n)
+    implements(ContractionOpInterface)
+    x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
 
 
 @linalg_structured_op
-def vecmat(y=TensorDef(T1, S.M),
-           A=TensorDef(T2, S.M, S.N),
-           x=TensorDef(U, S.N, output=True)):
-  """Performs a vector-matrix multiplication.
+def vecmat(
+    y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True)
+):
+    """Performs a vector-matrix multiplication.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.n, D.m)
-  implements(ContractionOpInterface)
-  x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.n, D.m)
+    implements(ContractionOpInterface)
+    x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
 
 
 @linalg_structured_op
-def batch_matvec(A=TensorDef(T1, Batch, S.M, S.K),
-                 B=TensorDef(T2, Batch, S.K),
-                 C=TensorDef(U, Batch, S.M, output=True)):
-  """Performs a batched matrix-vector multiplication.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.b, D.m, D.k)
-  implements(ContractionOpInterface)
-  C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-      U, B[D.b, D.k])
+def batch_matvec(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K),
+    C=TensorDef(U, Batch, S.M, output=True),
+):
+    """Performs a batched matrix-vector multiplication.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.m, D.k)
+    implements(ContractionOpInterface)
+    C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.b, D.k]
+    )
 
 
 @linalg_structured_op
-def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
-                                                                output=True)):
-  """Performs a dot product of two vectors to a scalar result.
+def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
+    """Performs a dot product of two vectors to a scalar result.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ContractionOpInterface)
-  C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ContractionOpInterface)
+    C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
 
 
 @linalg_structured_op
-def conv_1d(I=TensorDef(T1, S.OW + S.KW),
-            K=TensorDef(T2, S.KW),
-            O=TensorDef(U, S.OW, output=True)):
-  """Performs 1-D convolution with no channels.
+def conv_1d(
+    I=TensorDef(T1, S.OW + S.KW),
+    K=TensorDef(T2, S.KW),
+    O=TensorDef(U, S.OW, output=True),
+):
+    """Performs 1-D convolution with no channels.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.ow, D.kw)
-  O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(
-      U, K[D.kw])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.ow, D.kw)
+    O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kw])
 
 
 @linalg_structured_op
-def conv_2d(I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
-            K=TensorDef(T2, S.KH, S.KW),
-            O=TensorDef(U, S.OH, S.OW, output=True)):
-  """Performs 2-D convolution with no channels.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.oh, D.ow, D.kh, D.kw)
-  O[D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw])
+def conv_2d(
+    I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
+    K=TensorDef(T2, S.KH, S.KW),
+    O=TensorDef(U, S.OH, S.OW, output=True),
+):
+    """Performs 2-D convolution with no channels.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.oh, D.ow, D.kh, D.kw)
+    O[D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.oh + D.kh, D.ow + D.kw]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw])
 
 
 @linalg_structured_op
-def conv_3d(I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
-            K=TensorDef(T2, S.KD, S.KH, S.KW),
-            O=TensorDef(U, S.OD, S.OH, S.OW, output=True)):
-  """Performs 3-D convolution with no channels.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
-  O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(
-          U, K[D.kd, D.kh, D.kw])
+def conv_3d(
+    I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
+    K=TensorDef(T2, S.KD, S.KH, S.KW),
+    O=TensorDef(U, S.OD, S.OH, S.OW, output=True),
+):
+    """Performs 3-D convolution with no channels.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
+    O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]
+    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw])
 
 
 @linalg_structured_op
-def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
-                    K=TensorDef(T2, S.KW, S.C, S.F),
-                    O=TensorDef(U, S.N, S.OW, S.F, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs 1-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.f, D.kw, D.c)
-  O[D.n, D.ow, D.f] += TypeFn.cast_signed(
-      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
-          U, K[D.kw, D.c, D.f])
+def conv_1d_nwc_wcf(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, S.C, S.F),
+    O=TensorDef(U, S.N, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs 1-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.f, D.kw, D.c)
+    O[D.n, D.ow, D.f] += TypeFn.cast_signed(
+        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]
+    ) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f])
 
 
 @linalg_structured_op
-def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
-                    K=TensorDef(T2, S.F, S.C, S.KW),
-                    O=TensorDef(U, S.N, S.F, S.OW, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs 1-D convolution.
-
-  Layout:
-    * Input: NCW.
-    * Kernel: FCW.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.f, D.ow, D.c, D.kw)
-  O[D.n, D.f, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
-          U, K[D.f, D.c, D.kw])
+def conv_1d_ncw_fcw(
+    I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.F, S.C, S.KW),
+    O=TensorDef(U, S.N, S.F, S.OW, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs 1-D convolution.
+
+    Layout:
+      * Input: NCW.
+      * Kernel: FCW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.f, D.ow, D.c, D.kw)
+    O[D.n, D.f, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kw])
 
 
 @linalg_structured_op
-def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
-                      O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D convolution.
-
-  Layout:
-    * Input: NHWC.
-    * Kernel: HWCF.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
+def conv_2d_nhwc_hwcf(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: HWCF.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
 
 
 @linalg_structured_op
-def conv_2d_nhwc_fhwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
-                      O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D convolution.
-
-  Layout:
-    * Input: NHWC.
-    * Kernel: FHWC.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.c]) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c])
+def conv_2d_nhwc_fhwc(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: FHWC.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+    ) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c])
 
 
 @linalg_structured_op
-def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                    S.OW * S.SW + S.KW * S.DW, S.C),
-                        K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
-                        IZp=ScalarDef(I32),
-                        KZp=ScalarDef(I32),
-                        O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-                        strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                        dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D convolution with zero point offsets.
-
-  Layout:
-    * Input: NHWC.
-    * Kernel: HWCF.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. This includes the zero
-  point offsets common to quantized operations.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow,
-    D.f] += (TypeFn.cast_signed(
-        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
-             TypeFn.cast_signed(U, IZp)) * (TypeFn.cast_signed(
-                 U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
+def conv_2d_nhwc_hwcf_q(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution with zero point offsets.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: HWCF.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.f] += (
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
 
 
 @linalg_structured_op
-def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW),
-                      K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
-                      O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
-                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D convolution.
-
-  Layout:
-    * Input: NCHW.
-    * Kernel: FCHW.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
-           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
+def conv_2d_nchw_fchw(
+    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution.
+
+    Layout:
+      * Input: NCHW.
+      * Kernel: FCHW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
 
 
 @linalg_structured_op
-def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C,
-                                    S.OH * S.SH + S.KH * S.DH,
-                                    S.OW * S.SW + S.KW * S.DW),
-                        K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
-                        O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
-                        strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                        dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D grouped convolution.
-
-  Layout:
-    * Input: NGCHW.
-    * Kernel: FGCHW.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
-           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
+def conv_2d_ngchw_fgchw(
+    I=TensorDef(
+        T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
+    ),
+    K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NGCHW.
+      * Kernel: FGCHW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
 
 
 @linalg_structured_op
-def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                    S.OH * S.SH + S.KH * S.DH,
-                                    S.OW * S.SW + S.KW * S.DW, S.C),
-                        K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
-                        O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
-                        strides=IndexAttrDef(S.SD,
-                                             S.SH,
-                                             S.SW,
-                                             default=[1, 1, 1]),
-                        dilations=IndexAttrDef(S.DD,
-                                               S.DH,
-                                               S.DW,
-                                               default=[1, 1, 1])):
-  """Performs 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
-  O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
-               U, K[D.kd, D.kh, D.kw, D.c, D.f])
+def conv_3d_ndhwc_dhwcf(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+    O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+            D.c,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
 
 
 @linalg_structured_op
-def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                      S.OH * S.SH + S.KH * S.DH,
-                                      S.OW * S.SW + S.KW * S.DW, S.C),
-                          K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
-                          IZp=ScalarDef(I32),
-                          KZp=ScalarDef(I32),
-                          O=TensorDef(U,
-                                      S.N,
-                                      S.OD,
-                                      S.OH,
-                                      S.OW,
-                                      S.F,
-                                      output=True),
-                          strides=IndexAttrDef(S.SD,
-                                               S.SH,
-                                               S.SW,
-                                               default=[1, 1, 1]),
-                          dilations=IndexAttrDef(S.DD,
-                                                 S.DH,
-                                                 S.DW,
-                                                 default=[1, 1, 1])):
-  """Performs 3-D convolution with zero point offsets.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. This includes the zero
-  point offsets common to quantized operations.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
-  O[D.n, D.od, D.oh, D.ow, D.f] += (TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.c]) - TypeFn.cast_signed(U, IZp)) * (
-               TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) -
-               TypeFn.cast_signed(U, KZp))
+def conv_3d_ndhwc_dhwcf_q(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3-D convolution with zero point offsets.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+    O[D.n, D.od, D.oh, D.ow, D.f] += (
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.od * S.SD + D.kd * S.DD,
+                D.oh * S.SH + D.kh * S.DH,
+                D.ow * S.SW + D.kw * S.DW,
+                D.c,
+            ],
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (
+        TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
+        - TypeFn.cast_signed(U, KZp)
+    )
 
 
 @linalg_structured_op
-def conv_3d_ncdhw_fcdhw(I=TensorDef(T1, S.N, S.C, S.OD * S.SD + S.KD * S.DD,
-                                    S.OH * S.SH + S.KH * S.DH,
-                                    S.OW * S.SW + S.KW * S.DW),
-                        K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW),
-                        O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True),
-                        strides=IndexAttrDef(S.SD,
-                                             S.SH,
-                                             S.SW,
-                                             default=[1, 1, 1]),
-                        dilations=IndexAttrDef(S.DD,
-                                               S.DH,
-                                               S.DW,
-                                               default=[1, 1, 1])):
-  """Performs 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
-  O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
-               U, K[D.f, D.c, D.kd, D.kh, D.kw])
+def conv_3d_ncdhw_fcdhw(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.C,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+    ),
+    K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+    O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.c,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
 
 
 @linalg_structured_op
-def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
-                                         S.IC),
-                             K=TensorDef(T2, S.KW, S.IC),
-                             O=TensorDef(U, S.N, S.OW, S.IC, output=True),
-                             strides=IndexAttrDef(S.SW, default=[1]),
-                             dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs depth-wise 1-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.ic, D.kw)
-  O[D.n, D.ow, D.ic] += \
-      TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
-      TypeFn.cast_signed(U, K[D.kw, D.ic])
+def depthwise_conv_1d_nwc_wc(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KW, S.IC),
+    O=TensorDef(U, S.N, S.OW, S.IC, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs depth-wise 1-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.ic, D.kw)
+    O[D.n, D.ow, D.ic] += TypeFn.cast_signed(
+        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]
+    ) * TypeFn.cast_signed(U, K[D.kw, D.ic])
 
 
 @linalg_structured_op
-def depthwise_conv_1d_ncw_cw(I=TensorDef(T1, S.N, S.IC,
-                                         S.OW * S.SW + S.KW * S.DW),
-                             K=TensorDef(T2, S.IC, S.KW),
-                             O=TensorDef(U, S.N, S.IC, S.OW, output=True),
-                             strides=IndexAttrDef(S.SW, default=[1]),
-                             dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs depth-wise 1-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.ic, D.kw)
-  O[D.n, D.ic, D.ow] += \
-      TypeFn.cast_signed(U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]) * \
-      TypeFn.cast_signed(U, K[D.ic, D.kw])
+def depthwise_conv_1d_ncw_cw(
+    I=TensorDef(T1, S.N, S.IC, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.IC, S.KW),
+    O=TensorDef(U, S.N, S.IC, S.OW, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs depth-wise 1-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.ic, D.kw)
+    O[D.n, D.ic, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.ic, D.kw])
 
 
 @linalg_structured_op
-def depthwise_conv_1d_nwc_wcm(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
-                                          S.IC),
-                              K=TensorDef(T2, S.KW, S.IC, S.CM),
-                              O=TensorDef(U, S.N, S.OW, S.IC, S.CM,
-                                          output=True),
-                              strides=IndexAttrDef(S.SW, default=[1]),
-                              dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs depth-wise 1-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.ic, D.cm, D.kw)
-  O[D.n, D.ow, D.ic, D.cm] += \
-      TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
-      TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm])
+def depthwise_conv_1d_nwc_wcm(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KW, S.IC, S.CM),
+    O=TensorDef(U, S.N, S.OW, S.IC, S.CM, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs depth-wise 1-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.ic, D.cm, D.kw)
+    O[D.n, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
+        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]
+    ) * TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm])
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                           S.OW * S.SW + S.KW * S.DW, S.IC),
-                               K=TensorDef(T2, S.KH, S.KW, S.IC),
-                               O=TensorDef(U,
-                                           S.N,
-                                           S.OH,
-                                           S.OW,
-                                           S.IC,
-                                           output=True),
-                               strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                               dilations=IndexAttrDef(S.DH,
-                                                      S.DW,
-                                                      default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
+def depthwise_conv_2d_nhwc_hwc(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KH, S.KW, S.IC),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC,
-                                           S.OH * S.SH + S.KH * S.DH,
-                                           S.OW * S.SW + S.KW * S.DW),
-                               K=TensorDef(T2, S.IC, S.KH, S.KW),
-                               O=TensorDef(U,
-                                           S.N,
-                                           S.IC,
-                                           S.OH,
-                                           S.OW,
-                                           output=True),
-                               strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                               dilations=IndexAttrDef(S.DH,
-                                                      S.DW,
-                                                      default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
-  O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
-           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
+def depthwise_conv_2d_nchw_chw(
+    I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.IC, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.IC, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+    O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                             S.OW * S.SW + S.KW * S.DW, S.IC),
-                                 K=TensorDef(T2, S.KH, S.KW, S.IC),
-                                 IZp=ScalarDef(I32),
-                                 KZp=ScalarDef(I32),
-                                 O=TensorDef(U,
-                                             S.N,
-                                             S.OH,
-                                             S.OW,
-                                             S.IC,
-                                             output=True),
-                                 strides=IndexAttrDef(S.SH,
-                                                      S.SW,
-                                                      default=[1, 1]),
-                                 dilations=IndexAttrDef(S.DH,
-                                                        S.DW,
-                                                        default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
-                                TypeFn.cast_signed(U, IZp)) *
-                               (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) -
-                                TypeFn.cast_signed(U, KZp)))
+def depthwise_conv_2d_nhwc_hwc_q(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KH, S.KW, S.IC),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.ic] += (
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast_signed(U, KZp))
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwcm(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                            S.OW * S.SW + S.KW * S.DW, S.IC),
-                                K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
-                                O=TensorDef(U,
-                                            S.N,
-                                            S.OH,
-                                            S.OW,
-                                            S.IC,
-                                            S.CM,
-                                            output=True),
-                                strides=IndexAttrDef(S.SH, S.SW, default=[1,
-                                                                          1]),
-                                dilations=IndexAttrDef(S.DH,
-                                                       S.DW,
-                                                       default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
+def depthwise_conv_2d_nhwc_hwcm(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N,
-                                              S.OH * S.SH + S.KH * S.DH,
-                                              S.OW * S.SW + S.KW * S.DW, S.IC),
-                                  K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
-                                  IZp=ScalarDef(I32),
-                                  KZp=ScalarDef(I32),
-                                  O=TensorDef(U,
-                                              S.N,
-                                              S.OH,
-                                              S.OW,
-                                              S.IC,
-                                              S.CM,
-                                              output=True),
-                                  strides=IndexAttrDef(S.SH,
-                                                       S.SW,
-                                                       default=[1, 1]),
-                                  dilations=IndexAttrDef(S.DH,
-                                                         S.DW,
-                                                         default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic,
-    D.cm] += ((TypeFn.cast_signed(
-        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
-               TypeFn.cast_signed(U, IZp)) *
-              (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) -
-               TypeFn.cast_signed(U, KZp)))
+def depthwise_conv_2d_nhwc_hwcm_q(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.ic, D.cm] += (
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast_signed(U, KZp))
 
 
 @linalg_structured_op
-def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                             S.OH * S.SH + S.KH * S.DH,
-                                             S.OW * S.SW + S.KW * S.DW, S.IC),
-                                 K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
-                                 O=TensorDef(U,
-                                             S.N,
-                                             S.OD,
-                                             S.OH,
-                                             S.OW,
-                                             output=True),
-                                 strides=IndexAttrDef(S.SD,
-                                                      S.SH,
-                                                      S.SW,
-                                                      default=[1, 1, 1]),
-                                 dilations=IndexAttrDef(S.DD,
-                                                        S.DH,
-                                                        S.DW,
-                                                        default=[1, 1, 1])):
-  """Performs depth-wise 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
-  O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed(
-               U, K[D.kd, D.kh, D.kw, D.ic])
+def depthwise_conv_3d_ndhwc_dhwc(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.IC,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs depth-wise 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
+    O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+            D.ic,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
 
 
 @linalg_structured_op
-def depthwise_conv_3d_ncdhw_cdhw(I=TensorDef(T1, S.N, S.IC,
-                                             S.OD * S.SD + S.KD * S.DD,
-                                             S.OH * S.SH + S.KH * S.DH,
-                                             S.OW * S.SW + S.KW * S.DW),
-                                 K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW),
-                                 O=TensorDef(U,
-                                             S.N,
-                                             S.IC,
-                                             S.OD,
-                                             S.OH,
-                                             S.OW,
-                                             output=True),
-                                 strides=IndexAttrDef(S.SD,
-                                                      S.SH,
-                                                      S.SW,
-                                                      default=[1, 1, 1]),
-                                 dilations=IndexAttrDef(S.DD,
-                                                        S.DH,
-                                                        S.DW,
-                                                        default=[1, 1, 1])):
-  """Performs depth-wise 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
-  O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.ic, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
-               U, K[D.ic, D.kd, D.kh, D.kw])
+def depthwise_conv_3d_ncdhw_cdhw(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.IC,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+    ),
+    K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.IC, S.OD, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs depth-wise 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
+    O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.ic,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
 
 
 @linalg_structured_op
-def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N,
-                                              S.OD * S.SD + S.KD * S.DD,
-                                              S.OH * S.SH + S.KH * S.DH,
-                                              S.OW * S.SW + S.KW * S.DW, S.IC),
-                                  K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
-                                  O=TensorDef(U,
-                                              S.N,
-                                              S.OD,
-                                              S.OH,
-                                              S.OW,
-                                              S.CM,
-                                              output=True),
-                                  strides=IndexAttrDef(S.SD,
-                                                       S.SH,
-                                                       S.SW,
-                                                       default=[1, 1, 1]),
-                                  dilations=IndexAttrDef(S.DD,
-                                                         S.DH,
-                                                         S.DW,
-                                                         default=[1, 1, 1])):
-  """Performs depth-wise 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
-  O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed(
-               U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
+def depthwise_conv_3d_ndhwc_dhwcm(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.IC,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs depth-wise 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
+    O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+            D.ic,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
 
 
 @linalg_structured_op
-def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW, S.C),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs sum pooling.
-
-  Layout:
-    * Input: NHWC.
-    * Kernel: HW.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+def pooling_nhwc_sum(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs sum pooling.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: HW.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+    )
 
 
 @linalg_structured_op
-def pooling_nchw_sum(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs sum pooling.
-
-  Layout:
-    * Input: NCHW.
-    * Kernel: HW.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
-  O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW])
+def pooling_nchw_sum(
+    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs sum pooling.
+
+    Layout:
+      * Input: NCHW.
+      * Kernel: HW.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
+    O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    )
 
 
 @linalg_structured_op
-def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW, S.C),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_max(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nhwc_max_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                          S.OW * S.SW + S.KW * S.DW, S.C),
-                              K=TensorDef(T2,
-                                          S.KH,
-                                          S.KW,
-                                          index_dims=[D.kh, D.kw]),
-                              O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                              strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                              dilations=IndexAttrDef(S.DH, S.DW, default=[1,
-                                                                          1])):
-  """Performs unsigned max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow,
-    D.c] = ReduceFn.max_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
-        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_max_unsigned(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs unsigned max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
+        TypeFn.cast_unsigned(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nchw_max(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
-  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
-      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,]))
+def pooling_nchw_max(
+    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
+    O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.c,
+                D.oh * S.SH + D.kh * S.DH,
+                D.ow * S.SW + D.kw * S.DW,
+            ],
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nhwc_min(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW, S.C),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_min(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                          S.OW * S.SW + S.KW * S.DW, S.C),
-                              K=TensorDef(T2,
-                                          S.KH,
-                                          S.KW,
-                                          index_dims=[D.kh, D.kw]),
-                              O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                              strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                              dilations=IndexAttrDef(S.DH, S.DW, default=[1,
-                                                                          1])):
-  """Performs unsigned min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow,
-    D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
-        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_min_unsigned(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs unsigned min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
+        TypeFn.cast_unsigned(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+    )
+
 
 @linalg_structured_op
-def pooling_nwc_sum(I=TensorDef(T1, S.N,
-                                S.OW * S.SW + S.KW * S.DW, S.C),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs sum pooling.
-
-  Layout:
-    * Input: NWC.
-    * Kernel: W.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow, D.c] += TypeFn.cast_signed(
-      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+def pooling_nwc_sum(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs sum pooling.
+
+    Layout:
+      * Input: NWC.
+      * Kernel: W.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
 
 
 @linalg_structured_op
-def pooling_ncw_sum(I=TensorDef(T1, S.N, S.C,
-                                S.OW * S.SW + S.KW * S.DW),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.C, S.OW, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs sum pooling.
-
-  Layout:
-    * Input: NCW.
-    * Kernel: W.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.c, D.ow, D.kw)
-  O[D.n, D.c, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW])
+def pooling_ncw_sum(
+    I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.C, S.OW, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs sum pooling.
+
+    Layout:
+      * Input: NCW.
+      * Kernel: W.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.c, D.ow, D.kw)
+    O[D.n, D.c, D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW])
 
 
 @linalg_structured_op
-def pooling_nwc_max(I=TensorDef(T1, S.N,
-                                S.OW * S.SW + S.KW * S.DW, S.C),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed(
-      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nwc_max(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](
+        TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 @linalg_structured_op
-def pooling_nwc_max_unsigned(I=TensorDef(T1, S.N,
-                                         S.OW * S.SW + S.KW * S.DW, S.C),
-                             K=TensorDef(T2,
-                                         S.KW,
-                                         index_dims=[D.kw]),
-                             O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                             strides=IndexAttrDef(S.SW, default=[1]),
-                             dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs unsigned max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow,
-    D.c] = ReduceFn.max_unsigned[[D.kw]](TypeFn.cast_unsigned(
-        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nwc_max_unsigned(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs unsigned max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]](
+        TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 @linalg_structured_op
-def pooling_ncw_max(I=TensorDef(T1, S.N, S.C,
-                                S.OW * S.SW + S.KW * S.DW),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.C, S.OW, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.c, D.ow, D.kw)
-  O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed(
-      U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW,]))
+def pooling_ncw_max(
+    I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.C, S.OW, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.c, D.ow, D.kw)
+    O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.c,
+                D.ow * S.SW + D.kw * S.DW,
+            ],
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nwc_min(I=TensorDef(T1, S.N,
-                                S.OW * S.SW + S.KW * S.DW, S.C),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](TypeFn.cast_signed(
-      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nwc_min(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](
+        TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 @linalg_structured_op
-def pooling_nwc_min_unsigned(I=TensorDef(T1, S.N,
-                                         S.OW * S.SW + S.KW * S.DW, S.C),
-                             K=TensorDef(T2,
-                                         S.KW,
-                                         index_dims=[D.kw]),
-                             O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                             strides=IndexAttrDef(S.SW, default=[1]),
-                             dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs unsigned min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow,
-    D.c] = ReduceFn.min_unsigned[[D.kw]](TypeFn.cast_unsigned(
-        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
-
+def pooling_nwc_min_unsigned(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs unsigned min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]](
+        TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 @linalg_structured_op
-def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                  S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2,
-                                  S.KD,
-                                  S.KH,
-                                  S.KW,
-                                  index_dims=[D.kd, D.kh, D.kw]),
-                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
-                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-                      dilations=IndexAttrDef(S.DD,
-                                             S.DH,
-                                             S.DW,
-                                             default=[1, 1, 1])):
-  """Performs 3D sum pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.c])
+def pooling_ndhwc_sum(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3D sum pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
+    O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+            D.c,
+        ],
+    )
 
 
 @linalg_structured_op
-def pooling_ndhwc_max(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                  S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2,
-                                  S.KD,
-                                  S.KH,
-                                  S.KW,
-                                  index_dims=[D.kd, D.kh, D.kw]),
-                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
-                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-                      dilations=IndexAttrDef(S.DD,
-                                             S.DH,
-                                             S.DW,
-                                             default=[1, 1, 1])):
-  """Performs 3D max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow,
-    D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
-        U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-             D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_ndhwc_max(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3D max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
+    O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.od * S.SD + D.kd * S.DD,
+                D.oh * S.SH + D.kh * S.DH,
+                D.ow * S.SW + D.kw * S.DW,
+                D.c,
+            ],
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                  S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2,
-                                  S.KD,
-                                  S.KH,
-                                  S.KW,
-                                  index_dims=[D.kd, D.kh, D.kw]),
-                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
-                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-                      dilations=IndexAttrDef(S.DD,
-                                             S.DH,
-                                             S.DW,
-                                             default=[1, 1, 1])):
-  """Performs 3D min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow,
-    D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
-        U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-             D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_ndhwc_min(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3D min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
+    O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.od * S.SD + D.kd * S.DD,
+                D.oh * S.SH + D.kh * S.DH,
+                D.ow * S.SW + D.kw * S.DW,
+                D.c,
+            ],
+        )
+    )
 
 
 @linalg_structured_op
 def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
-  """Fills the output tensor with the given value.
+    """Fills the output tensor with the given value.
 
-  Works for arbitrary ranked output tensors since the operation performs scalar
-  accesses only and is thus rank polymorphic. Numeric casting is performed on
-  the value operand, promoting it to the same data type as the output.
-  """
-  implements(FillOpInterface)
-  defines(Canonicalizer)
-  O[None] = TypeFn.cast_signed(U, value)
+    Works for arbitrary ranked output tensors since the operation performs scalar
+    accesses only and is thus rank polymorphic. Numeric casting is performed on
+    the value operand, promoting it to the same data type as the output.
+    """
+    implements(FillOpInterface)
+    defines(Canonicalizer)
+    O[None] = TypeFn.cast_signed(U, value)
 
 
 @linalg_structured_op
-def fill_rng_2d(min=ScalarDef(F64),
-                max=ScalarDef(F64),
-                seed=ScalarDef(I32),
-                O=TensorDef(T, S.M, S.N, output=True)):
-  """Fills the output tensor with pseudo random numbers.
-
-  The operation generations pseudo random numbers using a linear congruential
-  generator. It provides no guarantees regarding the distribution of the
-  generated random numbers. Instead of generating the random numbers
-  sequentially, it instantiates one random number generator per data element
-  and runs them in parallel. The seed operand and the indices of the data
-  element seed the random number generation. The min and max operands limit
-  the range of the generated random numbers.
-  """
-  domain(D.m, D.n)
-  multiplier = TypeFn.cast_signed(I32, const(1103515245))
-  increment = TypeFn.cast_signed(I32, const(12345))
-  rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
-  rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
-  inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
-  offset = TypeFn.cast_signed(F64, const(2147483647))
-  scaling = (max - min) * inv_range
-  O[D.m, D.n] = TypeFn.cast_signed(
-      T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
+def fill_rng_2d(
+    min=ScalarDef(F64),
+    max=ScalarDef(F64),
+    seed=ScalarDef(I32),
+    O=TensorDef(T, S.M, S.N, output=True),
+):
+    """Fills the output tensor with pseudo random numbers.
+
+    The operation generations pseudo random numbers using a linear congruential
+    generator. It provides no guarantees regarding the distribution of the
+    generated random numbers. Instead of generating the random numbers
+    sequentially, it instantiates one random number generator per data element
+    and runs them in parallel. The seed operand and the indices of the data
+    element seed the random number generation. The min and max operands limit
+    the range of the generated random numbers.
+    """
+    domain(D.m, D.n)
+    multiplier = TypeFn.cast_signed(I32, const(1103515245))
+    increment = TypeFn.cast_signed(I32, const(12345))
+    rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
+    rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
+    inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
+    offset = TypeFn.cast_signed(F64, const(2147483647))
+    scaling = (max - min) * inv_range
+    O[D.m, D.n] = TypeFn.cast_signed(
+        T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min
+    )

diff  --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index ca0d479f1f5fc..980f237b19391 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -5,6 +5,8 @@
 from ._python_test_ops_gen import *
 from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
 
+
 def register_python_test_dialect(context, load=True):
-  from .._mlir_libs import _mlirPythonTest
-  _mlirPythonTest.register_python_test_dialect(context, load)
+    from .._mlir_libs import _mlirPythonTest
+
+    _mlirPythonTest.register_python_test_dialect(context, load)

diff  --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 78956c4370049..b505a490aeb97 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -6,16 +6,18 @@
 
 
 class FailurePropagationMode(Enum):
-  """Propagation mode for silenceable errors."""
-  PROPAGATE = 1
-  SUPPRESS = 2
+    """Propagation mode for silenceable errors."""
 
-  def _as_int(self):
-    if self is FailurePropagationMode.PROPAGATE:
-      return 1
+    PROPAGATE = 1
+    SUPPRESS = 2
+
+    def _as_int(self):
+        if self is FailurePropagationMode.PROPAGATE:
+            return 1
+
+        assert self is FailurePropagationMode.SUPPRESS
+        return 2
 
-    assert self is FailurePropagationMode.SUPPRESS
-    return 2
 
 from .._transform_ops_gen import *
 from ..._mlir_libs._mlirDialectsTransform import *

diff  --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py
index 262545b9ce726..4739231c155ba 100644
--- a/mlir/python/mlir/execution_engine.py
+++ b/mlir/python/mlir/execution_engine.py
@@ -7,37 +7,37 @@
 import ctypes
 
 __all__ = [
-  "ExecutionEngine",
+    "ExecutionEngine",
 ]
 
-class ExecutionEngine(_execution_engine.ExecutionEngine):
 
-  def lookup(self, name):
-    """Lookup a function emitted with the `llvm.emit_c_interface`
-    attribute and returns a ctype callable.
-    Raise a RuntimeError if the function isn't found.
-    """
-    func = self.raw_lookup("_mlir_ciface_" + name)
-    if not func:
-      raise RuntimeError("Unknown function " + name)
-    prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
-    return prototype(func)
+class ExecutionEngine(_execution_engine.ExecutionEngine):
+    def lookup(self, name):
+        """Lookup a function emitted with the `llvm.emit_c_interface`
+        attribute and returns a ctype callable.
+        Raise a RuntimeError if the function isn't found.
+        """
+        func = self.raw_lookup("_mlir_ciface_" + name)
+        if not func:
+            raise RuntimeError("Unknown function " + name)
+        prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
+        return prototype(func)
 
-  def invoke(self, name, *ctypes_args):
-    """Invoke a function with the list of ctypes arguments.
-    All arguments must be pointers.
-    Raise a RuntimeError if the function isn't found.
-    """
-    func = self.lookup(name)
-    packed_args = (ctypes.c_void_p * len(ctypes_args))()
-    for argNum in range(len(ctypes_args)):
-      packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
-    func(packed_args)
+    def invoke(self, name, *ctypes_args):
+        """Invoke a function with the list of ctypes arguments.
+        All arguments must be pointers.
+        Raise a RuntimeError if the function isn't found.
+        """
+        func = self.lookup(name)
+        packed_args = (ctypes.c_void_p * len(ctypes_args))()
+        for argNum in range(len(ctypes_args)):
+            packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
+        func(packed_args)
 
-  def register_runtime(self, name, ctypes_callback):
-    """Register a runtime function available to the jitted code
-    under the provided `name`. The `ctypes_callback` must be a
-    `CFuncType` that outlives the execution engine.
-    """
-    callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
-    self.raw_register_runtime("_mlir_ciface_" + name, callback)
+    def register_runtime(self, name, ctypes_callback):
+        """Register a runtime function available to the jitted code
+        under the provided `name`. The `ctypes_callback` must be a
+        `CFuncType` that outlives the execution engine.
+        """
+        callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
+        self.raw_register_runtime("_mlir_ciface_" + name, callback)

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index be065d4639519..99c21ff9aaef2 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -8,124 +8,123 @@
 
 # Convenience decorator for registering user-friendly Attribute builders.
 def register_attribute_builder(kind):
+    def decorator_builder(func):
+        AttrBuilder.insert(kind, func)
+        return func
 
-  def decorator_builder(func):
-    AttrBuilder.insert(kind, func)
-    return func
-
-  return decorator_builder
+    return decorator_builder
 
 
 @register_attribute_builder("BoolAttr")
 def _boolAttr(x, context):
-  return BoolAttr.get(x, context=context)
+    return BoolAttr.get(x, context=context)
 
 
 @register_attribute_builder("IndexAttr")
 def _indexAttr(x, context):
-  return IntegerAttr.get(IndexType.get(context=context), x)
+    return IntegerAttr.get(IndexType.get(context=context), x)
 
 
 @register_attribute_builder("I16Attr")
 def _i16Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
 
 
 @register_attribute_builder("I32Attr")
 def _i32Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
 
 
 @register_attribute_builder("I64Attr")
 def _i64Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
 
 
 @register_attribute_builder("SI16Attr")
 def _si16Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
 
 
 @register_attribute_builder("SI32Attr")
 def _si32Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
 
 
 @register_attribute_builder("F32Attr")
 def _f32Attr(x, context):
-  return FloatAttr.get_f32(x, context=context)
+    return FloatAttr.get_f32(x, context=context)
 
 
 @register_attribute_builder("F64Attr")
 def _f64Attr(x, context):
-  return FloatAttr.get_f64(x, context=context)
+    return FloatAttr.get_f64(x, context=context)
 
 
 @register_attribute_builder("StrAttr")
 def _stringAttr(x, context):
-  return StringAttr.get(x, context=context)
+    return StringAttr.get(x, context=context)
 
 
 @register_attribute_builder("SymbolNameAttr")
 def _symbolNameAttr(x, context):
-  return StringAttr.get(x, context=context)
+    return StringAttr.get(x, context=context)
 
 
 @register_attribute_builder("SymbolRefAttr")
 def _symbolRefAttr(x, context):
-  return FlatSymbolRefAttr.get(x, context=context)
+    return FlatSymbolRefAttr.get(x, context=context)
 
 
 @register_attribute_builder("ArrayAttr")
 def _arrayAttr(x, context):
-  return ArrayAttr.get(x, context=context)
+    return ArrayAttr.get(x, context=context)
 
 
 @register_attribute_builder("I32ArrayAttr")
 def _i32ArrayAttr(x, context):
-  return ArrayAttr.get([_i32Attr(v, context) for v in x])
+    return ArrayAttr.get([_i32Attr(v, context) for v in x])
 
 
 @register_attribute_builder("I64ArrayAttr")
 def _i64ArrayAttr(x, context):
-  return ArrayAttr.get([_i64Attr(v, context) for v in x])
+    return ArrayAttr.get([_i64Attr(v, context) for v in x])
 
 
 @register_attribute_builder("F32ArrayAttr")
 def _f32ArrayAttr(x, context):
-  return ArrayAttr.get([_f32Attr(v, context) for v in x])
+    return ArrayAttr.get([_f32Attr(v, context) for v in x])
 
 
 @register_attribute_builder("F64ArrayAttr")
 def _f64ArrayAttr(x, context):
-  return ArrayAttr.get([_f64Attr(v, context) for v in x])
+    return ArrayAttr.get([_f64Attr(v, context) for v in x])
 
 
 @register_attribute_builder("DenseI64ArrayAttr")
 def _denseI64ArrayAttr(x, context):
-  return DenseI64ArrayAttr.get(x, context=context)
+    return DenseI64ArrayAttr.get(x, context=context)
 
 
 @register_attribute_builder("TypeAttr")
 def _typeAttr(x, context):
-  return TypeAttr.get(x, context=context)
+    return TypeAttr.get(x, context=context)
 
 
 @register_attribute_builder("TypeArrayAttr")
 def _typeArrayAttr(x, context):
-  return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
+    return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
 
 
 try:
-  import numpy as np
+    import numpy as np
 
-  @register_attribute_builder("IndexElementsAttr")
-  def _indexElementsAttr(x, context):
-    return DenseElementsAttr.get(
-        np.array(x, dtype=np.int64),
-        type=IndexType.get(context=context),
-        context=context,
-    )
+    @register_attribute_builder("IndexElementsAttr")
+    def _indexElementsAttr(x, context):
+        return DenseElementsAttr.get(
+            np.array(x, dtype=np.int64),
+            type=IndexType.get(context=context),
+            context=context,
+        )
 
 except ImportError:
-  pass
+    pass

diff  --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index d70967983c45d..51433d75ac4fb 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -9,131 +9,134 @@
 
 
 class C128(ctypes.Structure):
-  """A ctype representation for MLIR's Double Complex."""
-  _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
+    """A ctype representation for MLIR's Double Complex."""
+
+    _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
 
 
 class C64(ctypes.Structure):
-  """A ctype representation for MLIR's Float Complex."""
-  _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
+    """A ctype representation for MLIR's Float Complex."""
+
+    _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
 
 
 class F16(ctypes.Structure):
-  """A ctype representation for MLIR's Float16."""
-  _fields_ = [("f16", ctypes.c_int16)]
+    """A ctype representation for MLIR's Float16."""
+
+    _fields_ = [("f16", ctypes.c_int16)]
 
 
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
-  """Converts dtype to ctype."""
-  if dtp == np.dtype(np.complex128):
-    return C128
-  if dtp == np.dtype(np.complex64):
-    return C64
-  if dtp == np.dtype(np.float16):
-    return F16
-  return np.ctypeslib.as_ctypes_type(dtp)
+    """Converts dtype to ctype."""
+    if dtp == np.dtype(np.complex128):
+        return C128
+    if dtp == np.dtype(np.complex64):
+        return C64
+    if dtp == np.dtype(np.float16):
+        return F16
+    return np.ctypeslib.as_ctypes_type(dtp)
 
 
 def to_numpy(array):
-  """Converts ctypes array back to numpy dtype array."""
-  if array.dtype == C128:
-    return array.view("complex128")
-  if array.dtype == C64:
-    return array.view("complex64")
-  if array.dtype == F16:
-    return array.view("float16")
-  return array
+    """Converts ctypes array back to numpy dtype array."""
+    if array.dtype == C128:
+        return array.view("complex128")
+    if array.dtype == C64:
+        return array.view("complex64")
+    if array.dtype == F16:
+        return array.view("float16")
+    return array
 
 
 def make_nd_memref_descriptor(rank, dtype):
+    class MemRefDescriptor(ctypes.Structure):
+        """Builds an empty descriptor for the given rank/dtype, where rank>0."""
 
-  class MemRefDescriptor(ctypes.Structure):
-    """Builds an empty descriptor for the given rank/dtype, where rank>0."""
+        _fields_ = [
+            ("allocated", ctypes.c_longlong),
+            ("aligned", ctypes.POINTER(dtype)),
+            ("offset", ctypes.c_longlong),
+            ("shape", ctypes.c_longlong * rank),
+            ("strides", ctypes.c_longlong * rank),
+        ]
 
-    _fields_ = [
-        ("allocated", ctypes.c_longlong),
-        ("aligned", ctypes.POINTER(dtype)),
-        ("offset", ctypes.c_longlong),
-        ("shape", ctypes.c_longlong * rank),
-        ("strides", ctypes.c_longlong * rank),
-    ]
-
-  return MemRefDescriptor
+    return MemRefDescriptor
 
 
 def make_zero_d_memref_descriptor(dtype):
+    class MemRefDescriptor(ctypes.Structure):
+        """Builds an empty descriptor for the given dtype, where rank=0."""
 
-  class MemRefDescriptor(ctypes.Structure):
-    """Builds an empty descriptor for the given dtype, where rank=0."""
-
-    _fields_ = [
-        ("allocated", ctypes.c_longlong),
-        ("aligned", ctypes.POINTER(dtype)),
-        ("offset", ctypes.c_longlong),
-    ]
+        _fields_ = [
+            ("allocated", ctypes.c_longlong),
+            ("aligned", ctypes.POINTER(dtype)),
+            ("offset", ctypes.c_longlong),
+        ]
 
-  return MemRefDescriptor
+    return MemRefDescriptor
 
 
 class UnrankedMemRefDescriptor(ctypes.Structure):
-  """Creates a ctype struct for memref descriptor"""
-  _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
+    """Creates a ctype struct for memref descriptor"""
+
+    _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
 
 
 def get_ranked_memref_descriptor(nparray):
-  """Returns a ranked memref descriptor for the given numpy array."""
-  ctp = as_ctype(nparray.dtype)
-  if nparray.ndim == 0:
-    x = make_zero_d_memref_descriptor(ctp)()
+    """Returns a ranked memref descriptor for the given numpy array."""
+    ctp = as_ctype(nparray.dtype)
+    if nparray.ndim == 0:
+        x = make_zero_d_memref_descriptor(ctp)()
+        x.allocated = nparray.ctypes.data
+        x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
+        x.offset = ctypes.c_longlong(0)
+        return x
+
+    x = make_nd_memref_descriptor(nparray.ndim, ctp)()
     x.allocated = nparray.ctypes.data
     x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
     x.offset = ctypes.c_longlong(0)
-    return x
+    x.shape = nparray.ctypes.shape
 
-  x = make_nd_memref_descriptor(nparray.ndim, ctp)()
-  x.allocated = nparray.ctypes.data
-  x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
-  x.offset = ctypes.c_longlong(0)
-  x.shape = nparray.ctypes.shape
-
-  # Numpy uses byte quantities to express strides, MLIR OTOH uses the
-  # torch abstraction which specifies strides in terms of elements.
-  strides_ctype_t = ctypes.c_longlong * nparray.ndim
-  x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
-  return x
+    # Numpy uses byte quantities to express strides, MLIR OTOH uses the
+    # torch abstraction which specifies strides in terms of elements.
+    strides_ctype_t = ctypes.c_longlong * nparray.ndim
+    x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
+    return x
 
 
 def get_unranked_memref_descriptor(nparray):
-  """Returns a generic/unranked memref descriptor for the given numpy array."""
-  d = UnrankedMemRefDescriptor()
-  d.rank = nparray.ndim
-  x = get_ranked_memref_descriptor(nparray)
-  d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
-  return d
+    """Returns a generic/unranked memref descriptor for the given numpy array."""
+    d = UnrankedMemRefDescriptor()
+    d.rank = nparray.ndim
+    x = get_ranked_memref_descriptor(nparray)
+    d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
+    return d
 
 
 def unranked_memref_to_numpy(unranked_memref, np_dtype):
-  """Converts unranked memrefs to numpy arrays."""
-  ctp = as_ctype(np_dtype)
-  descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
-  val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
-  np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
-  strided_arr = np.lib.stride_tricks.as_strided(
-      np_arr,
-      np.ctypeslib.as_array(val[0].shape),
-      np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
-  )
-  return to_numpy(strided_arr)
+    """Converts unranked memrefs to numpy arrays."""
+    ctp = as_ctype(np_dtype)
+    descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
+    val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
+    np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
+    strided_arr = np.lib.stride_tricks.as_strided(
+        np_arr,
+        np.ctypeslib.as_array(val[0].shape),
+        np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
+    )
+    return to_numpy(strided_arr)
 
 
 def ranked_memref_to_numpy(ranked_memref):
-  """Converts ranked memrefs to numpy arrays."""
-  np_arr = np.ctypeslib.as_array(
-      ranked_memref[0].aligned, shape=ranked_memref[0].shape)
-  strided_arr = np.lib.stride_tricks.as_strided(
-      np_arr,
-      np.ctypeslib.as_array(ranked_memref[0].shape),
-      np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
-  )
-  return to_numpy(strided_arr)
+    """Converts ranked memrefs to numpy arrays."""
+    np_arr = np.ctypeslib.as_array(
+        ranked_memref[0].aligned, shape=ranked_memref[0].shape
+    )
+    strided_arr = np.lib.stride_tricks.as_strided(
+        np_arr,
+        np.ctypeslib.as_array(ranked_memref[0].shape),
+        np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
+    )
+    return to_numpy(strided_arr)

diff  --git a/mlir/test/CAPI/lit.local.cfg b/mlir/test/CAPI/lit.local.cfg
index f08a0de488ddd..bb0c17cdbada7 100644
--- a/mlir/test/CAPI/lit.local.cfg
+++ b/mlir/test/CAPI/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.add('.c')
+config.suffixes.add(".c")

diff  --git a/mlir/test/Conversion/GPUToCUDA/lit.local.cfg b/mlir/test/Conversion/GPUToCUDA/lit.local.cfg
index 847c3efbf24ac..bc470ccc5733a 100644
--- a/mlir/test/Conversion/GPUToCUDA/lit.local.cfg
+++ b/mlir/test/Conversion/GPUToCUDA/lit.local.cfg
@@ -1,2 +1,2 @@
 if not config.run_cuda_tests:
-  config.unsupported = True
\ No newline at end of file
+    config.unsupported = True

diff  --git a/mlir/test/Conversion/GPUToROCm/lit.local.cfg b/mlir/test/Conversion/GPUToROCm/lit.local.cfg
index 6eb561783b3fb..2f5cc9f3bad97 100644
--- a/mlir/test/Conversion/GPUToROCm/lit.local.cfg
+++ b/mlir/test/Conversion/GPUToROCm/lit.local.cfg
@@ -1,2 +1,2 @@
 if not config.run_rocm_tests:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/Examples/Toy/Ch6/lit.local.cfg b/mlir/test/Examples/Toy/Ch6/lit.local.cfg
index c5aeb13c427c5..0d9aa1006cf8b 100644
--- a/mlir/test/Examples/Toy/Ch6/lit.local.cfg
+++ b/mlir/test/Examples/Toy/Ch6/lit.local.cfg
@@ -1,5 +1,3 @@
 # Requires native execution.
-if 'host-supports-jit' not in config.available_features:
+if "host-supports-jit" not in config.available_features:
     config.unsupported = True
-
-

diff  --git a/mlir/test/Examples/Toy/Ch7/lit.local.cfg b/mlir/test/Examples/Toy/Ch7/lit.local.cfg
index c5aeb13c427c5..0d9aa1006cf8b 100644
--- a/mlir/test/Examples/Toy/Ch7/lit.local.cfg
+++ b/mlir/test/Examples/Toy/Ch7/lit.local.cfg
@@ -1,5 +1,3 @@
 # Requires native execution.
-if 'host-supports-jit' not in config.available_features:
+if "host-supports-jit" not in config.available_features:
     config.unsupported = True
-
-

diff  --git a/mlir/test/Examples/lit.local.cfg b/mlir/test/Examples/lit.local.cfg
index 97db322f29e29..1a51296e2b60a 100644
--- a/mlir/test/Examples/lit.local.cfg
+++ b/mlir/test/Examples/lit.local.cfg
@@ -1,2 +1,2 @@
 if not config.build_examples:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/Examples/standalone/lit.local.cfg b/mlir/test/Examples/standalone/lit.local.cfg
index cf7c8ff4d54f7..fe8397c6b9a10 100644
--- a/mlir/test/Examples/standalone/lit.local.cfg
+++ b/mlir/test/Examples/standalone/lit.local.cfg
@@ -1,13 +1,12 @@
 # Disable with sanitizers for now, this require some more setup apparently.
-for san in ['asan', 'msan', 'ubsan']:
-   if (san in config.available_features):
-      config.unsupported = True
+for san in ["asan", "msan", "ubsan"]:
+    if san in config.available_features:
+        config.unsupported = True
 
 config.substitutions.append(("%cmake_exe", config.host_cmake))
 config.substitutions.append(("%cmake_generator", config.host_cmake_generator))
 config.substitutions.append(("%host_cxx", config.host_cxx))
 config.substitutions.append(("%host_cc", config.host_cc))
 config.substitutions.append(("%enable_libcxx", config.enable_libcxx))
-config.substitutions.append(
-    ("%mlir_cmake_dir", config.mlir_cmake_dir))
+config.substitutions.append(("%mlir_cmake_dir", config.mlir_cmake_dir))
 config.substitutions.append(("%llvm_use_linker", config.llvm_use_linker))

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg b/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg
index 7215edacf7a83..073f6373d8581 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg
@@ -1,5 +1,5 @@
 import sys
 
 # Windows does not have aligned_alloc
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg b/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg
index 263c8f8a12283..071a13c27769e 100644
--- a/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg
@@ -1,4 +1,4 @@
 import platform
 
-if platform.machine() != 'x86_64':
+if platform.machine() != "x86_64":
     config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg b/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg
index 7d1e494eb96e9..3214a1101f084 100644
--- a/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg
@@ -1,18 +1,22 @@
 import sys
 
-lli_cmd = 'lli'
+lli_cmd = "lli"
 if config.riscv_emulator_lli_executable:
     lli_cmd = config.riscv_emulator_lli_executable
 
-config.substitutions.append(('%mlir_native_utils_lib_dir',
-    config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir))
+config.substitutions.append(
+    (
+        "%mlir_native_utils_lib_dir",
+        config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir,
+    )
+)
 
 if config.riscv_vector_emulator_executable:
     # Run test in qemu emulator.
     emulation_cmd = config.riscv_vector_emulator_executable
     if config.riscv_vector_emulator_options:
-        emulation_cmd = emulation_cmd + ' ' + config.riscv_vector_emulator_options
-    emulation_cmd = emulation_cmd + ' ' + lli_cmd + ' --march=riscv64 -mattr=+v '
-    config.substitutions.append(('%lli', emulation_cmd))
+        emulation_cmd = emulation_cmd + " " + config.riscv_vector_emulator_options
+    emulation_cmd = emulation_cmd + " " + lli_cmd + " --march=riscv64 -mattr=+v "
+    config.substitutions.append(("%lli", emulation_cmd))
 else:
-    config.substitutions.append(('%lli', lli_cmd))
+    config.substitutions.append(("%lli", lli_cmd))

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg
index 9bf49cc8246ec..6e07eb87987a4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg
@@ -2,14 +2,16 @@ import sys
 from lit.llvm import llvm_config
 
 # FIXME: %mlir_native_utils_lib_dir is set incorrectly on Windows
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True
 
 # ArmSVE tests must be enabled via build flag.
 if config.mlir_run_arm_sve_tests:
-    config.substitutions.append(('%ENABLE_VLA', 'true'))
-    config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', '--march=aarch64 --mattr="+sve"'))
+    config.substitutions.append(("%ENABLE_VLA", "true"))
+    config.substitutions.append(
+        ("%VLA_ARCH_ATTR_OPTIONS", '--march=aarch64 --mattr="+sve"')
+    )
 else:
-    config.substitutions.append(('%ENABLE_VLA', 'false'))
-    config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', ''))
-    config.substitutions.append(('%mlir_native_utils_lib_dir', config.mlir_lib_dir))
+    config.substitutions.append(("%ENABLE_VLA", "false"))
+    config.substitutions.append(("%VLA_ARCH_ATTR_OPTIONS", ""))
+    config.substitutions.append(("%mlir_native_utils_lib_dir", config.mlir_lib_dir))

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg
index c586aae6475bd..6788ccea3a222 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg
@@ -1,2 +1,2 @@
 if not config.enable_cuda_runner or not config.mlir_run_cuda_sm80_tests:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg
index cf04454dea6ef..361b657dd2d83 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg
@@ -1,5 +1,5 @@
 # Disable ASAN's leak detection for python OpsDSL tests.
-config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
+config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
 # Only run when python bindings are enabled.
 if not config.enable_bindings_python:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
index 958aa86752a07..1f9b636038318 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
@@ -18,42 +18,45 @@
 sys.path.append(_SCRIPT_PATH)
 from tools import sparse_compiler
 
+
 @dsl.linalg_structured_op
 def sddmm_dsl(
     A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
     B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
     S=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N),
-    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
-  C[dsl.D.m,
-    dsl.D.n] += S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
+):
+    C[dsl.D.m, dsl.D.n] += (
+        S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+    )
 
 
 def build_SDDMM(attr: st.EncodingAttr):
-  """Build SDDMM kernel.
+    """Build SDDMM kernel.
 
-  This method generates a linalg op with for matrix multiplication using
-  just the Python API. Effectively, a generic linalg op is constructed
-  that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
-  """
-  module = ir.Module.create()
-  f64 = ir.F64Type.get()
-  a = ir.RankedTensorType.get([8, 8], f64)
-  b = ir.RankedTensorType.get([8, 8], f64)
-  c = ir.RankedTensorType.get([8, 8], f64)
-  s = ir.RankedTensorType.get([8, 8], f64, attr)
-  arguments = [a, b, s, c]
-  with ir.InsertionPoint(module.body):
+    This method generates a linalg op with for matrix multiplication using
+    just the Python API. Effectively, a generic linalg op is constructed
+    that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
+    """
+    module = ir.Module.create()
+    f64 = ir.F64Type.get()
+    a = ir.RankedTensorType.get([8, 8], f64)
+    b = ir.RankedTensorType.get([8, 8], f64)
+    c = ir.RankedTensorType.get([8, 8], f64)
+    s = ir.RankedTensorType.get([8, 8], f64, attr)
+    arguments = [a, b, s, c]
+    with ir.InsertionPoint(module.body):
 
-    @func.FuncOp.from_py_func(*arguments)
-    def sddmm(*args):
-      return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
+        @func.FuncOp.from_py_func(*arguments)
+        def sddmm(*args):
+            return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
 
-  return module
+    return module
 
 
 def boilerplate(attr: st.EncodingAttr):
-  """Returns boilerplate code for main driver."""
-  return f"""
+    """Returns boilerplate code for main driver."""
+    return f"""
 func.func @main(%a: tensor<8x8xf64>,
            %b: tensor<8x8xf64>,
            %c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{
@@ -69,92 +72,100 @@ def boilerplate(attr: st.EncodingAttr):
 
 
 def build_compile_and_run_SDDMMM(attr: st.EncodingAttr, compiler):
-  # Build.
-  module = build_SDDMM(attr)
-  func = str(module.operation.regions[0].blocks[0].operations[0].operation)
-  module = ir.Module.parse(func + boilerplate(attr))
-
-  # Compile.
-  engine = compiler.compile_and_jit(module)
-
-  # Set up numpy input and buffer for output.
-  a = np.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
-                [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
-                [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
-                [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
-                [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
-                [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
-                [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
-                [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8]], np.float64)
-  b = np.ones((8, 8), np.float64)
-  c = np.zeros((8, 8), np.float64)
-
-  mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
-  mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
-  mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-
-  # Allocate a MemRefDescriptor to receive the output tensor.
-  # The buffer itself is allocated inside the MLIR code generation.
-  ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
-  mem_out = ctypes.pointer(ctypes.pointer(ref_out))
-
-  # Invoke the kernel and get numpy output.
-  # Built-in bufferization uses in-out buffers.
-  # TODO: replace with inplace comprehensive bufferization.
-  engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
-
-  # Sanity check on computed result. Only a few elements
-  # are sampled from the full dense matrix multiplication.
-  full_matmul = np.matmul(a, b)
-  expected = np.zeros((8, 8), np.float64)
-  expected[0, 0] = 1.0 * full_matmul[0, 0]
-  expected[0, 2] = 2.0 * full_matmul[0, 2]
-  expected[4, 1] = 3.0 * full_matmul[4, 1]
-  c = rt.ranked_memref_to_numpy(mem_out[0])
-  if np.allclose(c, expected):
-    pass
-  else:
-    quit(f'FAILURE')
+    # Build.
+    module = build_SDDMM(attr)
+    func = str(module.operation.regions[0].blocks[0].operations[0].operation)
+    module = ir.Module.parse(func + boilerplate(attr))
+
+    # Compile.
+    engine = compiler.compile_and_jit(module)
+
+    # Set up numpy input and buffer for output.
+    a = np.array(
+        [
+            [1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
+            [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
+            [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
+            [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
+            [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
+            [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
+            [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
+            [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8],
+        ],
+        np.float64,
+    )
+    b = np.ones((8, 8), np.float64)
+    c = np.zeros((8, 8), np.float64)
+
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+
+    # Allocate a MemRefDescriptor to receive the output tensor.
+    # The buffer itself is allocated inside the MLIR code generation.
+    ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
+    mem_out = ctypes.pointer(ctypes.pointer(ref_out))
+
+    # Invoke the kernel and get numpy output.
+    # Built-in bufferization uses in-out buffers.
+    # TODO: replace with inplace comprehensive bufferization.
+    engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
+
+    # Sanity check on computed result. Only a few elements
+    # are sampled from the full dense matrix multiplication.
+    full_matmul = np.matmul(a, b)
+    expected = np.zeros((8, 8), np.float64)
+    expected[0, 0] = 1.0 * full_matmul[0, 0]
+    expected[0, 2] = 2.0 * full_matmul[0, 2]
+    expected[4, 1] = 3.0 * full_matmul[4, 1]
+    c = rt.ranked_memref_to_numpy(mem_out[0])
+    if np.allclose(c, expected):
+        pass
+    else:
+        quit(f"FAILURE")
 
 
 def main():
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  if not os.path.exists(support_lib):
-    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
-                            support_lib)
-
-  # CHECK-LABEL: TEST: testSDDMMM
-  print('\nTEST: testSDDMMM')
-  with ir.Context() as ctx, ir.Location.unknown():
-    count = 0
-    # Loop over various ways to compile and annotate the SDDMM kernel with
-    # a *single* sparse tensor. Note that we deliberate do not exhaustively
-    # search the full state space to reduce runtime of the test. It is
-    # straightforward to adapt the code below to explore more combinations.
-    levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
-              [st.DimLevelType.dense, st.DimLevelType.compressed],
-              [st.DimLevelType.compressed, st.DimLevelType.dense],
-              [st.DimLevelType.compressed, st.DimLevelType.compressed]]
-    orderings = [
-        ir.AffineMap.get_permutation([0, 1]),
-        ir.AffineMap.get_permutation([1, 0])
-    ]
-    for level in levels:
-      for ordering in orderings:
-        for pwidth in [32]:
-          for iwidth in [32]:
-            for e in [True]:
-              attr = st.EncodingAttr.get(level, ordering, None, pwidth,
-                                         iwidth)
-              opt = (f'parallelization-strategy=none')
-              compiler = sparse_compiler.SparseCompiler(
-                  options=opt, opt_level=0, shared_libs=[support_lib])
-              build_compile_and_run_SDDMMM(attr, compiler)
-              count = count + 1
-  # CHECK: Passed 8 tests
-  print('Passed ', count, 'tests')
-
-
-if __name__ == '__main__':
-  main()
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    if not os.path.exists(support_lib):
+        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+    # CHECK-LABEL: TEST: testSDDMMM
+    print("\nTEST: testSDDMMM")
+    with ir.Context() as ctx, ir.Location.unknown():
+        count = 0
+        # Loop over various ways to compile and annotate the SDDMM kernel with
+        # a *single* sparse tensor. Note that we deliberate do not exhaustively
+        # search the full state space to reduce runtime of the test. It is
+        # straightforward to adapt the code below to explore more combinations.
+        levels = [
+            [st.DimLevelType.dense, st.DimLevelType.dense],
+            [st.DimLevelType.dense, st.DimLevelType.compressed],
+            [st.DimLevelType.compressed, st.DimLevelType.dense],
+            [st.DimLevelType.compressed, st.DimLevelType.compressed],
+        ]
+        orderings = [
+            ir.AffineMap.get_permutation([0, 1]),
+            ir.AffineMap.get_permutation([1, 0]),
+        ]
+        for level in levels:
+            for ordering in orderings:
+                for pwidth in [32]:
+                    for iwidth in [32]:
+                        for e in [True]:
+                            attr = st.EncodingAttr.get(
+                                level, ordering, None, pwidth, iwidth
+                            )
+                            opt = f"parallelization-strategy=none"
+                            compiler = sparse_compiler.SparseCompiler(
+                                options=opt, opt_level=0, shared_libs=[support_lib]
+                            )
+                            build_compile_and_run_SDDMMM(attr, compiler)
+                            count = count + 1
+    # CHECK: Passed 8 tests
+    print("Passed ", count, "tests")
+
+
+if __name__ == "__main__":
+    main()

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
index 97954ce08ced1..69f6cdcea967f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
@@ -18,45 +18,47 @@
 sys.path.append(_SCRIPT_PATH)
 from tools import sparse_compiler
 
+
 @dsl.linalg_structured_op
 def matmul_dsl(
     A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
     B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
-    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
-  C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
+):
+    C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
 
 
 def build_SpMM(attr: st.EncodingAttr):
-  """Build SpMM kernel.
+    """Build SpMM kernel.
 
-  This method generates a linalg op with for matrix multiplication using
-  just the Python API. Effectively, a generic linalg op is constructed
-  that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
-  """
-  module = ir.Module.create()
-  f64 = ir.F64Type.get()
-  a = ir.RankedTensorType.get([3, 4], f64, attr)
-  b = ir.RankedTensorType.get([4, 2], f64)
-  c = ir.RankedTensorType.get([3, 2], f64)
-  arguments = [a, b, c]
-  with ir.InsertionPoint(module.body):
+    This method generates a linalg op with for matrix multiplication using
+    just the Python API. Effectively, a generic linalg op is constructed
+    that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
+    """
+    module = ir.Module.create()
+    f64 = ir.F64Type.get()
+    a = ir.RankedTensorType.get([3, 4], f64, attr)
+    b = ir.RankedTensorType.get([4, 2], f64)
+    c = ir.RankedTensorType.get([3, 2], f64)
+    arguments = [a, b, c]
+    with ir.InsertionPoint(module.body):
 
-    @func.FuncOp.from_py_func(*arguments)
-    def spMxM(*args):
-      return matmul_dsl(args[0], args[1], outs=[args[2]])
+        @func.FuncOp.from_py_func(*arguments)
+        def spMxM(*args):
+            return matmul_dsl(args[0], args[1], outs=[args[2]])
 
-  return module
+    return module
 
 
 def boilerplate(attr: st.EncodingAttr):
-  """Returns boilerplate main method.
-
-  This method sets up a boilerplate main method that takes three tensors
-  (a, b, c), converts the first tensor a into s sparse tensor, and then
-  calls the sparse kernel for matrix multiplication. For convenience,
-  this part is purely done as string input.
-  """
-  return f"""
+    """Returns boilerplate main method.
+
+    This method sets up a boilerplate main method that takes three tensors
+    (a, b, c), converts the first tensor a into s sparse tensor, and then
+    calls the sparse kernel for matrix multiplication. For convenience,
+    this part is purely done as string input.
+    """
+    return f"""
 func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64>
   attributes {{ llvm.emit_c_interface }} {{
   %a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
@@ -69,82 +71,87 @@ def boilerplate(attr: st.EncodingAttr):
 
 
 def build_compile_and_run_SpMM(attr: st.EncodingAttr, compiler):
-  # Build.
-  module = build_SpMM(attr)
-  func = str(module.operation.regions[0].blocks[0].operations[0].operation)
-  module = ir.Module.parse(func + boilerplate(attr))
-
-  # Compile.
-  engine = compiler.compile_and_jit(module)
-
-  # Set up numpy input and buffer for output.
-  a = np.array(
-      [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]],
-      np.float64)
-  b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
-  c = np.zeros((3, 2), np.float64)
-
-  mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
-  mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
-  mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-  # Allocate a MemRefDescriptor to receive the output tensor.
-  # The buffer itself is allocated inside the MLIR code generation.
-  ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
-  mem_out = ctypes.pointer(ctypes.pointer(ref_out))
-
-  # Invoke the kernel and get numpy output.
-  # Built-in bufferization uses in-out buffers.
-  # TODO: replace with inplace comprehensive bufferization.
-  engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
-
-  # Sanity check on computed result.
-  expected = np.matmul(a, b);
-  c = rt.ranked_memref_to_numpy(mem_out[0])
-  if np.allclose(c, expected):
-    pass
-  else:
-    quit(f'FAILURE')
+    # Build.
+    module = build_SpMM(attr)
+    func = str(module.operation.regions[0].blocks[0].operations[0].operation)
+    module = ir.Module.parse(func + boilerplate(attr))
+
+    # Compile.
+    engine = compiler.compile_and_jit(module)
+
+    # Set up numpy input and buffer for output.
+    a = np.array(
+        [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], np.float64
+    )
+    b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
+    c = np.zeros((3, 2), np.float64)
+
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+    # Allocate a MemRefDescriptor to receive the output tensor.
+    # The buffer itself is allocated inside the MLIR code generation.
+    ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
+    mem_out = ctypes.pointer(ctypes.pointer(ref_out))
+
+    # Invoke the kernel and get numpy output.
+    # Built-in bufferization uses in-out buffers.
+    # TODO: replace with inplace comprehensive bufferization.
+    engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
+
+    # Sanity check on computed result.
+    expected = np.matmul(a, b)
+    c = rt.ranked_memref_to_numpy(mem_out[0])
+    if np.allclose(c, expected):
+        pass
+    else:
+        quit(f"FAILURE")
 
 
 def main():
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  if not os.path.exists(support_lib):
-    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-
-  # CHECK-LABEL: TEST: testSpMM
-  print('\nTEST: testSpMM')
-  with ir.Context() as ctx, ir.Location.unknown():
-    count = 0
-    # Loop over various ways to compile and annotate the SpMM kernel with
-    # a *single* sparse tensor. Note that we deliberate do not exhaustively
-    # search the full state space to reduce runtime of the test. It is
-    # straightforward to adapt the code below to explore more combinations.
-
-    vl = 1
-    e = False
-    opt = (f'parallelization-strategy=none')
-    levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
-              [st.DimLevelType.dense, st.DimLevelType.compressed],
-              [st.DimLevelType.compressed, st.DimLevelType.dense],
-              [st.DimLevelType.compressed, st.DimLevelType.compressed]]
-    orderings = [
-        ir.AffineMap.get_permutation([0, 1]),
-        ir.AffineMap.get_permutation([1, 0])
-    ]
-    bitwidths = [0]
-    compiler = sparse_compiler.SparseCompiler(
-        options=opt, opt_level=0, shared_libs=[support_lib])
-    for level in levels:
-      for ordering in orderings:
-        for pwidth in bitwidths:
-          for iwidth in bitwidths:
-            attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
-            build_compile_and_run_SpMM(attr, compiler)
-            count = count + 1
-    # CHECK: Passed 8 tests
-    print('Passed ', count, 'tests')
-
-
-if __name__ == '__main__':
-  main()
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    if not os.path.exists(support_lib):
+        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+    # CHECK-LABEL: TEST: testSpMM
+    print("\nTEST: testSpMM")
+    with ir.Context() as ctx, ir.Location.unknown():
+        count = 0
+        # Loop over various ways to compile and annotate the SpMM kernel with
+        # a *single* sparse tensor. Note that we deliberate do not exhaustively
+        # search the full state space to reduce runtime of the test. It is
+        # straightforward to adapt the code below to explore more combinations.
+
+        vl = 1
+        e = False
+        opt = f"parallelization-strategy=none"
+        levels = [
+            [st.DimLevelType.dense, st.DimLevelType.dense],
+            [st.DimLevelType.dense, st.DimLevelType.compressed],
+            [st.DimLevelType.compressed, st.DimLevelType.dense],
+            [st.DimLevelType.compressed, st.DimLevelType.compressed],
+        ]
+        orderings = [
+            ir.AffineMap.get_permutation([0, 1]),
+            ir.AffineMap.get_permutation([1, 0]),
+        ]
+        bitwidths = [0]
+        compiler = sparse_compiler.SparseCompiler(
+            options=opt, opt_level=0, shared_libs=[support_lib]
+        )
+        for level in levels:
+            for ordering in orderings:
+                for pwidth in bitwidths:
+                    for iwidth in bitwidths:
+                        attr = st.EncodingAttr.get(
+                            level, ordering, None, pwidth, iwidth
+                        )
+                        build_compile_and_run_SpMM(attr, compiler)
+                        count = count + 1
+        # CHECK: Passed 8 tests
+        print("Passed ", count, "tests")
+
+
+if __name__ == "__main__":
+    main()

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
index b29b029c7a331..a41bde1ee2d34 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
@@ -57,49 +57,52 @@
 
 
 def _run_test(support_lib, kernel):
-  """Compiles, runs and checks results."""
-  compiler = sparse_compiler.SparseCompiler(
-      options='', opt_level=2, shared_libs=[support_lib])
-  module = ir.Module.parse(kernel)
-  engine = compiler.compile_and_jit(module)
-
-  # Set up numpy inputs and buffer for output.
-  a = np.array(
-      [[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]],
-      np.float64)
-  b = np.array(
-      [[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
-      np.float64)
-
-  mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
-  mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
-
-  # The sparse tensor output is a pointer to pointer of char.
-  out = ctypes.c_char(0)
-  mem_out = ctypes.pointer(ctypes.pointer(out))
-
-  # Invoke the kernel.
-  engine.invoke('main', mem_a, mem_b, mem_out)
-
-  # Retrieve and check the result.
-  rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
-      support_lib, mem_out[0], np.float64)
-
-  # CHECK: PASSED
-  if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
-      indices, [[0, 0], [0, 3], [2, 2]]):
-    print('PASSED')
-  else:
-    quit('FAILURE')
+    """Compiles, runs and checks results."""
+    compiler = sparse_compiler.SparseCompiler(
+        options="", opt_level=2, shared_libs=[support_lib]
+    )
+    module = ir.Module.parse(kernel)
+    engine = compiler.compile_and_jit(module)
+
+    # Set up numpy inputs and buffer for output.
+    a = np.array(
+        [[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]], np.float64
+    )
+    b = np.array(
+        [[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], np.float64
+    )
+
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+
+    # The sparse tensor output is a pointer to pointer of char.
+    out = ctypes.c_char(0)
+    mem_out = ctypes.pointer(ctypes.pointer(out))
+
+    # Invoke the kernel.
+    engine.invoke("main", mem_a, mem_b, mem_out)
+
+    # Retrieve and check the result.
+    rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
+        support_lib, mem_out[0], np.float64
+    )
+
+    # CHECK: PASSED
+    if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
+        indices, [[0, 0], [0, 3], [2, 2]]
+    ):
+        print("PASSED")
+    else:
+        quit("FAILURE")
 
 
 def test_elementwise_add():
-  # Obtain path to runtime support library.
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  assert os.path.exists(support_lib), f'{support_lib} does not exist'
-  with ir.Context() as ctx, ir.Location.unknown():
-    _run_test(support_lib, _KERNEL_STR)
+    # Obtain path to runtime support library.
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    assert os.path.exists(support_lib), f"{support_lib} does not exist"
+    with ir.Context() as ctx, ir.Location.unknown():
+        _run_test(support_lib, _KERNEL_STR)
 
 
 test_elementwise_add()

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 7d57b1c901948..7d77490080205 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -18,8 +18,8 @@
 
 # TODO: move more into actual IR building.
 def boilerplate(attr: st.EncodingAttr):
-  """Returns boilerplate main method."""
-  return f"""
+    """Returns boilerplate main method."""
+    return f"""
 func.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{
   %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
                              [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
@@ -31,13 +31,13 @@ def boilerplate(attr: st.EncodingAttr):
 
 
 def expected():
-  """Returns expected contents of output.
+    """Returns expected contents of output.
 
-  Regardless of the dimension ordering, compression, and bitwidths that are
-  used in the sparse tensor, the output is always lexicographically sorted
-  by natural index order.
-  """
-  return f"""; extended FROSTT format
+    Regardless of the dimension ordering, compression, and bitwidths that are
+    used in the sparse tensor, the output is always lexicographically sorted
+    by natural index order.
+    """
+    return f"""; extended FROSTT format
 2 5
 10 10
 1 1 1
@@ -49,53 +49,55 @@ def expected():
 
 
 def build_compile_and_run_output(attr: st.EncodingAttr, compiler):
-  # Build and Compile.
-  module = ir.Module.parse(boilerplate(attr))
-  engine = compiler.compile_and_jit(module)
+    # Build and Compile.
+    module = ir.Module.parse(boilerplate(attr))
+    engine = compiler.compile_and_jit(module)
 
-  # Invoke the kernel and compare output.
-  with tempfile.TemporaryDirectory() as test_dir:
-    out = os.path.join(test_dir, 'out.tns')
-    buf = out.encode('utf-8')
-    mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
-    engine.invoke('main', mem_a)
+    # Invoke the kernel and compare output.
+    with tempfile.TemporaryDirectory() as test_dir:
+        out = os.path.join(test_dir, "out.tns")
+        buf = out.encode("utf-8")
+        mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
+        engine.invoke("main", mem_a)
 
-    actual = open(out).read()
-    if actual != expected():
-      quit('FAILURE')
+        actual = open(out).read()
+        if actual != expected():
+            quit("FAILURE")
 
 
 def main():
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  if not os.path.exists(support_lib):
-    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
-                            support_lib)
-
-  # CHECK-LABEL: TEST: test_output
-  print('\nTEST: test_output')
-  count = 0
-  with ir.Context() as ctx, ir.Location.unknown():
-    # Loop over various sparse types: CSR, DCSR, CSC, DCSC.
-    levels = [[st.DimLevelType.dense, st.DimLevelType.compressed],
-              [st.DimLevelType.compressed, st.DimLevelType.compressed]]
-    orderings = [
-        ir.AffineMap.get_permutation([0, 1]),
-        ir.AffineMap.get_permutation([1, 0])
-    ]
-    bitwidths = [8, 16, 32, 64]
-    compiler = sparse_compiler.SparseCompiler(
-        options='', opt_level=2, shared_libs=[support_lib])
-    for level in levels:
-      for ordering in orderings:
-        for bwidth in bitwidths:
-          attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
-          build_compile_and_run_output(attr, compiler)
-          count = count + 1
-
-  # CHECK: Passed 16 tests
-  print('Passed', count, 'tests')
-
-
-if __name__ == '__main__':
-  main()
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    if not os.path.exists(support_lib):
+        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+    # CHECK-LABEL: TEST: test_output
+    print("\nTEST: test_output")
+    count = 0
+    with ir.Context() as ctx, ir.Location.unknown():
+        # Loop over various sparse types: CSR, DCSR, CSC, DCSC.
+        levels = [
+            [st.DimLevelType.dense, st.DimLevelType.compressed],
+            [st.DimLevelType.compressed, st.DimLevelType.compressed],
+        ]
+        orderings = [
+            ir.AffineMap.get_permutation([0, 1]),
+            ir.AffineMap.get_permutation([1, 0]),
+        ]
+        bitwidths = [8, 16, 32, 64]
+        compiler = sparse_compiler.SparseCompiler(
+            options="", opt_level=2, shared_libs=[support_lib]
+        )
+        for level in levels:
+            for ordering in orderings:
+                for bwidth in bitwidths:
+                    attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
+                    build_compile_and_run_output(attr, compiler)
+                    count = count + 1
+
+    # CHECK: Passed 16 tests
+    print("Passed", count, "tests")
+
+
+if __name__ == "__main__":
+    main()

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
index 3a04e5b9ab5ca..373f7457e0b5f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
@@ -28,216 +28,241 @@
 # TODO: move this boilerplate to its own module, so it can be used by
 # other tests and programs.
 class TypeConverter:
-  """Converter between NumPy types and MLIR types."""
-
-  def __init__(self, context: ir.Context):
-    # Note 1: these are numpy "scalar types" (i.e., the values of
-    # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
-    #
-    # Note 2: we must construct the MLIR types in the same context as the
-    # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
-    # otherwise, those methods will raise a KeyError.
-    types_list = [
-      (np.float64, ir.F64Type.get(context=context)),
-      (np.float32, ir.F32Type.get(context=context)),
-      (np.int64, ir.IntegerType.get_signless(64, context=context)),
-      (np.int32, ir.IntegerType.get_signless(32, context=context)),
-      (np.int16, ir.IntegerType.get_signless(16, context=context)),
-      (np.int8, ir.IntegerType.get_signless(8, context=context)),
-    ]
-    self._sc2ir = dict(types_list)
-    self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
-
-  def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
-    """Returns the MLIR equivalent of a NumPy dtype."""
-    try:
-      return self.sctype_to_irtype(dtype.type)
-    except KeyError as e:
-      raise KeyError(f'Unknown dtype: {dtype}') from e
-
-  def sctype_to_irtype(self, sctype) -> ir.Type:
-    """Returns the MLIR equivalent of a NumPy scalar type."""
-    if sctype in self._sc2ir:
-      return self._sc2ir[sctype]
-    else:
-      raise KeyError(f'Unknown sctype: {sctype}')
-
-  def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
-    """Returns the NumPy dtype equivalent of an MLIR type."""
-    return np.dtype(self.irtype_to_sctype(tp))
-
-  def irtype_to_sctype(self, tp: ir.Type):
-    """Returns the NumPy scalar-type equivalent of an MLIR type."""
-    if tp in self._ir2sc:
-      return self._ir2sc[tp]
-    else:
-      raise KeyError(f'Unknown ir.Type: {tp}')
-
-  def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
-    """Returns the ir.RankedTensorType of a NumPy array.  Note that NumPy
-    arrays can only be converted to/from dense tensors, not sparse tensors."""
-    # TODO: handle strides as well?
-    return ir.RankedTensorType.get(nparray.shape,
-                                   self.dtype_to_irtype(nparray.dtype))
+    """Converter between NumPy types and MLIR types."""
+
+    def __init__(self, context: ir.Context):
+        # Note 1: these are numpy "scalar types" (i.e., the values of
+        # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
+        #
+        # Note 2: we must construct the MLIR types in the same context as the
+        # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
+        # otherwise, those methods will raise a KeyError.
+        types_list = [
+            (np.float64, ir.F64Type.get(context=context)),
+            (np.float32, ir.F32Type.get(context=context)),
+            (np.int64, ir.IntegerType.get_signless(64, context=context)),
+            (np.int32, ir.IntegerType.get_signless(32, context=context)),
+            (np.int16, ir.IntegerType.get_signless(16, context=context)),
+            (np.int8, ir.IntegerType.get_signless(8, context=context)),
+        ]
+        self._sc2ir = dict(types_list)
+        self._ir2sc = dict(((ir, sc) for sc, ir in types_list))
+
+    def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
+        """Returns the MLIR equivalent of a NumPy dtype."""
+        try:
+            return self.sctype_to_irtype(dtype.type)
+        except KeyError as e:
+            raise KeyError(f"Unknown dtype: {dtype}") from e
+
+    def sctype_to_irtype(self, sctype) -> ir.Type:
+        """Returns the MLIR equivalent of a NumPy scalar type."""
+        if sctype in self._sc2ir:
+            return self._sc2ir[sctype]
+        else:
+            raise KeyError(f"Unknown sctype: {sctype}")
+
+    def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
+        """Returns the NumPy dtype equivalent of an MLIR type."""
+        return np.dtype(self.irtype_to_sctype(tp))
+
+    def irtype_to_sctype(self, tp: ir.Type):
+        """Returns the NumPy scalar-type equivalent of an MLIR type."""
+        if tp in self._ir2sc:
+            return self._ir2sc[tp]
+        else:
+            raise KeyError(f"Unknown ir.Type: {tp}")
+
+    def get_RankedTensorType_of_nparray(
+        self, nparray: np.ndarray
+    ) -> ir.RankedTensorType:
+        """Returns the ir.RankedTensorType of a NumPy array.  Note that NumPy
+        arrays can only be converted to/from dense tensors, not sparse tensors."""
+        # TODO: handle strides as well?
+        return ir.RankedTensorType.get(
+            nparray.shape, self.dtype_to_irtype(nparray.dtype)
+        )
+
 
 # ===----------------------------------------------------------------------=== #
 
+
 class StressTest:
-  def __init__(self, tyconv: TypeConverter):
-    self._tyconv = tyconv
-    self._roundtripTp = None
-    self._module = None
-    self._engine = None
-
-  def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
-    assert self._roundtripTp is not None, \
-        'StressTest: uninitialized roundtrip type'
-    if tp != self._roundtripTp:
-      raise AssertionError(
-          f"Type is not equal to the roundtrip type.\n"
-          f"\tExpected: {self._roundtripTp}\n"
-          f"\tFound:    {tp}\n")
-
-  def build(self, types: List[ir.Type]):
-    """Builds the ir.Module.  The module has only the @main function,
-    which will convert the input through the list of types and then back
-    to the initial type.  The roundtrip type must be a dense tensor."""
-    assert self._module is None, 'StressTest: must not call build() repeatedly'
-    self._module = ir.Module.create()
-    with ir.InsertionPoint(self._module.body):
-      tp0 = types.pop(0)
-      self._roundtripTp = tp0
-      # TODO: assert dense? assert element type is recognised by the TypeConverter?
-      types.append(tp0)
-      funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
-      funcOp = func.FuncOp(name='main', type=funcTp)
-      funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
-      with ir.InsertionPoint(funcOp.add_entry_block()):
-        arg0 = funcOp.entry_block.arguments[0]
-        self._assertEqualsRoundtripTp(arg0.type)
-        v = st.ConvertOp(types.pop(0), arg0)
-        for tp in types:
-          w = st.ConvertOp(tp, v)
-          # Release intermediate tensors before they fall out of scope.
-          bufferization.DeallocTensorOp(v.result)
-          v = w
-        self._assertEqualsRoundtripTp(v.result.type)
-        func.ReturnOp(v)
-    return self
-
-  def writeTo(self, filename):
-    """Write the ir.Module to the given file.  If the file already exists,
-    then raises an error.  If the filename is None, then is a no-op."""
-    assert self._module is not None, \
-        'StressTest: must call build() before writeTo()'
-    if filename is None:
-      # Silent no-op, for convenience.
-      return self
-    if os.path.exists(filename):
-      raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
-    with open(filename, 'w') as f:
-      f.write(str(self._module))
-    return self
-
-  def compile(self, compiler):
-    """Compile the ir.Module."""
-    assert self._module is not None, \
-        'StressTest: must call build() before compile()'
-    assert self._engine is None, \
-        'StressTest: must not call compile() repeatedly'
-    self._engine = compiler.compile_and_jit(self._module)
-    return self
-
-  def run(self, np_arg0: np.ndarray) -> np.ndarray:
-    """Runs the test on the given numpy array, and returns the resulting
-    numpy array."""
-    assert self._engine is not None, \
-        'StressTest: must call compile() before run()'
-    self._assertEqualsRoundtripTp(
-        self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
-    np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
-    self._assertEqualsRoundtripTp(
-        self._tyconv.get_RankedTensorType_of_nparray(np_out))
-    mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
-    mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
-    self._engine.invoke('main', mem_out, mem_arg0)
-    return rt.ranked_memref_to_numpy(mem_out[0])
+    def __init__(self, tyconv: TypeConverter):
+        self._tyconv = tyconv
+        self._roundtripTp = None
+        self._module = None
+        self._engine = None
+
+    def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
+        assert self._roundtripTp is not None, "StressTest: uninitialized roundtrip type"
+        if tp != self._roundtripTp:
+            raise AssertionError(
+                f"Type is not equal to the roundtrip type.\n"
+                f"\tExpected: {self._roundtripTp}\n"
+                f"\tFound:    {tp}\n"
+            )
+
+    def build(self, types: List[ir.Type]):
+        """Builds the ir.Module.  The module has only the @main function,
+        which will convert the input through the list of types and then back
+        to the initial type.  The roundtrip type must be a dense tensor."""
+        assert self._module is None, "StressTest: must not call build() repeatedly"
+        self._module = ir.Module.create()
+        with ir.InsertionPoint(self._module.body):
+            tp0 = types.pop(0)
+            self._roundtripTp = tp0
+            # TODO: assert dense? assert element type is recognised by the TypeConverter?
+            types.append(tp0)
+            funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
+            funcOp = func.FuncOp(name="main", type=funcTp)
+            funcOp.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+            with ir.InsertionPoint(funcOp.add_entry_block()):
+                arg0 = funcOp.entry_block.arguments[0]
+                self._assertEqualsRoundtripTp(arg0.type)
+                v = st.ConvertOp(types.pop(0), arg0)
+                for tp in types:
+                    w = st.ConvertOp(tp, v)
+                    # Release intermediate tensors before they fall out of scope.
+                    bufferization.DeallocTensorOp(v.result)
+                    v = w
+                self._assertEqualsRoundtripTp(v.result.type)
+                func.ReturnOp(v)
+        return self
+
+    def writeTo(self, filename):
+        """Write the ir.Module to the given file.  If the file already exists,
+        then raises an error.  If the filename is None, then is a no-op."""
+        assert (
+            self._module is not None
+        ), "StressTest: must call build() before writeTo()"
+        if filename is None:
+            # Silent no-op, for convenience.
+            return self
+        if os.path.exists(filename):
+            raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
+        with open(filename, "w") as f:
+            f.write(str(self._module))
+        return self
+
+    def compile(self, compiler):
+        """Compile the ir.Module."""
+        assert (
+            self._module is not None
+        ), "StressTest: must call build() before compile()"
+        assert self._engine is None, "StressTest: must not call compile() repeatedly"
+        self._engine = compiler.compile_and_jit(self._module)
+        return self
+
+    def run(self, np_arg0: np.ndarray) -> np.ndarray:
+        """Runs the test on the given numpy array, and returns the resulting
+        numpy array."""
+        assert self._engine is not None, "StressTest: must call compile() before run()"
+        self._assertEqualsRoundtripTp(
+            self._tyconv.get_RankedTensorType_of_nparray(np_arg0)
+        )
+        np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
+        self._assertEqualsRoundtripTp(
+            self._tyconv.get_RankedTensorType_of_nparray(np_out)
+        )
+        mem_arg0 = ctypes.pointer(
+            ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0))
+        )
+        mem_out = ctypes.pointer(
+            ctypes.pointer(rt.get_ranked_memref_descriptor(np_out))
+        )
+        self._engine.invoke("main", mem_out, mem_arg0)
+        return rt.ranked_memref_to_numpy(mem_out[0])
+
 
 # ===----------------------------------------------------------------------=== #
 
+
 def main():
-  """
-  USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
-
-  The environment variable SUPPORT_LIB must be set to point to the
-  libmlir_c_runner_utils shared library.  There are two optional
-  arguments, for debugging purposes.  The first argument specifies where
-  to write out the raw/generated ir.Module.  The second argument specifies
-  where to write out the compiled version of that ir.Module.
-  """
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  if not os.path.exists(support_lib):
-    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-
-  # CHECK-LABEL: TEST: test_stress
-  print("\nTEST: test_stress")
-  with ir.Context() as ctx, ir.Location.unknown():
-    # Disable direct sparse2sparse conversion, because it doubles the time!
-    # TODO: While direct s2s is far too slow for per-commit testing,
-    # we should have some framework ensure that we run this test with
-    # `s2s=0` on a regular basis, to ensure that it does continue to work.
-    # TODO: be sure to test s2s=0 together with singletons.
-    s2s = 1
-    sparsification_options = (
-        f'parallelization-strategy=none '
-        f's2s-strategy={s2s}')
-    compiler = sparse_compiler.SparseCompiler(
-        options=sparsification_options, opt_level=0, shared_libs=[support_lib])
-    f64 = ir.F64Type.get()
-    # Be careful about increasing this because
-    #     len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
-    shape = range(2, 3)
-    rank = len(shape)
-    # All combinations.
-    # TODO: add singleton here too; which requires updating how `np_arg0`
-    # is initialized below.
-    levels = list(itertools.product(*itertools.repeat(
-      [st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
-    # All permutations.
-    orderings = list(map(ir.AffineMap.get_permutation,
-      itertools.permutations(range(rank))))
-    bitwidths = [0]
-    # The first type must be a dense tensor for numpy conversion to work.
-    types = [ir.RankedTensorType.get(shape, f64)]
-    for level in levels:
-      for ordering in orderings:
-        for pwidth in bitwidths:
-          for iwidth in bitwidths:
-            attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
-            types.append(ir.RankedTensorType.get(shape, f64, attr))
-    #
-    # For exhaustiveness we should have one or more StressTest, such
-    # that their paths cover all 2*n*(n-1) directed pairwise combinations
-    # of the `types` set.  However, since n is already superexponential,
-    # such exhaustiveness would be prohibitive for a test that runs on
-    # every commit.  So for now we'll just pick one particular path that
-    # at least hits all n elements of the `types` set.
-    #
-    tyconv = TypeConverter(ctx)
-    size = 1
-    for d in shape:
-      size *= d
-    np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
-    np_out = (
-        StressTest(tyconv).build(types).writeTo(
-            sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler)
-        .writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
-    # CHECK: Passed
-    if np.allclose(np_out, np_arg0):
-      print('Passed')
-    else:
-      sys.exit('FAILURE')
-
-if __name__ == '__main__':
-  main()
+    """
+    USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
+
+    The environment variable SUPPORT_LIB must be set to point to the
+    libmlir_c_runner_utils shared library.  There are two optional
+    arguments, for debugging purposes.  The first argument specifies where
+    to write out the raw/generated ir.Module.  The second argument specifies
+    where to write out the compiled version of that ir.Module.
+    """
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    if not os.path.exists(support_lib):
+        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+    # CHECK-LABEL: TEST: test_stress
+    print("\nTEST: test_stress")
+    with ir.Context() as ctx, ir.Location.unknown():
+        # Disable direct sparse2sparse conversion, because it doubles the time!
+        # TODO: While direct s2s is far too slow for per-commit testing,
+        # we should have some framework ensure that we run this test with
+        # `s2s=0` on a regular basis, to ensure that it does continue to work.
+        # TODO: be sure to test s2s=0 together with singletons.
+        s2s = 1
+        sparsification_options = f"parallelization-strategy=none " f"s2s-strategy={s2s}"
+        compiler = sparse_compiler.SparseCompiler(
+            options=sparsification_options, opt_level=0, shared_libs=[support_lib]
+        )
+        f64 = ir.F64Type.get()
+        # Be careful about increasing this because
+        #     len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
+        shape = range(2, 3)
+        rank = len(shape)
+        # All combinations.
+        # TODO: add singleton here too; which requires updating how `np_arg0`
+        # is initialized below.
+        levels = list(
+            itertools.product(
+                *itertools.repeat(
+                    [st.DimLevelType.dense, st.DimLevelType.compressed], rank
+                )
+            )
+        )
+        # All permutations.
+        orderings = list(
+            map(ir.AffineMap.get_permutation, itertools.permutations(range(rank)))
+        )
+        bitwidths = [0]
+        # The first type must be a dense tensor for numpy conversion to work.
+        types = [ir.RankedTensorType.get(shape, f64)]
+        for level in levels:
+            for ordering in orderings:
+                for pwidth in bitwidths:
+                    for iwidth in bitwidths:
+                        attr = st.EncodingAttr.get(
+                            level, ordering, None, pwidth, iwidth
+                        )
+                        types.append(ir.RankedTensorType.get(shape, f64, attr))
+        #
+        # For exhaustiveness we should have one or more StressTest, such
+        # that their paths cover all 2*n*(n-1) directed pairwise combinations
+        # of the `types` set.  However, since n is already superexponential,
+        # such exhaustiveness would be prohibitive for a test that runs on
+        # every commit.  So for now we'll just pick one particular path that
+        # at least hits all n elements of the `types` set.
+        #
+        tyconv = TypeConverter(ctx)
+        size = 1
+        for d in shape:
+            size *= d
+        np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
+        np_out = (
+            StressTest(tyconv)
+            .build(types)
+            .writeTo(sys.argv[1] if len(sys.argv) > 1 else None)
+            .compile(compiler)
+            .writeTo(sys.argv[2] if len(sys.argv) > 2 else None)
+            .run(np_arg0)
+        )
+        # CHECK: Passed
+        if np.allclose(np_out, np_arg0):
+            print("Passed")
+        else:
+            sys.exit("FAILURE")
+
+
+if __name__ == "__main__":
+    main()

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
index f5b0ab60e85e9..785d42cadbbe9 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
@@ -11,65 +11,71 @@
 
 @functools.lru_cache()
 def _get_c_shared_lib(lib_name: str):
-  """Loads and returns the requested C shared library.
+    """Loads and returns the requested C shared library.
 
-  Args:
-    lib_name: A string representing the C shared library.
+    Args:
+      lib_name: A string representing the C shared library.
 
-  Returns:
-    The C shared library.
+    Returns:
+      The C shared library.
 
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError:  If the shared library doesn't contain the needed routine.
-  """
-  # This raises OSError exception if there is any problem in loading the shared
-  # library.
-  c_lib = ctypes.CDLL(lib_name)
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError:  If the shared library doesn't contain the needed routine.
+    """
+    # This raises OSError exception if there is any problem in loading the shared
+    # library.
+    c_lib = ctypes.CDLL(lib_name)
 
-  try:
-    c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
-  except Exception as e:
-    raise ValueError('Missing function convertFromMLIRSparseTensorF64 from '
-                     f'the C shared library: {e} ') from e
+    try:
+        c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
+    except Exception as e:
+        raise ValueError(
+            "Missing function convertFromMLIRSparseTensorF64 from "
+            f"the C shared library: {e} "
+        ) from e
 
-  return c_lib
+    return c_lib
 
 
 def sparse_tensor_to_coo_tensor(support_lib, sparse, dtype):
-  """Converts a sparse tensor to COO-flavored format.
+    """Converts a sparse tensor to COO-flavored format.
 
-  Args:
-     support_lib: A string for the supporting C shared library.
-     sparse: A ctypes.pointer to the sparse tensor descriptor.
-     dtype: The numpy data type for the tensor elements.
+    Args:
+       support_lib: A string for the supporting C shared library.
+       sparse: A ctypes.pointer to the sparse tensor descriptor.
+       dtype: The numpy data type for the tensor elements.
 
-  Returns:
-    A tuple that contains the following values:
-    rank: An integer for the rank of the tensor.
-    nse: An integer for the number of non-zero values in the tensor.
-    shape: A 1D numpy array of integers, for the shape of the tensor.
-    values: A 1D numpy array, for the non-zero values in the tensor.
-    indices: A 2D numpy array of integers, representing the indices for the
-      non-zero values in the tensor.
+    Returns:
+      A tuple that contains the following values:
+      rank: An integer for the rank of the tensor.
+      nse: An integer for the number of non-zero values in the tensor.
+      shape: A 1D numpy array of integers, for the shape of the tensor.
+      values: A 1D numpy array, for the non-zero values in the tensor.
+      indices: A 2D numpy array of integers, representing the indices for the
+        non-zero values in the tensor.
 
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError:  If the shared library doesn't contain the needed routine.
-  """
-  c_lib = _get_c_shared_lib(support_lib)
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError:  If the shared library doesn't contain the needed routine.
+    """
+    c_lib = _get_c_shared_lib(support_lib)
 
-  rank = ctypes.c_ulonglong(0)
-  nse = ctypes.c_ulonglong(0)
-  shape = ctypes.POINTER(ctypes.c_ulonglong)()
-  values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
-  indices = ctypes.POINTER(ctypes.c_ulonglong)()
-  c_lib.convertFromMLIRSparseTensorF64(sparse, ctypes.byref(rank),
-                                       ctypes.byref(nse), ctypes.byref(shape),
-                                       ctypes.byref(values),
-                                       ctypes.byref(indices))
-  # Convert the returned values to the corresponding numpy types.
-  shape = np.ctypeslib.as_array(shape, shape=[rank.value])
-  values = np.ctypeslib.as_array(values, shape=[nse.value])
-  indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
-  return rank, nse, shape, values, indices
+    rank = ctypes.c_ulonglong(0)
+    nse = ctypes.c_ulonglong(0)
+    shape = ctypes.POINTER(ctypes.c_ulonglong)()
+    values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
+    indices = ctypes.POINTER(ctypes.c_ulonglong)()
+    c_lib.convertFromMLIRSparseTensorF64(
+        sparse,
+        ctypes.byref(rank),
+        ctypes.byref(nse),
+        ctypes.byref(shape),
+        ctypes.byref(values),
+        ctypes.byref(indices),
+    )
+    # Convert the returned values to the corresponding numpy types.
+    shape = np.ctypeslib.as_array(shape, shape=[rank.value])
+    values = np.ctypeslib.as_array(values, shape=[nse.value])
+    indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
+    return rank, nse, shape, values, indices

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
index 25004f9492dbc..d549a9a0954c6 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
@@ -9,30 +9,31 @@
 from mlir import passmanager
 from typing import Sequence
 
+
 class SparseCompiler:
-  """Sparse compiler class for compiling and building MLIR modules."""
-
-  def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
-    pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
-    self.pipeline = pipeline
-    self.opt_level = opt_level
-    self.shared_libs = shared_libs
-
-  def __call__(self, module: ir.Module):
-    """Convenience application method."""
-    self.compile(module)
-
-  def compile(self, module: ir.Module):
-    """Compiles the module by invoking the sparse copmiler pipeline."""
-    passmanager.PassManager.parse(self.pipeline).run(module.operation)
-
-  def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
-    """Wraps the module in a JIT execution engine."""
-    return execution_engine.ExecutionEngine(
-        module, opt_level=self.opt_level, shared_libs=self.shared_libs)
-
-  def compile_and_jit(self,
-                      module: ir.Module) -> execution_engine.ExecutionEngine:
-    """Compiles and jits the module."""
-    self.compile(module)
-    return self.jit(module)
+    """Sparse compiler class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
+        self.pipeline = pipeline
+        self.opt_level = opt_level
+        self.shared_libs = shared_libs
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the sparse copmiler pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg
index 7137d0fba95f8..f1bbcf486bc27 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg
@@ -1,5 +1,5 @@
 # Disable ASAN's leak detection for python taco tests.
-config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
+config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
 # Only run when python bindings are enabled.
 if not config.enable_bindings_python:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
index 88b13ae161b0d..2d558f8d6ddff 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
@@ -46,10 +46,10 @@
 
 # Perform the MTTKRP computation and write the result to file.
 with tempfile.TemporaryDirectory() as test_dir:
-  golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
-  out_file = os.path.join(test_dir, "A.tns")
-  pt.write(out_file, A)
-  #
-  # CHECK: Compare result True
-  #
-  print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
+    golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
+    out_file = os.path.join(test_dir, "A.tns")
+    pt.write(out_file, A)
+    #
+    # CHECK: Compare result True
+    #
+    print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
index ba4ea9c1e6d0a..ef94ea9900fe4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
@@ -46,13 +46,13 @@
 
 # Force evaluation of the kernels by writing out X and Y.
 with tempfile.TemporaryDirectory() as test_dir:
-  x_file = os.path.join(test_dir, "X.tns")
-  y_file = os.path.join(test_dir, "Y.tns")
-  pt.write(x_file, X)
-  pt.write(y_file, Y)
-  #
-  # CHECK: Compare result True True
-  #
-  x_data = utils.file_as_string(x_file)
-  y_data = utils.file_as_string(y_file)
-  print(f"Compare result {x_data == expected} {y_data == expected}")
+    x_file = os.path.join(test_dir, "X.tns")
+    y_file = os.path.join(test_dir, "Y.tns")
+    pt.write(x_file, X)
+    pt.write(y_file, Y)
+    #
+    # CHECK: Compare result True True
+    #
+    x_data = utils.file_as_string(x_file)
+    y_data = utils.file_as_string(y_file)
+    print(f"Compare result {x_data == expected} {y_data == expected}")

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
index 10309cb706f78..02bbbc096e7a3 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
@@ -26,10 +26,10 @@
 
 # Force evaluation of the kernel by writing out C.
 with tempfile.TemporaryDirectory() as test_dir:
-  golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
-  out_file = os.path.join(test_dir, "C.tns")
-  pt.write(out_file, C)
-  #
-  # CHECK: Compare result True
-  #
-  print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
+    golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
+    out_file = os.path.join(test_dir, "C.tns")
+    pt.write(out_file, C)
+    #
+    # CHECK: Compare result True
+    #
+    print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
index de150ea68efd5..2038a473ae530 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
@@ -47,10 +47,10 @@
 
 # Perform the SpMV computation and write the result to file
 with tempfile.TemporaryDirectory() as test_dir:
-  golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
-  out_file = os.path.join(test_dir, "y.tns")
-  pt.write(out_file, y)
-  #
-  # CHECK: Compare result True
-  #
-  print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
+    golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
+    out_file = os.path.join(test_dir, "y.tns")
+    pt.write(out_file, y)
+    #
+    # CHECK: Compare result True
+    #
+    print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py
index c1e6c87940a02..cd24e0dbb0a43 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py
@@ -18,11 +18,10 @@
 alpha = pt.tensor(42.0)
 
 # Set up some sparse tensors with 
diff erent dim annotations and ordering.
-S = pt.tensor([8, 8, 8],
-              pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
-X = pt.tensor([8, 8, 8],
-              pt.format([pt.compressed, pt.compressed, pt.compressed],
-                        [1, 0, 2]))
+S = pt.tensor([8, 8, 8], pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
+X = pt.tensor(
+    [8, 8, 8], pt.format([pt.compressed, pt.compressed, pt.compressed], [1, 0, 2])
+)
 S.insert([0, 0, 0], 2.0)
 S.insert([1, 1, 1], 3.0)
 S.insert([4, 4, 4], 4.0)
@@ -32,16 +31,14 @@
 
 # Set up tensors with a dense last dimension. This results in a full
 # enveloping storage of all last "rows" with one or more nonzeros.
-T = pt.tensor([1, 2, 3, 4, 5],
-              pt.format([
-                  pt.compressed, pt.compressed, pt.compressed, pt.compressed,
-                  pt.dense
-              ]))
-Y = pt.tensor([1, 2, 3, 4, 5],
-              pt.format([
-                  pt.compressed, pt.compressed, pt.compressed, pt.compressed,
-                  pt.dense
-              ]))
+T = pt.tensor(
+    [1, 2, 3, 4, 5],
+    pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
+)
+Y = pt.tensor(
+    [1, 2, 3, 4, 5],
+    pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
+)
 T.insert([0, 1, 2, 3, 4], -2.0)
 
 Y[i, j, k, l, m] = alpha[0] * T[i, j, k, l, m]
@@ -85,18 +82,18 @@
 
 # Force evaluation of the kernel by writing out X.
 with tempfile.TemporaryDirectory() as test_dir:
-  x_file = os.path.join(test_dir, 'X.tns')
-  pt.write(x_file, X)
-  y_file = os.path.join(test_dir, 'Y.tns')
-  pt.write(y_file, Y)
-  z_file = os.path.join(test_dir, 'Z.tns')
-  pt.write(z_file, Z)
-  #
-  # CHECK: Compare result True True True
-  #
-  x_data = utils.file_as_string(x_file)
-  y_data = utils.file_as_string(y_file)
-  z_data = utils.file_as_string(z_file)
-  print(
-      f'Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}'
-  )
+    x_file = os.path.join(test_dir, "X.tns")
+    pt.write(x_file, X)
+    y_file = os.path.join(test_dir, "Y.tns")
+    pt.write(y_file, Y)
+    z_file = os.path.join(test_dir, "Z.tns")
+    pt.write(z_file, Z)
+    #
+    # CHECK: Compare result True True True
+    #
+    x_data = utils.file_as_string(x_file)
+    y_data = utils.file_as_string(y_file)
+    z_data = utils.file_as_string(z_file)
+    print(
+        f"Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}"
+    )

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
index 60b91de42157e..206ffa9316d48 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
@@ -12,7 +12,7 @@
 
 i, j = pt.get_index_vars(2)
 A = pt.tensor([2, 3])
-S = pt.tensor(3) # S is a scalar tensor.
+S = pt.tensor(3)  # S is a scalar tensor.
 B = pt.tensor([2, 3], compressed)
 A.insert([0, 1], 10)
 A.insert([1, 2], 40)
@@ -26,11 +26,11 @@
 
 # Sum all the values in A.
 S[0] = A[i, j]
-passed += (S.get_scalar_value() == 50.0)
+passed += S.get_scalar_value() == 50.0
 
 indices, values = S.get_coordinates_and_values()
-passed += (len(indices)==0)
-passed += (values == 50.0)
+passed += len(indices) == 0
+passed += values == 50.0
 
 # CHECK: Number of passed: 5
 print("Number of passed:", passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
index 8fd545b91710f..b0fed50f8b5db 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
@@ -12,20 +12,20 @@
 passed = 0
 all_types = [pt.complex64, pt.complex128]
 for t in all_types:
-  i, j = pt.get_index_vars(2)
-  A = pt.tensor([2, 3], dtype=t)
-  B = pt.tensor([2, 3], dtype=t)
-  C = pt.tensor([2, 3], compressed, dtype=t)
-  A.insert([0, 1], 10 + 20j)
-  A.insert([1, 2], 40 + 0.5j)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30 + 15j)
-  C[i, j] = A[i, j] + B[i, j]
+    i, j = pt.get_index_vars(2)
+    A = pt.tensor([2, 3], dtype=t)
+    B = pt.tensor([2, 3], dtype=t)
+    C = pt.tensor([2, 3], compressed, dtype=t)
+    A.insert([0, 1], 10 + 20j)
+    A.insert([1, 2], 40 + 0.5j)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30 + 15j)
+    C[i, j] = A[i, j] + B[i, j]
 
-  indices, values = C.get_coordinates_and_values()
-  passed += isinstance(values[0], t.value)
-  passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
-  passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
+    indices, values = C.get_coordinates_and_values()
+    passed += isinstance(values[0], t.value)
+    passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+    passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
 
 # CHECK: Number of passed: 6
 print("Number of passed:", passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
index cec687ff4de5c..4ba2836dd4616 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
@@ -12,24 +12,22 @@
 dense = pt.dense
 
 passed = 0
-all_types = [
-    pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64
-]
+all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64]
 for t in all_types:
-  i, j = pt.get_index_vars(2)
-  A = pt.tensor([2, 3], dtype=t)
-  B = pt.tensor([2, 3], dtype=t)
-  C = pt.tensor([2, 3], compressed, dtype=t)
-  A.insert([0, 1], 10)
-  A.insert([1, 2], 40)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30)
-  C[i, j] = A[i, j] + B[i, j]
+    i, j = pt.get_index_vars(2)
+    A = pt.tensor([2, 3], dtype=t)
+    B = pt.tensor([2, 3], dtype=t)
+    C = pt.tensor([2, 3], compressed, dtype=t)
+    A.insert([0, 1], 10)
+    A.insert([1, 2], 40)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30)
+    C[i, j] = A[i, j] + B[i, j]
 
-  indices, values = C.get_coordinates_and_values()
-  passed += isinstance(values[0], t.value)
-  passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
-  passed += np.allclose(values, [20.0, 10.0, 70.0])
+    indices, values = C.get_coordinates_and_values()
+    passed += isinstance(values[0], t.value)
+    passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+    passed += np.allclose(values, [20.0, 10.0, 70.0])
 
 # CHECK: Number of passed: 21
 print("Number of passed:", passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py
index a138678a02e86..78bce344e3b6f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py
@@ -10,8 +10,8 @@
 
 i, j = pt.get_index_vars(2)
 # Both tensors are true dense tensors.
-A = pt.from_array(np.full([2,3], 1, dtype=np.float64))
-B = pt.from_array(np.full([2,3], 2, dtype=np.float64))
+A = pt.from_array(np.full([2, 3], 1, dtype=np.float64))
+B = pt.from_array(np.full([2, 3], 2, dtype=np.float64))
 # Define the result tensor as a true dense tensor. The parameter is_dense=True
 # is an MLIR-PyTACO extension.
 C = pt.tensor([2, 3], dtype=pt.float64, is_dense=True)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 44d28b08d8b30..b3194f7edecd5 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -65,19 +65,20 @@
 
 
 class Type(enum.Enum):
-  """The data types supported by TACO.
+    """The data types supported by TACO.
 
-  We use numpy data types to implement the enum data types.
-  """
-  INT8 = np.int8
-  INT16 = np.int16
-  INT32 = np.int32
-  INT64 = np.int64
-  FLOAT16 = np.float16
-  FLOAT32 = np.float32
-  FLOAT64 = np.float64
-  COMPLEX64 = np.complex64
-  COMPLEX128 = np.complex128
+    We use numpy data types to implement the enum data types.
+    """
+
+    INT8 = np.int8
+    INT16 = np.int16
+    INT32 = np.int32
+    INT64 = np.int64
+    FLOAT16 = np.float16
+    FLOAT32 = np.float32
+    FLOAT64 = np.float64
+    COMPLEX64 = np.complex64
+    COMPLEX128 = np.complex128
 
 
 # All floating point type enums.
@@ -88,1732 +89,1810 @@ class Type(enum.Enum):
 _COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128)
 # Type alias for any numpy type used to implement the runtime support for the
 # enum data types.
-_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16,
-                        np.float32, np.float64, np.complex64, np.complex128]
+_AnyRuntimeType = Union[
+    np.int8,
+    np.int16,
+    np.int32,
+    np.int64,
+    np.float16,
+    np.float32,
+    np.float64,
+    np.complex64,
+    np.complex128,
+]
 
 
 @dataclasses.dataclass(frozen=True)
 class DType:
-  """The data type class.
+    """The data type class.
 
-  We support the TACO API dtype class with an alias of this class.
+    We support the TACO API dtype class with an alias of this class.
 
-  The following methods are defined by the TACO API:
-    is_float: Returns whether the data type represents a floating point value.
-    is_int:   Returns whether the data type represents an integral value.
+    The following methods are defined by the TACO API:
+      is_float: Returns whether the data type represents a floating point value.
+      is_int:   Returns whether the data type represents an integral value.
 
-  Attributes:
-    kind: A Type enum representing the data type.
-    value: The numpy data type for the TACO data type.
-  """
-  kind: Type = Type.FLOAT32
+    Attributes:
+      kind: A Type enum representing the data type.
+      value: The numpy data type for the TACO data type.
+    """
 
-  def is_float(self) -> bool:
-    """Returns whether the data type represents a floating point value."""
-    return self.kind in _FLOAT_TYPES
+    kind: Type = Type.FLOAT32
 
-  def is_int(self) -> bool:
-    """Returns whether the data type represents an integral value."""
-    return self.kind in _INT_TYPES
+    def is_float(self) -> bool:
+        """Returns whether the data type represents a floating point value."""
+        return self.kind in _FLOAT_TYPES
 
-  def is_complex(self) -> bool:
-    """Returns whether the data type represents a complex value."""
-    return self.kind in _COMPLEX_TYPES
+    def is_int(self) -> bool:
+        """Returns whether the data type represents an integral value."""
+        return self.kind in _INT_TYPES
 
-  @property
-  def value(self) -> _AnyRuntimeType:
-    """Returns the numpy dtype for the data type."""
-    return self.kind.value
+    def is_complex(self) -> bool:
+        """Returns whether the data type represents a complex value."""
+        return self.kind in _COMPLEX_TYPES
+
+    @property
+    def value(self) -> _AnyRuntimeType:
+        """Returns the numpy dtype for the data type."""
+        return self.kind.value
 
 
 def _dtype_to_mlir_str(dtype: DType) -> str:
-  """Returns the MLIR string for the given dtype."""
-  dtype_to_str = {
-      Type.INT16: "i8",
-      Type.INT16: "i16",
-      Type.INT32: "i32",
-      Type.INT64: "i64",
-      Type.FLOAT16: "f16",
-      Type.FLOAT32: "f32",
-      Type.FLOAT64: "f64",
-      Type.COMPLEX64: "complex<f32>",
-      Type.COMPLEX128: "complex<f64>"
-  }
-  return dtype_to_str[dtype.kind]
+    """Returns the MLIR string for the given dtype."""
+    dtype_to_str = {
+        Type.INT16: "i8",
+        Type.INT16: "i16",
+        Type.INT32: "i32",
+        Type.INT64: "i64",
+        Type.FLOAT16: "f16",
+        Type.FLOAT32: "f32",
+        Type.FLOAT64: "f64",
+        Type.COMPLEX64: "complex<f32>",
+        Type.COMPLEX128: "complex<f64>",
+    }
+    return dtype_to_str[dtype.kind]
 
 
 def _nptype_to_taco_type(ty: np.dtype) -> DType:
-  """Returns the TACO type for the given numpy type."""
-  nptype_to_dtype = {
-      np.int8: Type.INT8,
-      np.int16: Type.INT16,
-      np.int32: Type.INT32,
-      np.int64: Type.INT64,
-      np.float16: Type.FLOAT16,
-      np.float32: Type.FLOAT32,
-      np.float64: Type.FLOAT64,
-      np.complex64: Type.COMPLEX64,
-      np.complex128: Type.COMPLEX128
-  }
-  return DType(nptype_to_dtype[ty])
+    """Returns the TACO type for the given numpy type."""
+    nptype_to_dtype = {
+        np.int8: Type.INT8,
+        np.int16: Type.INT16,
+        np.int32: Type.INT32,
+        np.int64: Type.INT64,
+        np.float16: Type.FLOAT16,
+        np.float32: Type.FLOAT32,
+        np.float64: Type.FLOAT64,
+        np.complex64: Type.COMPLEX64,
+        np.complex128: Type.COMPLEX128,
+    }
+    return DType(nptype_to_dtype[ty])
 
 
 def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
-  """Returns the MLIR type corresponding to the given TACO type."""
-  dtype_to_irtype = {
-      Type.INT8: ir.IntegerType.get_signless(8),
-      Type.INT16: ir.IntegerType.get_signless(16),
-      Type.INT32: ir.IntegerType.get_signless(32),
-      Type.INT64: ir.IntegerType.get_signless(64),
-      Type.FLOAT16: ir.F16Type.get(),
-      Type.FLOAT32: ir.F32Type.get(),
-      Type.FLOAT64: ir.F64Type.get(),
-      Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),
-      Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get())
-  }
-  return dtype_to_irtype[dtype.kind]
+    """Returns the MLIR type corresponding to the given TACO type."""
+    dtype_to_irtype = {
+        Type.INT8: ir.IntegerType.get_signless(8),
+        Type.INT16: ir.IntegerType.get_signless(16),
+        Type.INT32: ir.IntegerType.get_signless(32),
+        Type.INT64: ir.IntegerType.get_signless(64),
+        Type.FLOAT16: ir.F16Type.get(),
+        Type.FLOAT32: ir.F32Type.get(),
+        Type.FLOAT64: ir.F64Type.get(),
+        Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),
+        Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get()),
+    }
+    return dtype_to_irtype[dtype.kind]
+
 
 def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer:
-  """Returns the ctype pointer for the given numpy array."""
-  return ctypes.pointer(
-      ctypes.pointer(runtime.get_ranked_memref_descriptor(array)))
+    """Returns the ctype pointer for the given numpy array."""
+    return ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(array)))
 
 
 class ModeFormat(enum.Enum):
-  """The tensor dimension storage format class.
+    """The tensor dimension storage format class.
 
-  We support the TACO API mode_format class with an alias of this class.
+    We support the TACO API mode_format class with an alias of this class.
+
+    In TACO, a tensor dimension is called a mode and the storage format for a
+    tensor dimension is called a mode format.
+    """
 
-  In TACO, a tensor dimension is called a mode and the storage format for a
-  tensor dimension is called a mode format.
-  """
-  DENSE = sparse_tensor.DimLevelType.dense
-  COMPRESSED = sparse_tensor.DimLevelType.compressed
+    DENSE = sparse_tensor.DimLevelType.dense
+    COMPRESSED = sparse_tensor.DimLevelType.compressed
 
 
-def _mode_format_operation(a: ModeFormat, b: ModeFormat,
-                           op: _LogicalOp) -> ModeFormat:
-  """Implements the given operator on ModeFormat."""
-  return (ModeFormat.COMPRESSED
-          if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED) else
-          ModeFormat.DENSE)
+def _mode_format_operation(a: ModeFormat, b: ModeFormat, op: _LogicalOp) -> ModeFormat:
+    """Implements the given operator on ModeFormat."""
+    return (
+        ModeFormat.COMPRESSED
+        if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED)
+        else ModeFormat.DENSE
+    )
 
 
 def _mode_format_estimator(op: _BinaryOp) -> _ModeFormatOp:
-  """Produces a ModeFormat operator for the given binary operator.
+    """Produces a ModeFormat operator for the given binary operator.
 
-  The ModeFormat operator is used as a heuristic to derive the destination
-  dimension sparsity from the source dimension sparsity. In particular, if the
-  binary operator produces a disjunction of the zero values from its source
-  operands, such as the MUL operator, we return a ModeFormat operator that
-  uses operator.or_. That is, we estimate that a dimension for the MUL
-  operation result to be sparse if either of its source operands is sparse.
+    The ModeFormat operator is used as a heuristic to derive the destination
+    dimension sparsity from the source dimension sparsity. In particular, if the
+    binary operator produces a disjunction of the zero values from its source
+    operands, such as the MUL operator, we return a ModeFormat operator that
+    uses operator.or_. That is, we estimate that a dimension for the MUL
+    operation result to be sparse if either of its source operands is sparse.
 
-  On the other hand, if the binary operator produces a conjunction of the
-  zero values from its source operands, such as the ADD operator, we return
-  a ModeFormat operator that uses operator.and_. In this case, we estimate
-  that a dimension for the ADD operation result to be sparse if both of its
-  source operands are sparse.
+    On the other hand, if the binary operator produces a conjunction of the
+    zero values from its source operands, such as the ADD operator, we return
+    a ModeFormat operator that uses operator.and_. In this case, we estimate
+    that a dimension for the ADD operation result to be sparse if both of its
+    source operands are sparse.
 
-  Args:
-    op: A _BinaryOp object representing a supporting operator on tensors.
+    Args:
+      op: A _BinaryOp object representing a supporting operator on tensors.
 
-  Returns:
-    A ModeFormatOp for estimating the destination dimension sparsity from
-    the source dimension sparsity.
-  """
-  conjunction = functools.partial(_mode_format_operation, op=operator.and_)
-  disjunction = functools.partial(_mode_format_operation, op=operator.or_)
-  return conjunction if op(0, 1) != 0 else disjunction
+    Returns:
+      A ModeFormatOp for estimating the destination dimension sparsity from
+      the source dimension sparsity.
+    """
+    conjunction = functools.partial(_mode_format_operation, op=operator.and_)
+    disjunction = functools.partial(_mode_format_operation, op=operator.or_)
+    return conjunction if op(0, 1) != 0 else disjunction
 
 
 def _all_instance_of(collection: Iterable, cls: Any) -> bool:
-  """Returns true if all elements of the iterable is an instance of cls."""
-  return all(isinstance(e, cls) for e in collection)
+    """Returns true if all elements of the iterable is an instance of cls."""
+    return all(isinstance(e, cls) for e in collection)
 
 
 def _identity_ordering(rank: int) -> List[int]:
-  """Returns the identity ordering for tensor of given rank."""
-  return list(range(rank))
+    """Returns the identity ordering for tensor of given rank."""
+    return list(range(rank))
 
 
 @dataclasses.dataclass(frozen=True)
 class ModeOrdering:
-  """The tensor dimension ordering class.
-
-  We support the TACO API mode_ordering class with an alias of this class.
+    """The tensor dimension ordering class.
 
-  Attributes:
-    ordering: A list of integers representing the ordering of the tensor
-      dimensions.
-  """
-  ordering: List[int]
+    We support the TACO API mode_ordering class with an alias of this class.
 
-  def __post_init__(self) -> None:
-    """Verifies the value in ordering.
-
-    Raises:
-       ValueError: If ordering is not a list of integers.
+    Attributes:
+      ordering: A list of integers representing the ordering of the tensor
+        dimensions.
     """
-    if (not isinstance(self.ordering, list) or
-        not _all_instance_of(self.ordering, int)):
-      raise ValueError("Ordering must be a list of integers: "
-                       f"{self.ordering}")
-    # Check that ordering is a permutation of the dimension numbers.
-    if sorted(self.ordering) != _identity_ordering(self.rank()):
-      raise ValueError(f"Invalid ordering: {self.ordering} != "
-                       f"permutation{_identity_ordering(self.rank())}.")
-
-  def rank(self) -> int:
-    """Returns the number of dimensions represented by the ordering."""
-    return len(self.ordering)
 
+    ordering: List[int]
 
- at dataclasses.dataclass(frozen=True)
-class ModeFormatPack:
-  """The tensor dimension format class.
+    def __post_init__(self) -> None:
+        """Verifies the value in ordering.
 
-  We support the TACO API mode_format_pack class with an alias of this class.
+        Raises:
+           ValueError: If ordering is not a list of integers.
+        """
+        if not isinstance(self.ordering, list) or not _all_instance_of(
+            self.ordering, int
+        ):
+            raise ValueError("Ordering must be a list of integers: " f"{self.ordering}")
+        # Check that ordering is a permutation of the dimension numbers.
+        if sorted(self.ordering) != _identity_ordering(self.rank()):
+            raise ValueError(
+                f"Invalid ordering: {self.ordering} != "
+                f"permutation{_identity_ordering(self.rank())}."
+            )
 
-  The storage format of a tensor contains one mode_format for each tensor
-  dimension.
+    def rank(self) -> int:
+        """Returns the number of dimensions represented by the ordering."""
+        return len(self.ordering)
 
-  Attributes:
-    formats: A list of ModeFormat representing the storage format for each of
-      the tensor dimension.
-  """
-  formats: List[ModeFormat]
 
-  def __post_init__(self) -> None:
-    """Verifies the value in formats.
-
-    Raises:
-       ValueError: If formats is not a list of ModeFormats.
-    """
-    if (not isinstance(self.formats, list) or
-        not _all_instance_of(self.formats, ModeFormat)):
-      raise ValueError("Formats must be a list of ModeFormat: "
-                       f"{self.formats}")
-
-  def rank(self) -> int:
-    """Returns the number of dimensions represented by the format pack."""
-    return len(self.formats)
+ at dataclasses.dataclass(frozen=True)
+class ModeFormatPack:
+    """The tensor dimension format class.
 
+    We support the TACO API mode_format_pack class with an alias of this class.
 
- at dataclasses.dataclass
-class Format:
-  """The tensor format class defined by the TACO API.
-
-  Attributes:
-    format_pack: A ModeFormatPack representing the storage format for the tensor
-      dimensions.
-    ordering: A ModeOrdering representing the tensor dimension ordering in the
-      storage.
-  """
-  format_pack: ModeFormatPack
-  ordering: Optional[ModeOrdering] = None
-
-  def __post_init__(self) -> None:
-    """Verifies and fixes up the values in format_pack and ordering.
-
-    Verifies and fixes up the values in format_pack and ordering to supports the
-    initializer syntax defined by the TACO API. If format_pack is a list of
-    ModeFormat, replaces it with ModeFormatPack constructed from the list. If
-    ordering is not provided, set ordering to the natural ordering for the rank
-    corresponding to format_pack.
+    The storage format of a tensor contains one mode_format for each tensor
+    dimension.
 
-    Raises:
-       ValueError: If format_pack is not an instance of ModeFormatPack or if
-         ordering is not an instance of ModeOrdering.
+    Attributes:
+      formats: A list of ModeFormat representing the storage format for each of
+        the tensor dimension.
     """
-    if isinstance(self.format_pack, list):
-      if not _all_instance_of(self.format_pack, ModeFormat):
-        raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}")
-      self.format_pack = ModeFormatPack(self.format_pack)
-    if not isinstance(self.format_pack, ModeFormatPack):
-      raise ValueError(f"Expected ModeFormatpack: {self.format_pack}")
-
-    if self.ordering is None:
-      self.ordering = ModeOrdering(list(range(self.rank())))
-    if isinstance(self.ordering, list):
-      if not _all_instance_of(self.ordering, int):
-        raise ValueError(f"Expected a list of integer: {self.ordering}")
-      self.ordering = ModeOrdering(self.ordering)
-    if not isinstance(self.ordering, ModeOrdering):
-      raise ValueError(f"Expected ModeOrdering: {self.ordering}")
-
-    if self.format_pack.rank() != self.ordering.rank():
-      raise ValueError("Inconsistent ModeFormatPack and ModeOrdering: "
-                       f"len({self.format_pack}) != "
-                       f"len({self.ordering})")
-
-  def rank(self) -> int:
-    """Returns the number of dimensions represented by the format."""
-    return self.format_pack.rank()
-
-  def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]:
-    """Constructs the numpy arrays for the permutation and sparsity."""
-    perm = np.array(self.ordering.ordering, dtype=np.ulonglong)
-    a = [f.value for f in self.format_pack.formats]
-    sparse = np.array(a, dtype=np.uint8)
-    return (perm, sparse)
-
-  def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]:
-    """Constructs the MLIR attributes for the tensor format."""
-    order = (
-        range(self.rank()) if
-        (self.ordering is None) else self.ordering.ordering)
-    mlir_storage_format = [f.value for f in self.format_pack.formats]
-    return sparse_tensor.EncodingAttr.get(mlir_storage_format,
-                                          ir.AffineMap.get_permutation(order),
-                                          None, _POS_WIDTH, _CRD_WIDTH)
-
-
-def _make_format(formats: List[ModeFormat],
-                 ordering: Optional[List[int]] = None) -> Format:
-  """Constructs a format from a list of ModeFormat and an optional ordering.
-
-  Args:
-    formats: A list of ModeFormat, one for each dimension of a tensor.
-    ordering: An optional list of integer, for the ordering of the tensor
-      dimensions. When an ordering is not given, the identity ordering is used.
-
-  Returns:
-    A tensor format object.
-
-  Raises:
-    ValueError: If formats is not a list of ModeFormat or the length of formats
-      is not consistent with the len of ordering.
-  """
-  ordering = ordering or _identity_ordering(len(formats))
-  return Format(ModeFormatPack(formats), ModeOrdering(ordering))
 
+    formats: List[ModeFormat]
 
-class IndexExpr(abc.ABC):
-  """The index notation base class.
+    def __post_init__(self) -> None:
+        """Verifies the value in formats.
 
-  We support the TACO API index_expression class with an alias of this class.
-  """
+        Raises:
+           ValueError: If formats is not a list of ModeFormats.
+        """
+        if not isinstance(self.formats, list) or not _all_instance_of(
+            self.formats, ModeFormat
+        ):
+            raise ValueError("Formats must be a list of ModeFormat: " f"{self.formats}")
 
-  def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
-    """Verifies the RHS operand and returns a binary expression.
+    def rank(self) -> int:
+        """Returns the number of dimensions represented by the format pack."""
+        return len(self.formats)
 
-    Args:
-      rhs: The RHS of the binary operation, which could be any Python object
-        from user inputs.
-      op: A _BinaryOp object representing the binary operator.
 
-    Raises:
-      ValueError: If rhs is not an IndexExpr.
-    """
-    if not isinstance(rhs, IndexExpr):
-      raise ValueError(f"Expected IndexExpr: {rhs}")
-    return _BinaryExpr(op, self, rhs)
-
-  def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
-    """Build a unary expression.
+ at dataclasses.dataclass
+class Format:
+    """The tensor format class defined by the TACO API.
 
-    Args:
-      op: A _UnaryOp object representing the unary operation.
+    Attributes:
+      format_pack: A ModeFormatPack representing the storage format for the tensor
+        dimensions.
+      ordering: A ModeOrdering representing the tensor dimension ordering in the
+        storage.
     """
-    return _UnaryExpr(op, self)
 
-  def __add__(self, rhs) -> "_BinaryExpr":
-    """Defines the operator +.
+    format_pack: ModeFormatPack
+    ordering: Optional[ModeOrdering] = None
+
+    def __post_init__(self) -> None:
+        """Verifies and fixes up the values in format_pack and ordering.
+
+        Verifies and fixes up the values in format_pack and ordering to supports the
+        initializer syntax defined by the TACO API. If format_pack is a list of
+        ModeFormat, replaces it with ModeFormatPack constructed from the list. If
+        ordering is not provided, set ordering to the natural ordering for the rank
+        corresponding to format_pack.
+
+        Raises:
+           ValueError: If format_pack is not an instance of ModeFormatPack or if
+             ordering is not an instance of ModeOrdering.
+        """
+        if isinstance(self.format_pack, list):
+            if not _all_instance_of(self.format_pack, ModeFormat):
+                raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}")
+            self.format_pack = ModeFormatPack(self.format_pack)
+        if not isinstance(self.format_pack, ModeFormatPack):
+            raise ValueError(f"Expected ModeFormatpack: {self.format_pack}")
+
+        if self.ordering is None:
+            self.ordering = ModeOrdering(list(range(self.rank())))
+        if isinstance(self.ordering, list):
+            if not _all_instance_of(self.ordering, int):
+                raise ValueError(f"Expected a list of integer: {self.ordering}")
+            self.ordering = ModeOrdering(self.ordering)
+        if not isinstance(self.ordering, ModeOrdering):
+            raise ValueError(f"Expected ModeOrdering: {self.ordering}")
+
+        if self.format_pack.rank() != self.ordering.rank():
+            raise ValueError(
+                "Inconsistent ModeFormatPack and ModeOrdering: "
+                f"len({self.format_pack}) != "
+                f"len({self.ordering})"
+            )
+
+    def rank(self) -> int:
+        """Returns the number of dimensions represented by the format."""
+        return self.format_pack.rank()
+
+    def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]:
+        """Constructs the numpy arrays for the permutation and sparsity."""
+        perm = np.array(self.ordering.ordering, dtype=np.ulonglong)
+        a = [f.value for f in self.format_pack.formats]
+        sparse = np.array(a, dtype=np.uint8)
+        return (perm, sparse)
+
+    def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]:
+        """Constructs the MLIR attributes for the tensor format."""
+        order = (
+            range(self.rank()) if (self.ordering is None) else self.ordering.ordering
+        )
+        mlir_storage_format = [f.value for f in self.format_pack.formats]
+        return sparse_tensor.EncodingAttr.get(
+            mlir_storage_format,
+            ir.AffineMap.get_permutation(order),
+            None,
+            _POS_WIDTH,
+            _CRD_WIDTH,
+        )
+
+
+def _make_format(
+    formats: List[ModeFormat], ordering: Optional[List[int]] = None
+) -> Format:
+    """Constructs a format from a list of ModeFormat and an optional ordering.
 
     Args:
-      rhs: The value being added, which could be any Python object from user
-        inputs.
+      formats: A list of ModeFormat, one for each dimension of a tensor.
+      ordering: An optional list of integer, for the ordering of the tensor
+        dimensions. When an ordering is not given, the identity ordering is used.
 
     Returns:
-      A _BinaryExpr object representing the operation.
+      A tensor format object.
 
     Raises:
-      ValueError: If rhs is not an IndexExpr.
+      ValueError: If formats is not a list of ModeFormat or the length of formats
+        is not consistent with the len of ordering.
     """
-    return self._verify_operand_and_build_expr(rhs, operator.add)
+    ordering = ordering or _identity_ordering(len(formats))
+    return Format(ModeFormatPack(formats), ModeOrdering(ordering))
 
-  def __mul__(self, rhs) -> "_BinaryExpr":
-    """Defines the operator *.
-
-    Args:
-      rhs: The value being multiplied, which could be any Python object from
-        user inputs.
 
-    Returns:
-      A _BinaryExpr object representing the operation.
+class IndexExpr(abc.ABC):
+    """The index notation base class.
 
-    Raises:
-      ValueError: If rhs is not an IndexExpr.
+    We support the TACO API index_expression class with an alias of this class.
     """
-    return self._verify_operand_and_build_expr(rhs, operator.mul)
 
-  def __abs__(self) -> "_UnaryExpr":
-    """Defines the operator abs.
-
-    Returns:
-      A _UnaryExpr object representing the operation.
-    """
-    return self._build_unary_expr(operator.abs)
+    def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
+        """Verifies the RHS operand and returns a binary expression.
+
+        Args:
+          rhs: The RHS of the binary operation, which could be any Python object
+            from user inputs.
+          op: A _BinaryOp object representing the binary operator.
+
+        Raises:
+          ValueError: If rhs is not an IndexExpr.
+        """
+        if not isinstance(rhs, IndexExpr):
+            raise ValueError(f"Expected IndexExpr: {rhs}")
+        return _BinaryExpr(op, self, rhs)
+
+    def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
+        """Build a unary expression.
+
+        Args:
+          op: A _UnaryOp object representing the unary operation.
+        """
+        return _UnaryExpr(op, self)
+
+    def __add__(self, rhs) -> "_BinaryExpr":
+        """Defines the operator +.
+
+        Args:
+          rhs: The value being added, which could be any Python object from user
+            inputs.
+
+        Returns:
+          A _BinaryExpr object representing the operation.
+
+        Raises:
+          ValueError: If rhs is not an IndexExpr.
+        """
+        return self._verify_operand_and_build_expr(rhs, operator.add)
+
+    def __mul__(self, rhs) -> "_BinaryExpr":
+        """Defines the operator *.
+
+        Args:
+          rhs: The value being multiplied, which could be any Python object from
+            user inputs.
+
+        Returns:
+          A _BinaryExpr object representing the operation.
+
+        Raises:
+          ValueError: If rhs is not an IndexExpr.
+        """
+        return self._verify_operand_and_build_expr(rhs, operator.mul)
+
+    def __abs__(self) -> "_UnaryExpr":
+        """Defines the operator abs.
+
+        Returns:
+          A _UnaryExpr object representing the operation.
+        """
+        return self._build_unary_expr(operator.abs)
+
+    def __neg__(self) -> "_UnaryExpr":
+        """Defines the operator neg.
+
+        Returns:
+          A _UnaryExpr object representing the operation.
+        """
+        return self._build_unary_expr(operator.neg)
+
+    def __sub__(self, rhs) -> "_BinaryExpr":
+        """Defines the operator -.
+
+        Args:
+          rhs: The value being subtracted, which could be any Python object from
+            user inputs.
+
+        Returns:
+          A _BinaryExpr object representing the operation.
+
+        Raises:
+          ValueError: If rhs is not an IndexExpr.
+        """
+        return self._verify_operand_and_build_expr(rhs, operator.sub)
+
+    @abc.abstractmethod
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        """A post-order visitor.
+
+        Args:
+          func: A callable applied to each node in the expression tree.
+          args: The variable-length arguments passed to the callable. These
+            arguments are grouped as an iterable and will be unpacked before passing
+            to the callable. This is to enable the keyword argument only syntax
+            after this argument.
+          leaf_checker: A callable object to identify nodes that should be treated
+            as leaf nodes to support partial tree visiting.
+        """
+        pass
+
+    @abc.abstractmethod
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits MLIR for the expression tree.
+
+        Args:
+          expr_to_opnd: A dictionary for looking up structured op input operands for
+            the input nodes of the structured op.
+          expr_to_info: A dictionary for looking up code generation information for
+            expressions.
+
+        Returns:
+          A linalg dialect ScalarExpression for the expression.
+        """
+        pass
+
+    @abc.abstractmethod
+    def dtype(self) -> DType:
+        """Returns the data type for the result of the expression."""
+        pass
+
+    def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
+        """Emits a structured op in the linalg dialect for the expression tree.
+
+        We define a DefineOpcallable in the domain specific language for the linalg
+        dialect and execute the callable to generate the structured op. Self is the
+        root of the expression tree for the structured op.
+
+        Args:
+          expr_to_info: A dictionary for looking up code generation information for
+            expressions.
+        """
+        op_info = expr_to_info[self].structop_info
+        op_name = op_info.dst_name
+        op_def = lang.LinalgOpDef(name=op_name)
+        op_callable = lang.DefinedOpCallable(op_name, op_def)
+
+        # Collect the input expression nodes for the structured op.
+        expr_inputs = []
+        self._visit(
+            _gather_structured_op_input,
+            (self, expr_to_info, expr_inputs),
+            leaf_checker=_is_structured_op_leaf,
+        )
+
+        # Create a linalg structured op operand for each input expression node and
+        # build a dictionary for looking up the information.
+        expr_to_input_opnd = {
+            e: _emit_structured_op_input(e, expr_to_info, op_def) for e in expr_inputs
+        }
+
+        # Emit the expression tree, which produces the value assigned to the
+        # destination tensor.
+        value = self._emit_expression(expr_to_input_opnd, expr_to_info)
+        # Emit the structured op representation for the destination tensor.
+        dst_opnd = _emit_operand(
+            op_def,
+            op_info.dst_indices,
+            op_info.dst_name,
+            lang.OperandKind.OUTPUT_TENSOR,
+        )
+        dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+        dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
+
+        expr_info = expr_to_info[self]
+        # If the structured op reduces some indices, explicitly represent the
+        # reduction. This is done by generating a ReduceFn for the dimensions being
+        # reduced in the linalg dialect and calling the function with the value
+        # being reduced. We only support add reduction currently.
+        if expr_info.reduce_indices:
+            reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
+            value = lang.ReduceFn.add[reduce_dims](value)
+
+        # Emit the assignment as a comprehension in the linalg dialect.
+        comp = lang.Comprehension((dst_use, value))
+        op_def.comprehensions.append(comp)
+
+        # The structured op in the linalg dialect requires an explicit
+        # initialization for the destination tensor. Emit MLIR to initialize the
+        # destination tensor.
+        init = op_info.emit_tensor_init()
+
+        # Collect MLIR values for the linalg input operands, with the assumption
+        # that dictionary preserves the insertion order.
+        args = [
+            expr_to_info[expr].mlir_value for expr, opnd in expr_to_input_opnd.items()
+        ]
+        # Execute the DefineOpcallable object for the linalg dialect operation to
+        # emit MLIR for the linalg structured op.
+        expr_info.mlir_value = op_callable(*args, outs=[init])
+
+    def _identify_structured_ops(
+        self,
+        expr_to_info: _ExprInfoDict,
+        dst: "Tensor",
+        dst_indices: Tuple["IndexVar", ...],
+    ) -> List["IndexExpr"]:
+        """Returns expression nodes for the roots of the identified structured ops.
+
+        A structured op in the linalg dialect only supports reduction performed on
+        the whole expression. If the expression tree contains reduction that are
+        performed on part of the expression tree, the expression tree needs to be
+        implemented with multiple structured ops. This routine identifies all the
+        expression nodes that contain reduction as the root of structured ops in the
+        linalg dialect.
+
+        Args:
+          expr_to_info: A dictionary for looking up code generation information for
+            expressions.
+          dst: A destination Tensor that accepts the value of the expression tree.
+          dst_indices: The indices used by the destination index expression.
+
+        Returns:
+          An ordered list of IndexExpr for the root expressions of the structured
+          ops, where child expressions go before parent expressions that use their
+          results.
+        """
+        reduce_indices = tuple(set(expr_to_info[self].src_indices) - set(dst_indices))
+        for reduce_index in reduce_indices:
+            _mark_structured_op_root(self, reduce_index, expr_to_info)
+
+        self._visit(_accumulate_reduce_indices, (expr_to_info,))
+        structop_roots = []
+        self._visit(_gather_structured_op, (expr_to_info, structop_roots))
+
+        # Handle the root of the top level expression.
+        if not structop_roots or structop_roots[-1] != self:
+            # The top level expression is not a reduction. Add the top level
+            # expression as a structured op root.
+            structop_roots.append(self)
+
+        # Use user specified information for the destination tensor to build an
+        # _StructOpInfo for the top level expression.
+        expr_to_info[self].structop_info = _StructOpInfo(
+            dst_indices, tuple(dst.shape), dst.dtype, dst.name, dst.format
+        )
+
+        return structop_roots
+
+    def _validate_and_collect_expr_info(
+        self,
+        dst: "Tensor",
+        dst_indices: Tuple["IndexVar", ...],
+    ) -> _ExprInfoDict:
+        """Propagates expression information for validation.
+
+        Propagates the indices used by child expression nodes to parent expression
+        nodes. Also collects and validates the sizes for the dimensions
+        corresponding to the indices.
+
+        Args:
+          dst: A destination Tensor that accepts the value of the expression tree.
+          dst_indices: The indices used by the destination index expression.
+
+        Raises:
+          ValueError if there is any inconsistency in indices or dimensional
+          values.
+
+        Returns:
+          A dictionary of (IndexExpr, _ExprInfo).
+        """
+        expr_to_info = {}
+        # Validate the expression tree and construct expression information.
+        self._visit(_validate_and_collect_expr_info, (expr_to_info,))
+
+        # Validate the destination dimension information.
+        info = expr_to_info[self]
+        index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
+        for (
+            i,
+            d,
+        ) in zip(dst_indices, dst.shape):
+            if i not in index_to_dim_info:
+                raise ValueError(
+                    "Destination IndexVar not used in the " f"source expression: {i}"
+                )
+            else:
+                if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1:
+                    raise ValueError(
+                        f"Inconsistent destination dimension for {i}: "
+                        f"{d} vs {index_to_dim_info[i].dim}"
+                    )
+
+        return expr_to_info
+
+    def _emit_assignment(
+        self,
+        module: ir.Module,
+        dst: "Tensor",
+        dst_indices: Tuple["IndexVar", ...],
+        expr_to_info: _ExprInfoDict,
+        input_accesses: List["Access"],
+    ) -> None:
+        """Emits an MLIR function for assigning the expression to a tensor."""
+        input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
+
+        # Build the kernel for the operations.
+        with ir.InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
+            def linalg_funcop(*args):
+                # Set up the mapping from the Access nodes to their MLIR values.
+                for e, mlir in zip(input_accesses, args):
+                    expr_to_info[e].mlir_value = mlir
+
+                # Emit structured ops in the linalg dialect to implement the assignment.
+                for structop_root in self._identify_structured_ops(
+                    expr_to_info, dst, dst_indices
+                ):
+                    structop_root._emit_structured_op(expr_to_info)
+                    dst._record_stats(expr_to_info[structop_root].structop_info)
+
+                # The function returns the MLIR value of the root expression.
+                return expr_to_info[self].mlir_value
+
+            linalg_funcop.func_op.attributes[
+                "llvm.emit_c_interface"
+            ] = ir.UnitAttr.get()
+
+    def get_input_accesses(self) -> List["Access"]:
+        """Compute the list of input accesses for the expression."""
+        input_accesses = []
+        self._visit(_gather_input_accesses_index_vars, (input_accesses,))
+        return input_accesses
+
+    def compile(
+        self,
+        dst: "Tensor",
+        dst_indices: Tuple["IndexVar", ...],
+    ) -> execution_engine.ExecutionEngine:
+        """Compiles the tensor assignment dst[dst_indices] = expression.
+
+        Args:
+          dst: The destination tensor.
+          dst_indices: The tuple of IndexVar used to access the destination tensor.
+
+        Returns:
+          The execution engine for the tensor assignment.
+
+        Raises:
+          ValueError: If the expression is not proper or not supported.
+        """
+        expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
+        input_accesses = self.get_input_accesses()
+
+        # Build and compile the module to produce the execution engine.
+        with ir.Context(), ir.Location.unknown():
+            module = ir.Module.create()
+            self._emit_assignment(
+                module, dst, dst_indices, expr_to_info, input_accesses
+            )
+            engine = utils.compile_and_build_engine(module)
+
+        return engine
 
-  def __neg__(self) -> "_UnaryExpr":
-    """Defines the operator neg.
 
-    Returns:
-      A _UnaryExpr object representing the operation.
-    """
-    return self._build_unary_expr(operator.neg)
+class _AtomicCounter:
+    """An atomic counter."""
 
-  def __sub__(self, rhs) -> "_BinaryExpr":
-    """Defines the operator -.
+    def __init__(self):
+        self._counter = 0
+        self._counter_lock = threading.Lock()
 
-    Args:
-      rhs: The value being subtracted, which could be any Python object from
-        user inputs.
+    def increment(self) -> int:
+        """Increments the counter by one and returns the old value."""
+        old_value = self._counter
+        with self._counter_lock:
+            self._counter = self._counter + 1
+        return old_value
 
-    Returns:
-      A _BinaryExpr object representing the operation.
 
-    Raises:
-      ValueError: If rhs is not an IndexExpr.
-    """
-    return self._verify_operand_and_build_expr(rhs, operator.sub)
-
-  @abc.abstractmethod
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    """A post-order visitor.
-
-    Args:
-      func: A callable applied to each node in the expression tree.
-      args: The variable-length arguments passed to the callable. These
-        arguments are grouped as an iterable and will be unpacked before passing
-        to the callable. This is to enable the keyword argument only syntax
-        after this argument.
-      leaf_checker: A callable object to identify nodes that should be treated
-        as leaf nodes to support partial tree visiting.
-    """
-    pass
+class IndexVar(IndexExpr):
+    """The tensor index class.
 
-  @abc.abstractmethod
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits MLIR for the expression tree.
+    We support the TACO API index_var class with an alias of this class.
 
-    Args:
-      expr_to_opnd: A dictionary for looking up structured op input operands for
-        the input nodes of the structured op.
-      expr_to_info: A dictionary for looking up code generation information for
-        expressions.
+    An IndexVar object represents an index variable in tensor index notation.
 
-    Returns:
-      A linalg dialect ScalarExpression for the expression.
+    Attributes:
+      name: A unique string name of the IndexVar.
     """
-    pass
 
-  @abc.abstractmethod
-  def dtype(self) -> DType:
-    """Returns the data type for the result of the expression."""
-    pass
+    _counter = _AtomicCounter()
 
-  def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
-    """Emits a structured op in the linalg dialect for the expression tree.
+    def __init__(self):
+        id = self._counter.increment()
+        self._name = f"{_TACO_INDEX_PREFIX}{id}"
 
-    We define a DefineOpcallable in the domain specific language for the linalg
-    dialect and execute the callable to generate the structured op. Self is the
-    root of the expression tree for the structured op.
+    def __repr__(self) -> str:
+        return f"IndexVar(name={repr(self._name)})"
 
-    Args:
-      expr_to_info: A dictionary for looking up code generation information for
-        expressions.
-    """
-    op_info = expr_to_info[self].structop_info
-    op_name = op_info.dst_name
-    op_def = lang.LinalgOpDef(name=op_name)
-    op_callable = lang.DefinedOpCallable(op_name, op_def)
-
-    # Collect the input expression nodes for the structured op.
-    expr_inputs = []
-    self._visit(
-        _gather_structured_op_input,
-        (self, expr_to_info, expr_inputs),
-        leaf_checker=_is_structured_op_leaf,
-    )
+    @property
+    def name(self) -> str:
+        """Returns the name of the IndexVar."""
+        return self._name
 
-    # Create a linalg structured op operand for each input expression node and
-    # build a dictionary for looking up the information.
-    expr_to_input_opnd = {
-        e: _emit_structured_op_input(e, expr_to_info, op_def)
-        for e in expr_inputs
-    }
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        """A post-order visitor."""
+        if leaf_checker:
+            assert leaf_checker(self, *args)
+        func(self, *args)
 
-    # Emit the expression tree, which produces the value assigned to the
-    # destination tensor.
-    value = self._emit_expression(expr_to_input_opnd, expr_to_info)
-    # Emit the structured op representation for the destination tensor.
-    dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
-                             lang.OperandKind.OUTPUT_TENSOR)
-    dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
-    dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
-
-    expr_info = expr_to_info[self]
-    # If the structured op reduces some indices, explicitly represent the
-    # reduction. This is done by generating a ReduceFn for the dimensions being
-    # reduced in the linalg dialect and calling the function with the value
-    # being reduced. We only support add reduction currently.
-    if expr_info.reduce_indices:
-      reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
-      value = lang.ReduceFn.add[reduce_dims](value)
-
-    # Emit the assignment as a comprehension in the linalg dialect.
-    comp = lang.Comprehension((dst_use, value))
-    op_def.comprehensions.append(comp)
-
-    # The structured op in the linalg dialect requires an explicit
-    # initialization for the destination tensor. Emit MLIR to initialize the
-    # destination tensor.
-    init = op_info.emit_tensor_init()
-
-    # Collect MLIR values for the linalg input operands, with the assumption
-    # that dictionary preserves the insertion order.
-    args = [
-        expr_to_info[expr].mlir_value
-        for expr, opnd in expr_to_input_opnd.items()
-    ]
-    # Execute the DefineOpcallable object for the linalg dialect operation to
-    # emit MLIR for the linalg structured op.
-    expr_info.mlir_value = op_callable(*args, outs=[init])
-
-  def _identify_structured_ops(
-      self,
-      expr_to_info: _ExprInfoDict,
-      dst: "Tensor",
-      dst_indices: Tuple["IndexVar", ...],
-  ) -> List["IndexExpr"]:
-    """Returns expression nodes for the roots of the identified structured ops.
-
-    A structured op in the linalg dialect only supports reduction performed on
-    the whole expression. If the expression tree contains reduction that are
-    performed on part of the expression tree, the expression tree needs to be
-    implemented with multiple structured ops. This routine identifies all the
-    expression nodes that contain reduction as the root of structured ops in the
-    linalg dialect.
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits a index value casted to the data type of the tensor expression."""
+        dim = getattr(lang.D, self.name)
+        index = lang.index(dim)
+        int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index)
+        return lang.TypeFn.cast_unsigned(lang.T, int_value)
 
-    Args:
-      expr_to_info: A dictionary for looking up code generation information for
-        expressions.
-      dst: A destination Tensor that accepts the value of the expression tree.
-      dst_indices: The indices used by the destination index expression.
+    def dtype(self) -> DType:
+        """Returns the data type for the index value.
 
-    Returns:
-      An ordered list of IndexExpr for the root expressions of the structured
-      ops, where child expressions go before parent expressions that use their
-      results.
-    """
-    reduce_indices = tuple(
-        set(expr_to_info[self].src_indices) - set(dst_indices))
-    for reduce_index in reduce_indices:
-      _mark_structured_op_root(self, reduce_index, expr_to_info)
-
-    self._visit(_accumulate_reduce_indices, (expr_to_info,))
-    structop_roots = []
-    self._visit(_gather_structured_op, (expr_to_info, structop_roots))
-
-    # Handle the root of the top level expression.
-    if not structop_roots or structop_roots[-1] != self:
-      # The top level expression is not a reduction. Add the top level
-      # expression as a structured op root.
-      structop_roots.append(self)
-
-    # Use user specified information for the destination tensor to build an
-    # _StructOpInfo for the top level expression.
-    expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
-                                                     tuple(dst.shape),
-                                                     dst.dtype, dst.name,
-                                                     dst.format)
-
-    return structop_roots
-
-  def _validate_and_collect_expr_info(
-      self,
-      dst: "Tensor",
-      dst_indices: Tuple["IndexVar", ...],
-  ) -> _ExprInfoDict:
-    """Propagates expression information for validation.
-
-    Propagates the indices used by child expression nodes to parent expression
-    nodes. Also collects and validates the sizes for the dimensions
-    corresponding to the indices.
+        This is unreachable for IndexVar.
+        """
+        assert 0
 
-    Args:
-      dst: A destination Tensor that accepts the value of the expression tree.
-      dst_indices: The indices used by the destination index expression.
 
-    Raises:
-      ValueError if there is any inconsistency in indices or dimensional
-      values.
+def get_index_vars(n: int) -> List[IndexVar]:
+    """Returns a list of n IndexVar.
 
-    Returns:
-      A dictionary of (IndexExpr, _ExprInfo).
-    """
-    expr_to_info = {}
-    # Validate the expression tree and construct expression information.
-    self._visit(_validate_and_collect_expr_info, (expr_to_info,))
-
-    # Validate the destination dimension information.
-    info = expr_to_info[self]
-    index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
-    for i, d, in zip(dst_indices, dst.shape):
-      if i not in index_to_dim_info:
-        raise ValueError("Destination IndexVar not used in the "
-                         f"source expression: {i}")
-      else:
-        if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1:
-          raise ValueError(f"Inconsistent destination dimension for {i}: "
-                           f"{d} vs {index_to_dim_info[i].dim}")
-
-    return expr_to_info
-
-  def _emit_assignment(
-      self,
-      module: ir.Module,
-      dst: "Tensor",
-      dst_indices: Tuple["IndexVar", ...],
-      expr_to_info: _ExprInfoDict,
-      input_accesses: List["Access"],
-  ) -> None:
-    """Emits an MLIR function for assigning the expression to a tensor."""
-    input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
-
-    # Build the kernel for the operations.
-    with ir.InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
-      def linalg_funcop(*args):
-        # Set up the mapping from the Access nodes to their MLIR values.
-        for e, mlir in zip(input_accesses, args):
-          expr_to_info[e].mlir_value = mlir
-
-        # Emit structured ops in the linalg dialect to implement the assignment.
-        for structop_root in self._identify_structured_ops(
-            expr_to_info, dst, dst_indices):
-          structop_root._emit_structured_op(expr_to_info)
-          dst._record_stats(expr_to_info[structop_root].structop_info)
-
-        # The function returns the MLIR value of the root expression.
-        return expr_to_info[self].mlir_value
-
-      linalg_funcop.func_op.attributes[
-          "llvm.emit_c_interface"] = ir.UnitAttr.get()
-
-  def get_input_accesses(self) -> List["Access"]:
-    """Compute the list of input accesses for the expression."""
-    input_accesses = []
-    self._visit(_gather_input_accesses_index_vars, (input_accesses,))
-    return input_accesses
-
-  def compile(
-      self,
-      dst: "Tensor",
-      dst_indices: Tuple["IndexVar", ...],
-  ) -> execution_engine.ExecutionEngine:
-    """Compiles the tensor assignment dst[dst_indices] = expression.
+    This routine is defined by the TACO API.
 
     Args:
-      dst: The destination tensor.
-      dst_indices: The tuple of IndexVar used to access the destination tensor.
+      n: An integer representing the number of IndexVar to get.
 
     Returns:
-      The execution engine for the tensor assignment.
+      A list of IndexVar.
 
     Raises:
-      ValueError: If the expression is not proper or not supported.
+      ValueError: if n is not a positive integer.
     """
-    expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
-    input_accesses = self.get_input_accesses()
-
-    # Build and compile the module to produce the execution engine.
-    with ir.Context(), ir.Location.unknown():
-      module = ir.Module.create()
-      self._emit_assignment(module, dst, dst_indices, expr_to_info,
-                            input_accesses)
-      engine = utils.compile_and_build_engine(module)
-
-    return engine
-
-
-class _AtomicCounter:
-  """An atomic counter."""
-
-  def __init__(self):
-    self._counter = 0
-    self._counter_lock = threading.Lock()
-
-  def increment(self) -> int:
-    """Increments the counter by one and returns the old value."""
-    old_value = self._counter
-    with self._counter_lock:
-      self._counter = self._counter + 1
-    return old_value
-
-
-class IndexVar(IndexExpr):
-  """The tensor index class.
-
-  We support the TACO API index_var class with an alias of this class.
-
-  An IndexVar object represents an index variable in tensor index notation.
-
-  Attributes:
-    name: A unique string name of the IndexVar.
-  """
-  _counter = _AtomicCounter()
-
-  def __init__(self):
-    id = self._counter.increment()
-    self._name = f"{_TACO_INDEX_PREFIX}{id}"
-
-  def __repr__(self) -> str:
-    return f"IndexVar(name={repr(self._name)})"
-
-  @property
-  def name(self) -> str:
-    """Returns the name of the IndexVar."""
-    return self._name
-
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    """A post-order visitor."""
-    if leaf_checker:
-      assert leaf_checker(self, *args)
-    func(self, *args)
-
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits a index value casted to the data type of the tensor expression."""
-    dim = getattr(lang.D, self.name)
-    index = lang.index(dim)
-    int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index)
-    return lang.TypeFn.cast_unsigned(lang.T, int_value)
-
-  def dtype(self) -> DType:
-    """Returns the data type for the index value.
-
-    This is unreachable for IndexVar.
-    """
-    assert 0
-
-
-def get_index_vars(n: int) -> List[IndexVar]:
-  """Returns a list of n IndexVar.
-
-  This routine is defined by the TACO API.
-
-  Args:
-    n: An integer representing the number of IndexVar to get.
-
-  Returns:
-    A list of IndexVar.
-
-  Raises:
-    ValueError: if n is not a positive integer.
-  """
-  if not isinstance(n, int) or n <= 0:
-    raise ValueError(f"Expected an integer: {n}.")
-  # If lock contention ever becomes an issue, we could implement a bulk getter
-  # that returns a range by only claiming the lock once.
-  return [IndexVar() for i in range(n)]
+    if not isinstance(n, int) or n <= 0:
+        raise ValueError(f"Expected an integer: {n}.")
+    # If lock contention ever becomes an issue, we could implement a bulk getter
+    # that returns a range by only claiming the lock once.
+    return [IndexVar() for i in range(n)]
 
 
 def _mlir_symbols_from_index_vars(
-    index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.SymbolDef, ...]:
-  """Returns a tuple of MLIR symbols for the given tuple of index_var."""
-  return tuple(getattr(lang.S, i.name) for i in index_vars)
+    index_vars: Tuple[IndexVar, ...]
+) -> Tuple[lang.SymbolDef, ...]:
+    """Returns a tuple of MLIR symbols for the given tuple of index_var."""
+    return tuple(getattr(lang.S, i.name) for i in index_vars)
 
 
 def _mlir_dimensions_from_index_vars(
-    index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.DimDef, ...]:
-  """Returns a tuple of MLIR dimensions for the given tuple of index_var."""
-  return tuple(getattr(lang.D, i.name) for i in index_vars)
+    index_vars: Tuple[IndexVar, ...]
+) -> Tuple[lang.DimDef, ...]:
+    """Returns a tuple of MLIR dimensions for the given tuple of index_var."""
+    return tuple(getattr(lang.D, i.name) for i in index_vars)
 
 
 def _mlir_tensor_type(
-    dtype: DType, shape: Tuple[int, ...],
-    attr: Optional[sparse_tensor.EncodingAttr]) -> ir.RankedTensorType:
-  """Returns an MLIR tensor type.
-
-  Args:
-    dtype: An DType object for the element data type of the tensor.
-    shape: A tuple of integer for the shape of the tensor.
-    attr: An optional MLIR sparse tensor attribute, only provided if the tensor
-      is a sparse tensor.
-
-  Returns:
-    An MLIR ranked tensor type.
-  """
-  ir_type = _mlir_type_from_taco_type(dtype)
-  return ir.RankedTensorType.get(shape, ir_type, attr)
-
-
- at dataclasses.dataclass(frozen=True)
-class _StructOpInfo:
-  """Information for generating a structured op in the linalg dialect.
-
-  This information is associated with an expression node that serves as the
-  root for an expression subtree implemented with a structured op.
-
-  Attributes:
-    dst_indices: A tuple of IndexVar, representing the result dimensions of the
-      structured op. This is used to construct the temporary variable for the
-      tensor to hold the structured op result.
-    dst_dims: A tuple of int, representing the result shape of the structured
-      op.
-    dst_dtype: A DType representing the data type of the structured op result.
-    dst_name: A string representing the name of the structured op result.
-    dst_format: An optional Format object representing the destination tensor
-      format. None represents a true dense tensor.
-  """
-  dst_indices: Tuple[IndexVar, ...]
-  dst_dims: Tuple[int, ...]
-  dst_dtype: DType
-  dst_name: str
-  dst_format: Optional[Format]
-
-  def __post_init__(self) -> None:
-    """Verifies the integrity of the attribute values."""
-    assert len(self.dst_indices) == len(self.dst_dims)
-
-  def emit_tensor_init(self) -> ir.RankedTensorType:
-    """Returns an initialization for the destination tensor."""
-    if self.dst_format is None or self.dst_format.rank() == 0:
-      # Initialize the dense tensor.
-      ir_type = _mlir_type_from_taco_type(self.dst_dtype)
-      empty = tensor.EmptyOp(self.dst_dims, ir_type).result
-      zero = arith.ConstantOp(ir_type, 0.0)
-      return linalg.fill(zero, outs=[empty])
-
-    # Initialize the sparse tensor.
-    mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims,
-                                  self.dst_format.mlir_tensor_attr())
-    index_type = ir.IndexType.get()
-    return bufferization.AllocTensorOp(mlir_type, [], None, None, None)
-
-
-class _Stats:
-  """Information to describe how a tensor expression is implemented.
-
-  Currently, we only record the temporary tensors introduced for splitting the
-  original expression.
-  """
-
-  def __init__(self):
-    self._temps = []
-
-  def __repr__(self) -> str:
-    return f"_Stats({repr(self._temps)})"
-
-  def add_element(self, structop: _StructOpInfo):
-    """Adds a temporary tensor."""
-    self._temps.append(structop)
-
-  def get_total(self) -> int:
-    """Gets the total number of temporary tensors."""
-    return len(self._temps)
-
-  def _get_element(self, idx: int) -> _StructOpInfo:
-    """Gets the ith temporary tensor."""
-    assert idx < self.get_total()
-    return self._temps[idx]
-
-  def get_dimensions(self, idx: int) -> Tuple[int]:
-    """Gets the dimensions for the ith temporary tensor."""
-    return self._get_element(idx).dst_dims
-
-  def get_formats(self, idx: int) -> Tuple[ModeFormat]:
-    """Gets the ModeFormats for the ith temporary tensor."""
-    return tuple(self._get_element(idx).dst_format.format_pack.formats)
-
-
-class _SparseValueInfo(enum.Enum):
-  """Describes how a sparse tensor value is stored.
-  _UNPACKED: The sparse tensor value is stored as (coordnates, values) in
-    Python.
-  _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR
-    sparse tensor.
-  """
-  _UNPACKED = 0
-  _PACKED = 1
-
-
- at dataclasses.dataclass(frozen=True)
-class _Assignment:
-  """Records an assignment to a tensor T as T[indices] = expression."""
-  indices: Tuple["IndexVar", ...]
-  expression: "IndexExpr"
-
-
-class Tensor:
-  """The tensor class.
-
-  We support the TACO API tensor class with an alias of this class.
-
-  This class is part of the TACO API with the following methods:
-    insert: Inserts a value to the given coordinate in the tensor.
-    to_array: Returns a numpy ndarray for the tensor.
-
-  TACO API also defines the following arrtibutes for the class:
-    dtype: A dtype object representing the data type of the tensor.
-    format: A format object representing the storage format of the tensor.
-    name: A string object representing the name of the tensor.
-    order: An integral rank of the tensor.
-    shape: A list of integers representing the shape of the tensor.
-
-  We currently ignore the tensor dimension ordering for dense tensor.
-  """
-  _counter = _AtomicCounter()
-
-  def _get_unique_name(self) -> str:
-    """Returns a unique name for creating a new Tensor."""
-    return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}"
-
-  def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat],
-                                    Format]) -> None:
-    """Process the fmt argument for the Tensor constructor.
-
-    Args:
-      fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
-        this argument is a ModeFormat, uses this ModeFormat for all the tensor
-        dimensions. If this argument is a list of ModeFormat, the len of the
-        list should equal to the rank of the tensor. If this argument is a
-        format, uses it for the format of the tensor.
-
-    Raises:
-      ValueError: If fmt is not one of the expected type or is inconsistent
-        with the rank of the tensor. This is because fmt could be an users
-        input.
-    """
-    if isinstance(fmt, ModeFormat):
-      self._format = _make_format([fmt] * self.order)
-    elif isinstance(fmt, list):
-      if len(fmt) == self.order and isinstance(fmt[0], ModeFormat):
-        self._format = _make_format(fmt)
-      else:
-        raise ValueError("Inconsistent shape and format: "
-                         f"{self._shape}, {fmt}.")
-    elif isinstance(fmt, Format):
-      if fmt.rank() != self.order:
-        raise ValueError("Inconsistent shape and format: "
-                         f"{self._shape}, {fmt}.")
-      else:
-        self._format = fmt
-    else:
-      raise ValueError(f"Invalid format argument: {fmt}.")
-
-  def __init__(self,
-               value_or_shape: Optional[Union[List[int], Tuple[int, ...],
-                                              complex, float, int]] = None,
-               fmt: Optional[Union[ModeFormat, List[ModeFormat],
-                                   Format]] = None,
-               dtype: Optional[DType] = None,
-               name: Optional[str] = None,
-               is_dense: bool = False):
-    """The tensor constructor interface defined by TACO API.
-
-    Args:
-      value_or_shape: This argument is optional and can be int, float,
-        List[int], or Tuple[int, ...]. If this argument is an int or float,
-        creates a scalar tensor and initializes it with the value. If this
-        argument is a list or tuple of int, uses it as the shape to create a
-        tensor.
-      fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
-        this argument is a ModeFormat, uses this ModeFormat for all the tensor
-        dimensions. If this argument is a list of ModeFormat, the len of the
-        list should equal to the rank of the tensor. If this argument is a
-        format, uses it for the format of the tensor.
-      dtype: An object of dtype, representing the data type of the tensor.
-      name: A string name of the tensor. If a name is not given, creates a
-        unique name for the tensor.
-      is_dense: A boolean variable to indicate whether the tensor is a dense
-        tensor without any sparsity annotation.
-
-    Raises:
-      ValueError: If there is any inconsistency among the input arguments.
-    """
-    # Take care of the argument default values common to both sparse tensors
-    # and dense tensors.
-    dtype = dtype or DType(Type.FLOAT32)
-    self._name = name or self._get_unique_name()
-    self._assignment = None
-    self._engine = None
-    self._sparse_value_location = _SparseValueInfo._UNPACKED
-    self._dense_storage = None
-    self._dtype = dtype
-
-    if is_dense:
-      assert (fmt is None)
-      assert (isinstance(value_or_shape, tuple) or isinstance(
-          value_or_shape, list)) and _all_instance_of(value_or_shape, int)
-      self._shape = value_or_shape
-      self._format = None
-      return
-
-    fmt = fmt or ModeFormat.COMPRESSED
-    # We currently use _coords and _values to host the sparse tensor value with
-    # COO format, and _dense_storage to host the dense tensor value. We don't
-    # support the conversion between the two storages.
-    self._coords = []
-    self._values = []
-    self._stats = _Stats()
-    if value_or_shape is None or isinstance(value_or_shape, int) or isinstance(
-        value_or_shape, float) or isinstance(value_or_shape, complex):
-      # Create a scalar tensor and ignore the fmt parameter.
-      self._shape = []
-      self._format = _make_format([], [])
-      if value_or_shape is not None:
-        self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value)
-    elif (isinstance(value_or_shape, tuple) or isinstance(
-        value_or_shape, list)) and _all_instance_of(value_or_shape, int):
-      # Create a tensor with the specified shape and format.
-      self._shape = list(value_or_shape)
-      self._init_format(fmt)
-    else:
-      raise ValueError("Invalid first argument. "
-                       "Must be a tuple or list for a shape or a single value"
-                       f"if initializing a scalar tensor: {value_or_shape}.")
-
-  def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None:
-    """Records the MLIR sparse tensor pointer."""
-    self._sparse_value_location = _SparseValueInfo._PACKED
-    self._packed_sparse_value = pointer
-
-  def is_unpacked(self) -> bool:
-    """Returns true if the tensor value is not packed as MLIR sparse tensor."""
-    return (self._sparse_value_location == _SparseValueInfo._UNPACKED)
-
-  def unpack(self) -> None:
-    """Unpacks the MLIR sparse tensor representation."""
-    if self.is_dense() or self.is_unpacked():
-      return
-
-    # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
-    # values and verify the values.
-    rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
-        self._packed_sparse_value, self._dtype.value)
-    assert rank == self.order
-    assert np.array_equal(self.shape, shape)
-    assert nse == len(values)
-    self._coords = indices
-    self._values = values
-    self._sparse_value_location = _SparseValueInfo._UNPACKED
-
-  def __repr__(self) -> str:
-    self._sync_value()
-    self.unpack()
-    value_str = (f"{repr(self._dense_storage)})" if self.is_dense() else
-                 f"{repr(self._coords)} {repr(self._values)})")
-    return (f"Tensor(_name={repr(self._name)} "
-            f"_dtype={repr(self._dtype)} : ") + value_str
-
-  def insert(self, coords: List[int], val: Union[complex, float, int]) -> None:
-    """Inserts a value to the given coordinate.
+    dtype: DType, shape: Tuple[int, ...], attr: Optional[sparse_tensor.EncodingAttr]
+) -> ir.RankedTensorType:
+    """Returns an MLIR tensor type.
 
     Args:
-      coords: A list of integer coordinates. The length of the list must be the
-        same as the rank of the tensor.
-      val: A value being inserted. It is either an integral or a floating point
-        value. This value will be converted to the data type of the tensor.
-
-    Raises:
-      ValueError: When there is any problem in the parameters.
-    """
-    if self.is_dense():
-      raise ValueError("Insert method is not supported for dense tensors.")
-    if self._assignment != None or not self.is_unpacked():
-      raise ValueError(
-          "Can't use Insert method for a tensor constructed from a file.")
-    if not isinstance(coords, list):
-      raise ValueError(f"Non list coordinate detected: {coords}.")
-    if not _all_instance_of(coords, int):
-      raise ValueError(f"Non integer coordinate detected: {coords}.")
-    if (len(coords) != self.order or
-        any([c < 0 or c >= self._shape[i] for i, c in enumerate(coords)])):
-      raise ValueError("Invalid coordinate for rank: "
-                       f"{self.order}, {coords}.")
-
-    if not isinstance(val, int) and not isinstance(
-        val, float) and not isinstance(val, complex):
-      raise ValueError(f"Value is neither int nor float: {val}.")
-
-    self._coords.append(tuple(coords))
-    self._values.append(self._dtype.value(val))
-
-  def is_dense(self) -> bool:
-    """Returns true if the tensor doesn't have sparsity annotation."""
-    return self.order == 0 or self._format is None
-
-  def to_array(self) -> np.ndarray:
-    """Returns the numpy array for the Tensor.
-
-    This is currenly only implemented for dense Tensor.
-    """
-    if not self.is_dense():
-      raise ValueError("Conversion from non-dense Tensor "
-                       "to numpy array not supported yet.")
-
-    self._sync_value()
-
-    return self._dense_storage
-
-  @staticmethod
-  def from_array(array: np.ndarray) -> "Tensor":
-    """Returns a dense tensor with the value copied from the input array.
-
-    We currently only support the conversion of float32 and float64 numpy arrays
-    to Tensor.
-
-    Args:
-      array: The numpy array that provides the data type, shape and value for
-        the tensor.
+      dtype: An DType object for the element data type of the tensor.
+      shape: A tuple of integer for the shape of the tensor.
+      attr: An optional MLIR sparse tensor attribute, only provided if the tensor
+        is a sparse tensor.
 
     Returns:
-      A Tensor object.
-
-    Raises:
-      ValueError if the data type of the numpy array is not supported.
+      An MLIR ranked tensor type.
     """
-    if array.dtype != np.float32 and array.dtype != np.float64:
-      raise ValueError(f"Expected floating point value type: {array.dtype}.")
-    t = Tensor(
-        array.shape,
-        dtype=_nptype_to_taco_type(array.dtype.type),
-        is_dense=True)
-    t._dense_storage = np.copy(array)
-    return t
-
-  @staticmethod
-  def from_coo(
-      coordinates: List[Tuple[int, ...]],
-      values: List[_AnyRuntimeType],
-      fmt: Format,
-      dtype: DType,
-  ) -> "Tensor":
-    """Converts coordinates and values to a sparse tensor representation.
+    ir_type = _mlir_type_from_taco_type(dtype)
+    return ir.RankedTensorType.get(shape, ir_type, attr)
 
-    Args:
-      coordinates: A list of coordinates with non-zero values.
-      values: The non-zero values.
-      fmt: The tensor storage format.
-      dtype: The tensor element data type.
 
-    Returns:
-      A tensor with the given non-zero values and storage format. The shape of
-      the tensor has the minimum size for each dimension to make the given
-      coordinates valid.
-    """
-    assert (isinstance(coordinates, List) and
-            _all_instance_of(coordinates, Tuple))
-    assert (isinstance(values, List) and _all_instance_of(values, dtype.value))
-    assert isinstance(fmt, Format)
-
-    rank = fmt.rank()
-    assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates)
-
-    # Find the maximum coordinate value for each dimension.
-    max_coordinate = list(map(max, zip(*coordinates)))
-    # The size of each dimension is one more that such a maximum coordinate
-    # value.
-    shape = [c + 1 for c in max_coordinate]
-    t = Tensor(shape, fmt, dtype=dtype)
-    t._coords = coordinates
-    t._values = values
-
-    return tensor
-
-  @staticmethod
-  def from_file(
-      filename: str,
-      fmt: Format,
-      dtype: DType,
-  ) -> "Tensor":
-    """Constructs a sparse tensor using the COO-flavored values from a file.
-
-    Args:
-      filename: A string for the name of the file that contains the sparse
-        tensor data.
-      fmt: The tensor storage format.
-      dtype: The tensor element data type.
-
-    Returns:
-      A tensor with the given non-zero values and storage format. The tensor
-      value is stored as an MLIR sparse tensor.
+ at dataclasses.dataclass(frozen=True)
+class _StructOpInfo:
+    """Information for generating a structured op in the linalg dialect.
+
+    This information is associated with an expression node that serves as the
+    root for an expression subtree implemented with a structured op.
+
+    Attributes:
+      dst_indices: A tuple of IndexVar, representing the result dimensions of the
+        structured op. This is used to construct the temporary variable for the
+        tensor to hold the structured op result.
+      dst_dims: A tuple of int, representing the result shape of the structured
+        op.
+      dst_dtype: A DType representing the data type of the structured op result.
+      dst_name: A string representing the name of the structured op result.
+      dst_format: An optional Format object representing the destination tensor
+        format. None represents a true dense tensor.
     """
-    sparse_tensor, shape = utils.create_sparse_tensor(filename,
-                                                      fmt.format_pack.formats,
-                                                      _dtype_to_mlir_str(dtype))
-    t = Tensor(shape.tolist(), fmt, dtype=dtype)
-    t._set_packed_sparse_tensor(sparse_tensor)
-
-    return t
 
-  def to_file(self, filename: str) -> None:
-    """Output the tensor value to a file.
+    dst_indices: Tuple[IndexVar, ...]
+    dst_dims: Tuple[int, ...]
+    dst_dtype: DType
+    dst_name: str
+    dst_format: Optional[Format]
+
+    def __post_init__(self) -> None:
+        """Verifies the integrity of the attribute values."""
+        assert len(self.dst_indices) == len(self.dst_dims)
+
+    def emit_tensor_init(self) -> ir.RankedTensorType:
+        """Returns an initialization for the destination tensor."""
+        if self.dst_format is None or self.dst_format.rank() == 0:
+            # Initialize the dense tensor.
+            ir_type = _mlir_type_from_taco_type(self.dst_dtype)
+            empty = tensor.EmptyOp(self.dst_dims, ir_type).result
+            zero = arith.ConstantOp(ir_type, 0.0)
+            return linalg.fill(zero, outs=[empty])
+
+        # Initialize the sparse tensor.
+        mlir_type = _mlir_tensor_type(
+            self.dst_dtype, self.dst_dims, self.dst_format.mlir_tensor_attr()
+        )
+        index_type = ir.IndexType.get()
+        return bufferization.AllocTensorOp(mlir_type, [], None, None, None)
 
-    This method evaluates any pending assignment to the tensor and outputs the
-    tensor value.
 
-    Args:
-      filename: A string file name.
+class _Stats:
+    """Information to describe how a tensor expression is implemented.
 
-    Raises:
-       ValueError: If the tensor is dense, or an unpacked sparse tensor.
+    Currently, we only record the temporary tensors introduced for splitting the
+    original expression.
     """
-    self._sync_value()
-
-    if self.is_dense():
-      raise ValueError("Writing dense tensors without sparsity annotation to "
-                       "file is not supported.")
 
-    if self.is_unpacked():
-      raise ValueError("Writing unpacked sparse tensors to file is not "
-                       "supported.")
+    def __init__(self):
+        self._temps = []
 
-    utils.output_sparse_tensor(self._packed_sparse_value, filename,
-                               self._format.format_pack.formats,
-                               _dtype_to_mlir_str(self._dtype))
+    def __repr__(self) -> str:
+        return f"_Stats({repr(self._temps)})"
 
-  @property
-  def dtype(self) -> DType:
-    """Returns the data type for the Tensor."""
-    return self._dtype
+    def add_element(self, structop: _StructOpInfo):
+        """Adds a temporary tensor."""
+        self._temps.append(structop)
 
-  @property
-  def format(self) -> Format:
-    """Returns the storage format for the Tensor."""
-    return self._format
+    def get_total(self) -> int:
+        """Gets the total number of temporary tensors."""
+        return len(self._temps)
 
-  @property
-  def name(self) -> str:
-    """Returns the name for the Tensor."""
-    return self._name
+    def _get_element(self, idx: int) -> _StructOpInfo:
+        """Gets the ith temporary tensor."""
+        assert idx < self.get_total()
+        return self._temps[idx]
 
-  @property
-  def order(self) -> int:
-    """Returns the rank of the Tensor."""
-    return len(self._shape)
+    def get_dimensions(self, idx: int) -> Tuple[int]:
+        """Gets the dimensions for the ith temporary tensor."""
+        return self._get_element(idx).dst_dims
 
-  @property
-  def shape(self) -> List[int]:
-    """Returns the shape of the Tensor."""
-    return self._shape
+    def get_formats(self, idx: int) -> Tuple[ModeFormat]:
+        """Gets the ModeFormats for the ith temporary tensor."""
+        return tuple(self._get_element(idx).dst_format.format_pack.formats)
 
-  def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
-    """Verifies and normalizes the indices to access the tensor.
 
-    Args:
-      indices: The index expression used to access a tensor, which could be any
-        Python object from user inputs.
-
-    Returns:
-      A tuple of IndexVar.
-
-    Raises:
-      ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
-        a tuple of IndexVar for other tensors.
+class _SparseValueInfo(enum.Enum):
+    """Describes how a sparse tensor value is stored.
+    _UNPACKED: The sparse tensor value is stored as (coordnates, values) in
+      Python.
+    _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR
+      sparse tensor.
     """
-    if self.order == 0:
-      if not isinstance(indices, int) or indices != 0:
-        raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
-      return ()
-
-    if isinstance(indices, IndexVar):
-      return (indices,)
-    elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
-      return indices
 
-    raise ValueError(f"Expected IndexVars: {indices}")
+    _UNPACKED = 0
+    _PACKED = 1
 
-  def __getitem__(self, key) -> "Access":
-    """Verifies and processes a tensor access.
 
-    In the tensor index notation, a tensor access T[i, j] is represented as
-    retrieving a value with key (i, j) from the tensor object T in Python. This
-    routine verifies the key for the tensor access and returns a tensor access
-    object.
-
-    Args:
-      key: The key used to access the tensor, which could be any Python object
-        from user inputs.
+ at dataclasses.dataclass(frozen=True)
+class _Assignment:
+    """Records an assignment to a tensor T as T[indices] = expression."""
 
-    Returns:
-      The corresponding tensor access object.
+    indices: Tuple["IndexVar", ...]
+    expression: "IndexExpr"
 
-    Raises:
-      ValueError: If key is not an IndexVar or a tuple of IndexVar.
-    """
-    indices = self._verify_and_normalize_indices(key)
-    return Access(self, indices)
 
-  def __setitem__(self, key, value) -> None:
-    """Verifies and processes a tensor assignment.
+class Tensor:
+    """The tensor class.
 
-    In the tensor index notation, a tensor assignment "T[i, j] = ..." is
-    represented as setting a value for a tensor object T via key (i, j) in
-    Python. This routine verifies the key, evaluates the value, and assigns the
-    value to the tensor.
+    We support the TACO API tensor class with an alias of this class.
 
-    We only support assignment of dense tensor currently.
+    This class is part of the TACO API with the following methods:
+      insert: Inserts a value to the given coordinate in the tensor.
+      to_array: Returns a numpy ndarray for the tensor.
 
-    Args:
-      key: The key used to access the tensor, which could be any Python object
-        from user inputs.
-      value: The value assigned to the tensor, which could be any Python object
-        from user inputs.
+    TACO API also defines the following arrtibutes for the class:
+      dtype: A dtype object representing the data type of the tensor.
+      format: A format object representing the storage format of the tensor.
+      name: A string object representing the name of the tensor.
+      order: An integral rank of the tensor.
+      shape: A list of integers representing the shape of the tensor.
 
-    Raises:
-      ValueError: If tensor is not a dense tensor, or the key is not an IndexVar
-        or a tuple of IndexVar, or the length of the indices is not the same as
-        the rank of the tensor.
+    We currently ignore the tensor dimension ordering for dense tensor.
     """
-    indices = self._verify_and_normalize_indices(key)
-    if len(indices) != self.order:
-      raise ValueError("Mismatch between indices and tensor rank: "
-                       f"len({indices}) != {self.order}.")
 
-    self._assignment = _Assignment(indices, value)
-    self._engine = None
-
-  def compile(self, force_recompile: bool = False) -> None:
-    """Compiles the tensor assignment to an execution engine.
-
-    Calling compile the second time does not do anything unless
-    force_recompile is True.
+    _counter = _AtomicCounter()
+
+    def _get_unique_name(self) -> str:
+        """Returns a unique name for creating a new Tensor."""
+        return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}"
+
+    def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat], Format]) -> None:
+        """Process the fmt argument for the Tensor constructor.
+
+        Args:
+          fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
+            this argument is a ModeFormat, uses this ModeFormat for all the tensor
+            dimensions. If this argument is a list of ModeFormat, the len of the
+            list should equal to the rank of the tensor. If this argument is a
+            format, uses it for the format of the tensor.
+
+        Raises:
+          ValueError: If fmt is not one of the expected type or is inconsistent
+            with the rank of the tensor. This is because fmt could be an users
+            input.
+        """
+        if isinstance(fmt, ModeFormat):
+            self._format = _make_format([fmt] * self.order)
+        elif isinstance(fmt, list):
+            if len(fmt) == self.order and isinstance(fmt[0], ModeFormat):
+                self._format = _make_format(fmt)
+            else:
+                raise ValueError(
+                    "Inconsistent shape and format: " f"{self._shape}, {fmt}."
+                )
+        elif isinstance(fmt, Format):
+            if fmt.rank() != self.order:
+                raise ValueError(
+                    "Inconsistent shape and format: " f"{self._shape}, {fmt}."
+                )
+            else:
+                self._format = fmt
+        else:
+            raise ValueError(f"Invalid format argument: {fmt}.")
+
+    def __init__(
+        self,
+        value_or_shape: Optional[
+            Union[List[int], Tuple[int, ...], complex, float, int]
+        ] = None,
+        fmt: Optional[Union[ModeFormat, List[ModeFormat], Format]] = None,
+        dtype: Optional[DType] = None,
+        name: Optional[str] = None,
+        is_dense: bool = False,
+    ):
+        """The tensor constructor interface defined by TACO API.
+
+        Args:
+          value_or_shape: This argument is optional and can be int, float,
+            List[int], or Tuple[int, ...]. If this argument is an int or float,
+            creates a scalar tensor and initializes it with the value. If this
+            argument is a list or tuple of int, uses it as the shape to create a
+            tensor.
+          fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
+            this argument is a ModeFormat, uses this ModeFormat for all the tensor
+            dimensions. If this argument is a list of ModeFormat, the len of the
+            list should equal to the rank of the tensor. If this argument is a
+            format, uses it for the format of the tensor.
+          dtype: An object of dtype, representing the data type of the tensor.
+          name: A string name of the tensor. If a name is not given, creates a
+            unique name for the tensor.
+          is_dense: A boolean variable to indicate whether the tensor is a dense
+            tensor without any sparsity annotation.
+
+        Raises:
+          ValueError: If there is any inconsistency among the input arguments.
+        """
+        # Take care of the argument default values common to both sparse tensors
+        # and dense tensors.
+        dtype = dtype or DType(Type.FLOAT32)
+        self._name = name or self._get_unique_name()
+        self._assignment = None
+        self._engine = None
+        self._sparse_value_location = _SparseValueInfo._UNPACKED
+        self._dense_storage = None
+        self._dtype = dtype
+
+        if is_dense:
+            assert fmt is None
+            assert (
+                isinstance(value_or_shape, tuple) or isinstance(value_or_shape, list)
+            ) and _all_instance_of(value_or_shape, int)
+            self._shape = value_or_shape
+            self._format = None
+            return
+
+        fmt = fmt or ModeFormat.COMPRESSED
+        # We currently use _coords and _values to host the sparse tensor value with
+        # COO format, and _dense_storage to host the dense tensor value. We don't
+        # support the conversion between the two storages.
+        self._coords = []
+        self._values = []
+        self._stats = _Stats()
+        if (
+            value_or_shape is None
+            or isinstance(value_or_shape, int)
+            or isinstance(value_or_shape, float)
+            or isinstance(value_or_shape, complex)
+        ):
+            # Create a scalar tensor and ignore the fmt parameter.
+            self._shape = []
+            self._format = _make_format([], [])
+            if value_or_shape is not None:
+                self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value)
+        elif (
+            isinstance(value_or_shape, tuple) or isinstance(value_or_shape, list)
+        ) and _all_instance_of(value_or_shape, int):
+            # Create a tensor with the specified shape and format.
+            self._shape = list(value_or_shape)
+            self._init_format(fmt)
+        else:
+            raise ValueError(
+                "Invalid first argument. "
+                "Must be a tuple or list for a shape or a single value"
+                f"if initializing a scalar tensor: {value_or_shape}."
+            )
+
+    def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None:
+        """Records the MLIR sparse tensor pointer."""
+        self._sparse_value_location = _SparseValueInfo._PACKED
+        self._packed_sparse_value = pointer
+
+    def is_unpacked(self) -> bool:
+        """Returns true if the tensor value is not packed as MLIR sparse tensor."""
+        return self._sparse_value_location == _SparseValueInfo._UNPACKED
+
+    def unpack(self) -> None:
+        """Unpacks the MLIR sparse tensor representation."""
+        if self.is_dense() or self.is_unpacked():
+            return
+
+        # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
+        # values and verify the values.
+        rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
+            self._packed_sparse_value, self._dtype.value
+        )
+        assert rank == self.order
+        assert np.array_equal(self.shape, shape)
+        assert nse == len(values)
+        self._coords = indices
+        self._values = values
+        self._sparse_value_location = _SparseValueInfo._UNPACKED
+
+    def __repr__(self) -> str:
+        self._sync_value()
+        self.unpack()
+        value_str = (
+            f"{repr(self._dense_storage)})"
+            if self.is_dense()
+            else f"{repr(self._coords)} {repr(self._values)})"
+        )
+        return (
+            f"Tensor(_name={repr(self._name)} " f"_dtype={repr(self._dtype)} : "
+        ) + value_str
+
+    def insert(self, coords: List[int], val: Union[complex, float, int]) -> None:
+        """Inserts a value to the given coordinate.
+
+        Args:
+          coords: A list of integer coordinates. The length of the list must be the
+            same as the rank of the tensor.
+          val: A value being inserted. It is either an integral or a floating point
+            value. This value will be converted to the data type of the tensor.
+
+        Raises:
+          ValueError: When there is any problem in the parameters.
+        """
+        if self.is_dense():
+            raise ValueError("Insert method is not supported for dense tensors.")
+        if self._assignment != None or not self.is_unpacked():
+            raise ValueError(
+                "Can't use Insert method for a tensor constructed from a file."
+            )
+        if not isinstance(coords, list):
+            raise ValueError(f"Non list coordinate detected: {coords}.")
+        if not _all_instance_of(coords, int):
+            raise ValueError(f"Non integer coordinate detected: {coords}.")
+        if len(coords) != self.order or any(
+            [c < 0 or c >= self._shape[i] for i, c in enumerate(coords)]
+        ):
+            raise ValueError("Invalid coordinate for rank: " f"{self.order}, {coords}.")
+
+        if (
+            not isinstance(val, int)
+            and not isinstance(val, float)
+            and not isinstance(val, complex)
+        ):
+            raise ValueError(f"Value is neither int nor float: {val}.")
+
+        self._coords.append(tuple(coords))
+        self._values.append(self._dtype.value(val))
+
+    def is_dense(self) -> bool:
+        """Returns true if the tensor doesn't have sparsity annotation."""
+        return self.order == 0 or self._format is None
+
+    def to_array(self) -> np.ndarray:
+        """Returns the numpy array for the Tensor.
+
+        This is currenly only implemented for dense Tensor.
+        """
+        if not self.is_dense():
+            raise ValueError(
+                "Conversion from non-dense Tensor " "to numpy array not supported yet."
+            )
+
+        self._sync_value()
+
+        return self._dense_storage
+
+    @staticmethod
+    def from_array(array: np.ndarray) -> "Tensor":
+        """Returns a dense tensor with the value copied from the input array.
+
+        We currently only support the conversion of float32 and float64 numpy arrays
+        to Tensor.
+
+        Args:
+          array: The numpy array that provides the data type, shape and value for
+            the tensor.
+
+        Returns:
+          A Tensor object.
+
+        Raises:
+          ValueError if the data type of the numpy array is not supported.
+        """
+        if array.dtype != np.float32 and array.dtype != np.float64:
+            raise ValueError(f"Expected floating point value type: {array.dtype}.")
+        t = Tensor(
+            array.shape, dtype=_nptype_to_taco_type(array.dtype.type), is_dense=True
+        )
+        t._dense_storage = np.copy(array)
+        return t
+
+    @staticmethod
+    def from_coo(
+        coordinates: List[Tuple[int, ...]],
+        values: List[_AnyRuntimeType],
+        fmt: Format,
+        dtype: DType,
+    ) -> "Tensor":
+        """Converts coordinates and values to a sparse tensor representation.
+
+        Args:
+          coordinates: A list of coordinates with non-zero values.
+          values: The non-zero values.
+          fmt: The tensor storage format.
+          dtype: The tensor element data type.
+
+        Returns:
+          A tensor with the given non-zero values and storage format. The shape of
+          the tensor has the minimum size for each dimension to make the given
+          coordinates valid.
+        """
+        assert isinstance(coordinates, List) and _all_instance_of(coordinates, Tuple)
+        assert isinstance(values, List) and _all_instance_of(values, dtype.value)
+        assert isinstance(fmt, Format)
+
+        rank = fmt.rank()
+        assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates)
+
+        # Find the maximum coordinate value for each dimension.
+        max_coordinate = list(map(max, zip(*coordinates)))
+        # The size of each dimension is one more that such a maximum coordinate
+        # value.
+        shape = [c + 1 for c in max_coordinate]
+        t = Tensor(shape, fmt, dtype=dtype)
+        t._coords = coordinates
+        t._values = values
+
+        return tensor
+
+    @staticmethod
+    def from_file(
+        filename: str,
+        fmt: Format,
+        dtype: DType,
+    ) -> "Tensor":
+        """Constructs a sparse tensor using the COO-flavored values from a file.
+
+        Args:
+          filename: A string for the name of the file that contains the sparse
+            tensor data.
+          fmt: The tensor storage format.
+          dtype: The tensor element data type.
+
+        Returns:
+          A tensor with the given non-zero values and storage format. The tensor
+          value is stored as an MLIR sparse tensor.
+        """
+        sparse_tensor, shape = utils.create_sparse_tensor(
+            filename, fmt.format_pack.formats, _dtype_to_mlir_str(dtype)
+        )
+        t = Tensor(shape.tolist(), fmt, dtype=dtype)
+        t._set_packed_sparse_tensor(sparse_tensor)
+
+        return t
+
+    def to_file(self, filename: str) -> None:
+        """Output the tensor value to a file.
+
+        This method evaluates any pending assignment to the tensor and outputs the
+        tensor value.
+
+        Args:
+          filename: A string file name.
+
+        Raises:
+           ValueError: If the tensor is dense, or an unpacked sparse tensor.
+        """
+        self._sync_value()
+
+        if self.is_dense():
+            raise ValueError(
+                "Writing dense tensors without sparsity annotation to "
+                "file is not supported."
+            )
+
+        if self.is_unpacked():
+            raise ValueError(
+                "Writing unpacked sparse tensors to file is not " "supported."
+            )
+
+        utils.output_sparse_tensor(
+            self._packed_sparse_value,
+            filename,
+            self._format.format_pack.formats,
+            _dtype_to_mlir_str(self._dtype),
+        )
+
+    @property
+    def dtype(self) -> DType:
+        """Returns the data type for the Tensor."""
+        return self._dtype
+
+    @property
+    def format(self) -> Format:
+        """Returns the storage format for the Tensor."""
+        return self._format
+
+    @property
+    def name(self) -> str:
+        """Returns the name for the Tensor."""
+        return self._name
+
+    @property
+    def order(self) -> int:
+        """Returns the rank of the Tensor."""
+        return len(self._shape)
+
+    @property
+    def shape(self) -> List[int]:
+        """Returns the shape of the Tensor."""
+        return self._shape
+
+    def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
+        """Verifies and normalizes the indices to access the tensor.
+
+        Args:
+          indices: The index expression used to access a tensor, which could be any
+            Python object from user inputs.
+
+        Returns:
+          A tuple of IndexVar.
+
+        Raises:
+          ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
+            a tuple of IndexVar for other tensors.
+        """
+        if self.order == 0:
+            if not isinstance(indices, int) or indices != 0:
+                raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
+            return ()
+
+        if isinstance(indices, IndexVar):
+            return (indices,)
+        elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
+            return indices
+
+        raise ValueError(f"Expected IndexVars: {indices}")
+
+    def __getitem__(self, key) -> "Access":
+        """Verifies and processes a tensor access.
+
+        In the tensor index notation, a tensor access T[i, j] is represented as
+        retrieving a value with key (i, j) from the tensor object T in Python. This
+        routine verifies the key for the tensor access and returns a tensor access
+        object.
+
+        Args:
+          key: The key used to access the tensor, which could be any Python object
+            from user inputs.
+
+        Returns:
+          The corresponding tensor access object.
+
+        Raises:
+          ValueError: If key is not an IndexVar or a tuple of IndexVar.
+        """
+        indices = self._verify_and_normalize_indices(key)
+        return Access(self, indices)
+
+    def __setitem__(self, key, value) -> None:
+        """Verifies and processes a tensor assignment.
+
+        In the tensor index notation, a tensor assignment "T[i, j] = ..." is
+        represented as setting a value for a tensor object T via key (i, j) in
+        Python. This routine verifies the key, evaluates the value, and assigns the
+        value to the tensor.
+
+        We only support assignment of dense tensor currently.
+
+        Args:
+          key: The key used to access the tensor, which could be any Python object
+            from user inputs.
+          value: The value assigned to the tensor, which could be any Python object
+            from user inputs.
+
+        Raises:
+          ValueError: If tensor is not a dense tensor, or the key is not an IndexVar
+            or a tuple of IndexVar, or the length of the indices is not the same as
+            the rank of the tensor.
+        """
+        indices = self._verify_and_normalize_indices(key)
+        if len(indices) != self.order:
+            raise ValueError(
+                "Mismatch between indices and tensor rank: "
+                f"len({indices}) != {self.order}."
+            )
+
+        self._assignment = _Assignment(indices, value)
+        self._engine = None
+
+    def compile(self, force_recompile: bool = False) -> None:
+        """Compiles the tensor assignment to an execution engine.
+
+        Calling compile the second time does not do anything unless
+        force_recompile is True.
+
+        Args:
+          force_recompile: A boolean value to enable recompilation, such as for the
+            purpose of timing.
+
+        Raises:
+          ValueError: If the assignment is not proper or not supported.
+        """
+        if self._assignment is None or (
+            self._engine is not None and not force_recompile
+        ):
+            return
+
+        self._engine = self._assignment.expression.compile(
+            self, self._assignment.indices
+        )
+
+    def compute(self) -> None:
+        """Executes the engine for the tensor assignment.
+
+        Raises:
+          ValueError: If the assignment hasn't been compiled yet.
+        """
+        if self._assignment is None:
+            return
+
+        if self._engine is None:
+            raise ValueError("Need to invoke compile() before invoking compute().")
+
+        input_accesses = self._assignment.expression.get_input_accesses()
+        # Gather the pointers for the input buffers.
+        input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
+        if self.is_dense():
+            # The pointer to receive dense output is the first argument to the
+            # execution engine.
+            arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers
+        else:
+            # The pointer to receive the sparse tensor output is the last argument
+            # to the execution engine and is a pointer to pointer of char.
+            arg_pointers = input_pointers + [
+                ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
+            ]
+
+        # Invoke the execution engine to run the module.
+        self._engine.invoke(_ENTRY_NAME, *arg_pointers)
+
+        # Retrieve the result.
+        if self.is_dense():
+            result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
+            assert isinstance(result, np.ndarray)
+            self._dense_storage = result
+        else:
+            self._set_packed_sparse_tensor(arg_pointers[-1][0])
+
+        self._assignment = None
+        self._engine = None
+
+    def evaluate(self) -> None:
+        """Evaluates the tensor assignment."""
+        self.compile()
+        self.compute()
+
+    def _sync_value(self) -> None:
+        """Updates the tensor value by evaluating the pending assignment."""
+        if self._assignment is not None:
+            self.evaluate()
+
+    def mlir_tensor_type(self) -> ir.RankedTensorType:
+        """Returns the MLIR type for the tensor."""
+        mlir_attr = (
+            None
+            if (self._format is None or self.order == 0)
+            else self._format.mlir_tensor_attr()
+        )
+        return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
+
+    def dense_dst_ctype_pointer(self) -> ctypes.pointer:
+        """Returns the ctypes pointer for the pointer to an MemRefDescriptor.
+
+        For a dense tensor output, the MLIR compiler allocates the storage for
+        the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for
+        receiving the tensor.
+        """
+        assert self.is_dense()
+        mem_ref_desc = runtime.make_nd_memref_descriptor(
+            self.order, np.ctypeslib.as_ctypes_type(self.dtype.value)
+        )()
+        return ctypes.pointer(ctypes.pointer(mem_ref_desc))
+
+    def ctype_pointer(self) -> ctypes.pointer:
+        """Returns the ctypes pointer for the pointer to the input tensor."""
+        if self.is_dense():
+            if self._dense_storage is None:
+                self._dense_storage = np.zeros(self._shape, self._dtype.value)
+            return _ctype_pointer_from_array(self._dense_storage)
+
+        if self.is_unpacked():
+            shape = np.array(self._shape, np.int64)
+            indices = np.array(self._coords, np.int64)
+            values = np.array(self._values, self._dtype.value)
+            perm, sparse = self.format.get_permutation_and_sparsity()
+            ptr = utils.coo_tensor_to_sparse_tensor(
+                shape, values, indices, perm, sparse
+            )
+        else:
+            ptr = self._packed_sparse_value
+
+        return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
+
+    def get_scalar_value(self) -> _AnyRuntimeType:
+        """Returns the value for the scalar tensor.
+
+        This method also evaluates the assignment to the tensor.
+
+        Raises:
+          ValueError: If the tensor is not a scalar.
+        """
+        if self.order != 0:
+            raise ValueError(f"Expected a scalar tensor, got: rank={self.order}")
+
+        self._sync_value()
+        return self._dense_storage
+
+    def get_coordinates_and_values(
+        self,
+    ) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]:
+        """Returns the coordinates and values for the non-zero elements.
+
+        This method also evaluates the assignment to the tensor and unpack the
+        sparse tensor.
+        """
+        self._sync_value()
+
+        if not self.is_dense():
+            self.unpack()
+            return (self._coords, self._values)
+
+        if self.order == 0:
+            return ([], self._dense_storage)
+
+        # Coordinates for non-zero elements, grouped by dimensions.
+        coords_by_dims = self._dense_storage.nonzero()
+        # Coordinates for non-zero elements, grouped by elements.
+        coords = np.transpose(coords_by_dims)
+        values = self._dense_storage[coords_by_dims]
+        return (coords, values)
+
+    def _record_stats(self, structop: "_StructOpInfo"):
+        """Collects information for temporary tensors."""
+        # Exclude user specified destination tensors.
+        if structop.dst_name == self.name:
+            return
+
+        self._stats.add_element(structop)
+
+
+def _emit_operand(
+    op_def: lang.LinalgOpDef,
+    indices: Tuple[IndexVar, ...],
+    name: str,
+    kind: lang.OperandKind,
+) -> lang.OperandDef:
+    """Emits an operand for a tensor access in the current linalg operation.
 
     Args:
-      force_recompile: A boolean value to enable recompilation, such as for the
-        purpose of timing.
-
-    Raises:
-      ValueError: If the assignment is not proper or not supported.
-    """
-    if self._assignment is None or (self._engine is not None and
-                                    not force_recompile):
-      return
-
-    self._engine = self._assignment.expression.compile(self,
-                                                       self._assignment.indices)
+      op_def: A LinalgOpDef representing the current linalg dialect operation.
+      indices: A tuple of IndexVar used to access the tensor.
+      name: A unique string name of the tensor.
+      kind: An OperandKind for the operand.
 
-  def compute(self) -> None:
-    """Executes the engine for the tensor assignment.
-
-    Raises:
-      ValueError: If the assignment hasn't been compiled yet.
-    """
-    if self._assignment is None:
-      return
-
-    if self._engine is None:
-      raise ValueError("Need to invoke compile() before invoking compute().")
-
-    input_accesses = self._assignment.expression.get_input_accesses()
-    # Gather the pointers for the input buffers.
-    input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
-    if self.is_dense():
-      # The pointer to receive dense output is the first argument to the
-      # execution engine.
-      arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers
-    else:
-      # The pointer to receive the sparse tensor output is the last argument
-      # to the execution engine and is a pointer to pointer of char.
-      arg_pointers = input_pointers + [
-          ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
-      ]
-
-    # Invoke the execution engine to run the module.
-    self._engine.invoke(_ENTRY_NAME, *arg_pointers)
-
-    # Retrieve the result.
-    if self.is_dense():
-      result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
-      assert isinstance(result, np.ndarray)
-      self._dense_storage = result
-    else:
-      self._set_packed_sparse_tensor(arg_pointers[-1][0])
-
-    self._assignment = None
-    self._engine = None
-
-  def evaluate(self) -> None:
-    """Evaluates the tensor assignment."""
-    self.compile()
-    self.compute()
-
-  def _sync_value(self) -> None:
-    """Updates the tensor value by evaluating the pending assignment."""
-    if self._assignment is not None:
-      self.evaluate()
-
-  def mlir_tensor_type(self) -> ir.RankedTensorType:
-    """Returns the MLIR type for the tensor."""
-    mlir_attr = (None if (self._format is None or self.order == 0) else
-                 self._format.mlir_tensor_attr())
-    return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
-
-  def dense_dst_ctype_pointer(self) -> ctypes.pointer:
-    """Returns the ctypes pointer for the pointer to an MemRefDescriptor.
-
-    For a dense tensor output, the MLIR compiler allocates the storage for
-    the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for
-    receiving the tensor.
-    """
-    assert self.is_dense()
-    mem_ref_desc = runtime.make_nd_memref_descriptor(
-        self.order, np.ctypeslib.as_ctypes_type(self.dtype.value))()
-    return ctypes.pointer(ctypes.pointer(mem_ref_desc))
-
-  def ctype_pointer(self) -> ctypes.pointer:
-    """Returns the ctypes pointer for the pointer to the input tensor."""
-    if self.is_dense():
-      if self._dense_storage is None:
-        self._dense_storage = np.zeros(self._shape, self._dtype.value)
-      return _ctype_pointer_from_array(self._dense_storage)
-
-    if self.is_unpacked():
-      shape = np.array(self._shape, np.int64)
-      indices = np.array(self._coords, np.int64)
-      values = np.array(self._values, self._dtype.value)
-      perm, sparse = self.format.get_permutation_and_sparsity()
-      ptr = utils.coo_tensor_to_sparse_tensor(shape, values, indices, perm,
-                                              sparse)
-    else:
-      ptr = self._packed_sparse_value
-
-    return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
-
-  def get_scalar_value(self) -> _AnyRuntimeType:
-    """Returns the value for the scalar tensor.
-
-    This method also evaluates the assignment to the tensor.
-
-    Raises:
-      ValueError: If the tensor is not a scalar.
-    """
-    if self.order != 0:
-      raise ValueError(f"Expected a scalar tensor, got: rank={self.order}")
-
-    self._sync_value()
-    return self._dense_storage
-
-
-  def get_coordinates_and_values(
-      self) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]:
-    """Returns the coordinates and values for the non-zero elements.
-
-    This method also evaluates the assignment to the tensor and unpack the
-    sparse tensor.
+    Returns:
+      An OperandDef representing the operand.
     """
-    self._sync_value()
-
-    if not self.is_dense():
-      self.unpack()
-      return (self._coords, self._values)
-
-    if self.order == 0:
-      return ([], self._dense_storage)
-
-    # Coordinates for non-zero elements, grouped by dimensions.
-    coords_by_dims = self._dense_storage.nonzero()
-    # Coordinates for non-zero elements, grouped by elements.
-    coords = np.transpose(coords_by_dims)
-    values = self._dense_storage[coords_by_dims]
-    return (coords, values)
-
-  def _record_stats(self, structop: "_StructOpInfo"):
-    """Collects information for temporary tensors."""
-    # Exclude user specified destination tensors.
-    if structop.dst_name == self.name:
-      return
-
-    self._stats.add_element(structop)
-
-
-def _emit_operand(op_def: lang.LinalgOpDef, indices: Tuple[IndexVar, ...],
-                  name: str, kind: lang.OperandKind) -> lang.OperandDef:
-  """Emits an operand for a tensor access in the current linalg operation.
-
-  Args:
-    op_def: A LinalgOpDef representing the current linalg dialect operation.
-    indices: A tuple of IndexVar used to access the tensor.
-    name: A unique string name of the tensor.
-    kind: An OperandKind for the operand.
-
-  Returns:
-    An OperandDef representing the operand.
-  """
-  dim_sym = _mlir_symbols_from_index_vars(indices)
-  opnd = lang.OperandDef(kind, lang.T, dim_sym)
-  op_def.add_operand(name, opnd)
-  return opnd
+    dim_sym = _mlir_symbols_from_index_vars(indices)
+    opnd = lang.OperandDef(kind, lang.T, dim_sym)
+    op_def.add_operand(name, opnd)
+    return opnd
 
 
 @dataclasses.dataclass(frozen=True)
 class _DimInfo:
-  """Information for an operand dimension.
+    """Information for an operand dimension.
 
-  Attributes:
-    dim: An integer for the size of the dimension.
-    mode_format: A ModeFormat for the dimension sparsity.
-  """
-  dim: int
-  mode_format: ModeFormat
+    Attributes:
+      dim: An integer for the size of the dimension.
+      mode_format: A ModeFormat for the dimension sparsity.
+    """
+
+    dim: int
+    mode_format: ModeFormat
 
 
 def _get_dummy_dim_info() -> _DimInfo:
-  """Constructs the _DimInfo for an index used in tensor expressions."""
-  return _DimInfo(-1, ModeFormat.DENSE)
+    """Constructs the _DimInfo for an index used in tensor expressions."""
+    return _DimInfo(-1, ModeFormat.DENSE)
 
 
 @dataclasses.dataclass()
 class _ExprInfo:
-  """Expression information for validation and code generation.
-
-  Attributes:
-    src_indices: A tuple of IndexVar for the indices used by the tensors in the
-      expression tree.
-    dim_infos: A tuple of _DimInfo, representing the dimension information
-      corresponding to the src_indices.
-    reduce_indices: A set of IndexVar for the indices reduced by the expression.
-    acc_reduce_indices: An accumulated set of IndexVar for the indices reduced
-      by the expression and its children.
-    structop_info: Information to support the code generation for a structured
-      op in the linalg dialect, if the corresponding expression node is the root
-      of a subtree for a structured op.
-    mlir_value: The MLIR value generated for the structured op.
-  """
-  src_indices: Tuple[IndexVar, ...]
-  dim_infos: Tuple[_DimInfo, ...]
-  reduce_indices: Optional[Set[IndexVar]] = None
-  acc_reduce_indices: Optional[Set[IndexVar]] = None
-  structop_info: Optional[_StructOpInfo] = None
-  mlir_value: Optional[ir.Value] = None
-
-  def __post_init__(self) -> None:
-    """Verifies and fix up attribute values.
-
-    Verifies the consistency of the attributes and modifies the default values
-    to support convenient initializer syntax.
+    """Expression information for validation and code generation.
+
+    Attributes:
+      src_indices: A tuple of IndexVar for the indices used by the tensors in the
+        expression tree.
+      dim_infos: A tuple of _DimInfo, representing the dimension information
+        corresponding to the src_indices.
+      reduce_indices: A set of IndexVar for the indices reduced by the expression.
+      acc_reduce_indices: An accumulated set of IndexVar for the indices reduced
+        by the expression and its children.
+      structop_info: Information to support the code generation for a structured
+        op in the linalg dialect, if the corresponding expression node is the root
+        of a subtree for a structured op.
+      mlir_value: The MLIR value generated for the structured op.
     """
-    assert len(self.src_indices) == len(self.dim_infos)
-    self.reduce_indices = self.reduce_indices or set()
-    self.acc_reduce_indices = self.acc_reduce_indices or set()
 
+    src_indices: Tuple[IndexVar, ...]
+    dim_infos: Tuple[_DimInfo, ...]
+    reduce_indices: Optional[Set[IndexVar]] = None
+    acc_reduce_indices: Optional[Set[IndexVar]] = None
+    structop_info: Optional[_StructOpInfo] = None
+    mlir_value: Optional[ir.Value] = None
 
- at dataclasses.dataclass(frozen=True)
-class Access(IndexExpr):
-  """The tensor access class.
+    def __post_init__(self) -> None:
+        """Verifies and fix up attribute values.
+
+        Verifies the consistency of the attributes and modifies the default values
+        to support convenient initializer syntax.
+        """
+        assert len(self.src_indices) == len(self.dim_infos)
+        self.reduce_indices = self.reduce_indices or set()
+        self.acc_reduce_indices = self.acc_reduce_indices or set()
 
-  We support the TACO API access class with an alias of this class.
 
-  Attributes:
-    tensor: A Tensor being accessed.
-    indices: A tuple of IndexVar, representing the indices used to access the
-      Tensor.
-  """
-  tensor: Tensor
-  indices: Tuple[IndexVar, ...]
+ at dataclasses.dataclass(frozen=True)
+class Access(IndexExpr):
+    """The tensor access class.
 
-  def __post_init__(self) -> None:
-    """Verifies the tensor and indices for a tensor access.
+    We support the TACO API access class with an alias of this class.
 
-    Raises:
-       ValueError: If indices is not a list of IndexVar or the len of indices
-       doesn't equal to the rank of the tensor.
+    Attributes:
+      tensor: A Tensor being accessed.
+      indices: A tuple of IndexVar, representing the indices used to access the
+        Tensor.
     """
-    if (not isinstance(self.indices, tuple) or
-        not _all_instance_of(self.indices, IndexVar)):
-      raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.")
-    if self.tensor.order != len(self.indices):
-      raise ValueError("Invalid indices for rank: "
-                       f"str{self.tensor.order} != len({str(self.indices)}).")
-
-  def __repr__(self) -> str:
-    # The Tensor __repr__ method evaluates the pending assignment to the tensor.
-    # We want to define the __repr__ method here to avoid such evaluation of the
-    # tensor assignment.
-    indices_str = ", ".join(map(lambda i: i.name, self.indices))
-    return (f"Tensor({self.tensor.name}) " f"Indices({indices_str})")
-
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits a linalg dialect TensorUse expression for the tensor access."""
-    assert self in expr_to_opnd
-    dims = _mlir_dimensions_from_index_vars(self.indices)
-    return lang.TensorUse(expr_to_opnd[self], dims)
-
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    if leaf_checker:
-      assert leaf_checker(self, *args)
-    func(self, *args)
-
-  def dtype(self) -> DType:
-    return self.tensor.dtype
+
+    tensor: Tensor
+    indices: Tuple[IndexVar, ...]
+
+    def __post_init__(self) -> None:
+        """Verifies the tensor and indices for a tensor access.
+
+        Raises:
+           ValueError: If indices is not a list of IndexVar or the len of indices
+           doesn't equal to the rank of the tensor.
+        """
+        if not isinstance(self.indices, tuple) or not _all_instance_of(
+            self.indices, IndexVar
+        ):
+            raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.")
+        if self.tensor.order != len(self.indices):
+            raise ValueError(
+                "Invalid indices for rank: "
+                f"str{self.tensor.order} != len({str(self.indices)})."
+            )
+
+    def __repr__(self) -> str:
+        # The Tensor __repr__ method evaluates the pending assignment to the tensor.
+        # We want to define the __repr__ method here to avoid such evaluation of the
+        # tensor assignment.
+        indices_str = ", ".join(map(lambda i: i.name, self.indices))
+        return f"Tensor({self.tensor.name}) " f"Indices({indices_str})"
+
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits a linalg dialect TensorUse expression for the tensor access."""
+        assert self in expr_to_opnd
+        dims = _mlir_dimensions_from_index_vars(self.indices)
+        return lang.TensorUse(expr_to_opnd[self], dims)
+
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        if leaf_checker:
+            assert leaf_checker(self, *args)
+        func(self, *args)
+
+    def dtype(self) -> DType:
+        return self.tensor.dtype
 
 
 def _gather_input_accesses_index_vars(
     expr: IndexExpr,
     input_accesses: List[Access],
 ) -> None:
-  """Collects Access nodes."""
-  if isinstance(expr, Access) and expr not in input_accesses:
-    input_accesses.append(expr)
+    """Collects Access nodes."""
+    if isinstance(expr, Access) and expr not in input_accesses:
+        input_accesses.append(expr)
 
 
 def _op_ceil(__a: Any) -> Any:
-  """A _UnaryOp object for operation ceil."""
-  pass
+    """A _UnaryOp object for operation ceil."""
+    pass
 
 
 def _op_floor(__a: Any) -> Any:
-  """A _UnaryOp object for operation floor."""
-  pass
+    """A _UnaryOp object for operation floor."""
+    pass
 
 
 def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType:
-  """Returns the linalg dialect function object for the given operation."""
-  op_to_callable = {
-      operator.abs: lang.UnaryFn.abs,
-      operator.neg: lang.UnaryFn.negf,
-      _op_ceil: lang.UnaryFn.ceil,
-      _op_floor: lang.UnaryFn.floor,
-  }
-  return op_to_callable[op]
+    """Returns the linalg dialect function object for the given operation."""
+    op_to_callable = {
+        operator.abs: lang.UnaryFn.abs,
+        operator.neg: lang.UnaryFn.negf,
+        _op_ceil: lang.UnaryFn.ceil,
+        _op_floor: lang.UnaryFn.floor,
+    }
+    return op_to_callable[op]
 
 
 @dataclasses.dataclass(frozen=True)
 class _UnaryExpr(IndexExpr):
-  """The representation for a Unary operation.
-
-  Attributes:
-  op: A _UnaryOp representing the operation.
-  a: An IndexExpr representing the operand for the operation.
-  """
-  op: _BinaryOp
-  a: IndexExpr
-
-  def __post_init__(self) -> None:
-    """Verifies that the operand being added is an IndexExpr."""
-    assert isinstance(self.a, IndexExpr)
-
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits the expression tree and returns the expression."""
-    # The current expression node is an internal node of the structured op.
-    if self not in expr_to_opnd:
-      a = self.a._emit_expression(expr_to_opnd, expr_to_info)
-      return _op_unary_to_callable(self.op)(a)
-
-    # The current expression is a leaf node of the structured op. That is, it is
-    # a temporary tensor generated by its child structured op.
-    op_info = expr_to_info[self].structop_info
-    assert op_info is not None
-    dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
-    return lang.TensorUse(expr_to_opnd[self], dims)
-
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    """A post-order visitor."""
-    if leaf_checker is None or not leaf_checker(self, *args):
-      self.a._visit(func, args, leaf_checker=leaf_checker)
-    func(self, *args)
-
-  def dtype(self) -> DType:
-    """Returns the data type of the operation."""
-    return self.a.dtype()
+    """The representation for a Unary operation.
+
+    Attributes:
+    op: A _UnaryOp representing the operation.
+    a: An IndexExpr representing the operand for the operation.
+    """
+
+    op: _BinaryOp
+    a: IndexExpr
+
+    def __post_init__(self) -> None:
+        """Verifies that the operand being added is an IndexExpr."""
+        assert isinstance(self.a, IndexExpr)
+
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits the expression tree and returns the expression."""
+        # The current expression node is an internal node of the structured op.
+        if self not in expr_to_opnd:
+            a = self.a._emit_expression(expr_to_opnd, expr_to_info)
+            return _op_unary_to_callable(self.op)(a)
+
+        # The current expression is a leaf node of the structured op. That is, it is
+        # a temporary tensor generated by its child structured op.
+        op_info = expr_to_info[self].structop_info
+        assert op_info is not None
+        dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+        return lang.TensorUse(expr_to_opnd[self], dims)
+
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        """A post-order visitor."""
+        if leaf_checker is None or not leaf_checker(self, *args):
+            self.a._visit(func, args, leaf_checker=leaf_checker)
+        func(self, *args)
+
+    def dtype(self) -> DType:
+        """Returns the data type of the operation."""
+        return self.a.dtype()
 
 
 def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
-  """Returns the linalg dialect function object for the given operation."""
-  op_to_callable = {
-      operator.add: lang.BinaryFn.add,
-      operator.sub: lang.BinaryFn.sub,
-      operator.mul: lang.BinaryFn.mul,
-  }
-  return op_to_callable[op]
+    """Returns the linalg dialect function object for the given operation."""
+    op_to_callable = {
+        operator.add: lang.BinaryFn.add,
+        operator.sub: lang.BinaryFn.sub,
+        operator.mul: lang.BinaryFn.mul,
+    }
+    return op_to_callable[op]
+
 
 @dataclasses.dataclass(frozen=True)
 class _BinaryExpr(IndexExpr):
-  """The representation for a binary operation.
-
-  Attributes:
-  op: A _BinaryOp representing the binary operation.
-  a: An IndexExpr representing the first operand of the operation.
-  b: An IndexExpr representing the second operand of the operation.
-  """
-  op: _BinaryOp
-  a: IndexExpr
-  b: IndexExpr
-
-  def __post_init__(self) -> None:
-    """Verifies that the operands being added are IndexExpr."""
-    assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr)
-
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits the expression tree and returns the expression."""
-    # The current expression node is an internal node of the structured op.
-    if self not in expr_to_opnd:
-      a = self.a._emit_expression(expr_to_opnd, expr_to_info)
-      b = self.b._emit_expression(expr_to_opnd, expr_to_info)
-      return _op_to_callable(self.op)(a, b)
-
-    # The current expression is a leaf node of the structured op. That is, it is
-    # a temporary tensor generated by its child structured op.
-    op_info = expr_to_info[self].structop_info
-    assert op_info is not None
-    dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
-    return lang.TensorUse(expr_to_opnd[self], dims)
-
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    """A post-order visitor."""
-    if leaf_checker is None or not leaf_checker(self, *args):
-      self.a._visit(func, args, leaf_checker=leaf_checker)
-      self.b._visit(func, args, leaf_checker=leaf_checker)
-    func(self, *args)
-
-  def dtype(self) -> DType:
-    """Returns the data type of the binary operation."""
-    return self.a.dtype()
+    """The representation for a binary operation.
+
+    Attributes:
+    op: A _BinaryOp representing the binary operation.
+    a: An IndexExpr representing the first operand of the operation.
+    b: An IndexExpr representing the second operand of the operation.
+    """
+
+    op: _BinaryOp
+    a: IndexExpr
+    b: IndexExpr
+
+    def __post_init__(self) -> None:
+        """Verifies that the operands being added are IndexExpr."""
+        assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr)
+
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits the expression tree and returns the expression."""
+        # The current expression node is an internal node of the structured op.
+        if self not in expr_to_opnd:
+            a = self.a._emit_expression(expr_to_opnd, expr_to_info)
+            b = self.b._emit_expression(expr_to_opnd, expr_to_info)
+            return _op_to_callable(self.op)(a, b)
+
+        # The current expression is a leaf node of the structured op. That is, it is
+        # a temporary tensor generated by its child structured op.
+        op_info = expr_to_info[self].structop_info
+        assert op_info is not None
+        dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+        return lang.TensorUse(expr_to_opnd[self], dims)
+
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        """A post-order visitor."""
+        if leaf_checker is None or not leaf_checker(self, *args):
+            self.a._visit(func, args, leaf_checker=leaf_checker)
+            self.b._visit(func, args, leaf_checker=leaf_checker)
+        func(self, *args)
+
+    def dtype(self) -> DType:
+        """Returns the data type of the binary operation."""
+        return self.a.dtype()
 
 
 def _validate_and_collect_dim_info(
@@ -1822,105 +1901,104 @@ def _validate_and_collect_dim_info(
     dim_infos: Tuple[_DimInfo, ...],
     expr: _BinaryExpr,
 ) -> None:
-  """Validates and collects the dimension information for an index notation.
-
-  Validates (indices, dim_infos) against the information collected from other
-  source operands and is represented by index_to_dim_info. In particular, we
-  ensure that each IndexVar corresponds to only one dimension size. We also
-  aggregate the new information represented in (indices, dim_infos) to
-  index_to_dim_info.
-
-  Args:
-    index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the
-      previous operands.
-    indices: The IndexVars to be validated.
-    dim_infos: The dimension information for the IndexVars to be validated.
-    expr: The binary expression where (indices, dim_infos) is used.
-
-  Raises:
-    ValueError if there is any problem in the IndexVars or dimensional values.
-  """
-  assert len(indices) == len(dim_infos)
-  for i, d in zip(indices, dim_infos):
-    if i not in index_to_dim_info:
-      index_to_dim_info[i] = d
-    else:
-      dim = index_to_dim_info[i].dim
-      if dim == -1 or d.dim == -1:
-        dim = dim if dim != -1 else d.dim
-      elif dim != d.dim:
-        raise ValueError(f"Inconsistent source dimension for {i}: "
-                         f"{d.dim} vs {dim}")
-      mode_format = _mode_format_estimator(expr.op)(
-          index_to_dim_info[i].mode_format, d.mode_format)
-      index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
+    """Validates and collects the dimension information for an index notation.
+
+    Validates (indices, dim_infos) against the information collected from other
+    source operands and is represented by index_to_dim_info. In particular, we
+    ensure that each IndexVar corresponds to only one dimension size. We also
+    aggregate the new information represented in (indices, dim_infos) to
+    index_to_dim_info.
+
+    Args:
+      index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the
+        previous operands.
+      indices: The IndexVars to be validated.
+      dim_infos: The dimension information for the IndexVars to be validated.
+      expr: The binary expression where (indices, dim_infos) is used.
+
+    Raises:
+      ValueError if there is any problem in the IndexVars or dimensional values.
+    """
+    assert len(indices) == len(dim_infos)
+    for i, d in zip(indices, dim_infos):
+        if i not in index_to_dim_info:
+            index_to_dim_info[i] = d
+        else:
+            dim = index_to_dim_info[i].dim
+            if dim == -1 or d.dim == -1:
+                dim = dim if dim != -1 else d.dim
+            elif dim != d.dim:
+                raise ValueError(
+                    f"Inconsistent source dimension for {i}: " f"{d.dim} vs {dim}"
+                )
+            mode_format = _mode_format_estimator(expr.op)(
+                index_to_dim_info[i].mode_format, d.mode_format
+            )
+            index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
 
 
 def _validate_and_collect_expr_info(
     expr: IndexExpr,
     expr_to_info: _ExprInfoDict,
 ) -> None:
-  """Validates dimension information and constructs _ExprInfo.
-
-  Validates that dimensional values for the same IndexVar are the same. Collects
-  a list of IndexVar used by the expression and their corresponding dimensional
-  values. Constructs an _ExprInfo object to record the information for the
-  IndexExpr.
-
-  This routine is passed to the post-order visitor as an _ExprVisitor object.
-
-  Args:
-    expr: The IndexExpr being validated.
-    expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
-      expression information.
-
-  Raises:
-    ValueError if there is any problem in the IndexVars or dimensional values.
-  """
-  # Objects of class Access can be shared by 
diff erent expressions. Avoid
-  # processing Access objects multiple times by skipping the processing if expr
-  # is already in the dictionary.
-  if expr in expr_to_info:
-    return
-
-  if isinstance(expr, IndexVar):
-    src_indices = expr,  # A tuple with one element.
-    dim_infos = _get_dummy_dim_info(),  # A tuple with one element.
-  elif isinstance(expr, Access):
-    src_indices = expr.indices
-    src_dims = tuple(expr.tensor.shape)
-    if expr.tensor.format is None:
-      # Treat each dimension of a dense tensor as DENSE for the purpose of
-      # calculating temporary tensor storage format.
-      mode_formats = tuple([ModeFormat.DENSE] * len(src_dims))
+    """Validates dimension information and constructs _ExprInfo.
+
+    Validates that dimensional values for the same IndexVar are the same. Collects
+    a list of IndexVar used by the expression and their corresponding dimensional
+    values. Constructs an _ExprInfo object to record the information for the
+    IndexExpr.
+
+    This routine is passed to the post-order visitor as an _ExprVisitor object.
+
+    Args:
+      expr: The IndexExpr being validated.
+      expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
+        expression information.
+
+    Raises:
+      ValueError if there is any problem in the IndexVars or dimensional values.
+    """
+    # Objects of class Access can be shared by 
diff erent expressions. Avoid
+    # processing Access objects multiple times by skipping the processing if expr
+    # is already in the dictionary.
+    if expr in expr_to_info:
+        return
+
+    if isinstance(expr, IndexVar):
+        src_indices = (expr,)  # A tuple with one element.
+        dim_infos = (_get_dummy_dim_info(),)  # A tuple with one element.
+    elif isinstance(expr, Access):
+        src_indices = expr.indices
+        src_dims = tuple(expr.tensor.shape)
+        if expr.tensor.format is None:
+            # Treat each dimension of a dense tensor as DENSE for the purpose of
+            # calculating temporary tensor storage format.
+            mode_formats = tuple([ModeFormat.DENSE] * len(src_dims))
+        else:
+            mode_formats = tuple(expr.tensor.format.format_pack.formats)
+        assert len(src_dims) == len(mode_formats)
+        dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
+    elif isinstance(expr, _UnaryExpr):
+        a_info = expr_to_info[expr.a]
+        index_to_dim_info = {i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)}
+        # Here we rely on the fact that dictionaries keep the insertion order for
+        # keys and values.
+        src_indices = tuple(index_to_dim_info.keys())
+        dim_infos = tuple(index_to_dim_info.values())
     else:
-      mode_formats = tuple(expr.tensor.format.format_pack.formats)
-    assert len(src_dims) == len(mode_formats)
-    dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
-  elif isinstance(expr, _UnaryExpr):
-    a_info = expr_to_info[expr.a]
-    index_to_dim_info = {
-        i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
-    }
-    # Here we rely on the fact that dictionaries keep the insertion order for
-    # keys and values.
-    src_indices = tuple(index_to_dim_info.keys())
-    dim_infos = tuple(index_to_dim_info.values())
-  else:
-    assert isinstance(expr, _BinaryExpr)
-    a_info = expr_to_info[expr.a]
-    index_to_dim_info = {
-        i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
-    }
-    b_info = expr_to_info[expr.b]
-    _validate_and_collect_dim_info(index_to_dim_info, b_info.src_indices,
-                                   b_info.dim_infos, expr)
-    # Here we rely on the fact that dictionaries keep the insertion order for
-    # keys and values.
-    src_indices = tuple(index_to_dim_info.keys())
-    dim_infos = tuple(index_to_dim_info.values())
+        assert isinstance(expr, _BinaryExpr)
+        a_info = expr_to_info[expr.a]
+        index_to_dim_info = {i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)}
+        b_info = expr_to_info[expr.b]
+        _validate_and_collect_dim_info(
+            index_to_dim_info, b_info.src_indices, b_info.dim_infos, expr
+        )
+        # Here we rely on the fact that dictionaries keep the insertion order for
+        # keys and values.
+        src_indices = tuple(index_to_dim_info.keys())
+        dim_infos = tuple(index_to_dim_info.values())
 
-  expr_to_info[expr] = _ExprInfo(src_indices, dim_infos)
+    expr_to_info[expr] = _ExprInfo(src_indices, dim_infos)
 
 
 def _mark_structured_op_root(
@@ -1928,90 +2006,92 @@ def _mark_structured_op_root(
     reduce_index: IndexVar,
     expr_to_info: _ExprInfoDict,
 ) -> None:
-  """Identifies the root expression for a structured op in the linalg dialect.
-
-  An linalg structured op can only perform reduction on the whole expression.
-  For a TACO tensor algebra expression, the reduction on an IndexVar is done at
-  the smallest expression that contains all the uses of the IndexVar. If such an
-  expression is only part of the whole expression, we need to split this
-  sub-expression tree out from its parent and implement the sub-expression as a
-  structured op.
-
-  This routine identifies the root expression node for performing a reduction on
-  the given IndexVar. If the reduction of the given IndexVar should be performed
-  on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices
-
-  Args:
-    expr: The root IndexExpr for the tensor algebra expression.
-    reduce_index: The IndexVar which we want to find out the proper expression
-      to perform a reduction.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
-
-  Raises:
-      ValueError: If the expression is not proper or not supported.
-  """
-  expr_info = expr_to_info[expr]
-  if isinstance(expr, Access):
-    # Handle simple reduction expression in the format of A[i] = B[i, j].
-    if reduce_index in expr_info.src_indices:
-      expr_info.reduce_indices.add(reduce_index)
-    return
-  elif isinstance(expr, IndexVar):
-    # A[i] = B[i] + j is not allowed.
-    raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.")
-
-  assert (isinstance(expr, _BinaryExpr))
-  a_info = expr_to_info[expr.a]
-  b_info = expr_to_info[expr.b]
-
-  if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
-    expr_info.reduce_indices.add(reduce_index)
-    return
-
-  if reduce_index in a_info.src_indices:
-    _mark_structured_op_root(expr.a, reduce_index, expr_to_info)
-  elif reduce_index in b_info.src_indices:
-    _mark_structured_op_root(expr.b, reduce_index, expr_to_info)
-  else:
-    assert False, "Unreachable path"
+    """Identifies the root expression for a structured op in the linalg dialect.
 
+    An linalg structured op can only perform reduction on the whole expression.
+    For a TACO tensor algebra expression, the reduction on an IndexVar is done at
+    the smallest expression that contains all the uses of the IndexVar. If such an
+    expression is only part of the whole expression, we need to split this
+    sub-expression tree out from its parent and implement the sub-expression as a
+    structured op.
 
-def _accumulate_reduce_indices(
-    expr: IndexExpr,
-    expr_to_info: _ExprInfoDict,
-) -> None:
-  """Propagates reduction indices from child expressions to parent expressions.
+    This routine identifies the root expression node for performing a reduction on
+    the given IndexVar. If the reduction of the given IndexVar should be performed
+    on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices
 
-  This routine is passed to the post-order visitor as an _ExprVisitor object.
+    Args:
+      expr: The root IndexExpr for the tensor algebra expression.
+      reduce_index: The IndexVar which we want to find out the proper expression
+        to perform a reduction.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
 
-  Args:
-    expr: The IndexExpr being visited.
-    expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
-      expression information.
-  """
-  assert expr in expr_to_info
-  expr_info = expr_to_info[expr]
+    Raises:
+        ValueError: If the expression is not proper or not supported.
+    """
+    expr_info = expr_to_info[expr]
+    if isinstance(expr, Access):
+        # Handle simple reduction expression in the format of A[i] = B[i, j].
+        if reduce_index in expr_info.src_indices:
+            expr_info.reduce_indices.add(reduce_index)
+        return
+    elif isinstance(expr, IndexVar):
+        # A[i] = B[i] + j is not allowed.
+        raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.")
 
-  if isinstance(expr, _BinaryExpr):
+    assert isinstance(expr, _BinaryExpr)
     a_info = expr_to_info[expr.a]
     b_info = expr_to_info[expr.b]
-    expr_info.acc_reduce_indices = (
-        a_info.acc_reduce_indices | b_info.acc_reduce_indices
-        | expr_info.reduce_indices)
-  elif isinstance(expr, _UnaryExpr):
-    a_info = expr_to_info[expr.a]
-    expr_info.acc_reduce_indices = (
-        a_info.acc_reduce_indices | expr_info.reduce_indices)
-  elif isinstance(expr, IndexVar):
-    # If an IndexVar is reducing itself, it means the IndexVar is outside the
-    # iteration domain. This usage is now allowed and we should emit an error
-    # before reaching here.
-    assert not expr_info.reduce_indices
-  else:
-    assert isinstance(expr, Access)
-    # Handle simple reduction expression in the format of A[i] = B[i, j].
-    expr_info.acc_reduce_indices = expr_info.reduce_indices
 
+    if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
+        expr_info.reduce_indices.add(reduce_index)
+        return
+
+    if reduce_index in a_info.src_indices:
+        _mark_structured_op_root(expr.a, reduce_index, expr_to_info)
+    elif reduce_index in b_info.src_indices:
+        _mark_structured_op_root(expr.b, reduce_index, expr_to_info)
+    else:
+        assert False, "Unreachable path"
+
+
+def _accumulate_reduce_indices(
+    expr: IndexExpr,
+    expr_to_info: _ExprInfoDict,
+) -> None:
+    """Propagates reduction indices from child expressions to parent expressions.
+
+    This routine is passed to the post-order visitor as an _ExprVisitor object.
+
+    Args:
+      expr: The IndexExpr being visited.
+      expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
+        expression information.
+    """
+    assert expr in expr_to_info
+    expr_info = expr_to_info[expr]
+
+    if isinstance(expr, _BinaryExpr):
+        a_info = expr_to_info[expr.a]
+        b_info = expr_to_info[expr.b]
+        expr_info.acc_reduce_indices = (
+            a_info.acc_reduce_indices
+            | b_info.acc_reduce_indices
+            | expr_info.reduce_indices
+        )
+    elif isinstance(expr, _UnaryExpr):
+        a_info = expr_to_info[expr.a]
+        expr_info.acc_reduce_indices = (
+            a_info.acc_reduce_indices | expr_info.reduce_indices
+        )
+    elif isinstance(expr, IndexVar):
+        # If an IndexVar is reducing itself, it means the IndexVar is outside the
+        # iteration domain. This usage is now allowed and we should emit an error
+        # before reaching here.
+        assert not expr_info.reduce_indices
+    else:
+        assert isinstance(expr, Access)
+        # Handle simple reduction expression in the format of A[i] = B[i, j].
+        expr_info.acc_reduce_indices = expr_info.reduce_indices
 
 
 def _gather_structured_op(
@@ -2019,42 +2099,42 @@ def _gather_structured_op(
     expr_to_info: _ExprInfoDict,
     structop_roots: List[IndexExpr],
 ) -> None:
-  """Adds structured op root expression information to structop_roots.
-
-  This routine is passed to the post-order visitor as an _ExprVisitor object.
-
-  Args:
-    expr: The IndexExpr being visited.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
-    structop_roots: The resulting list of IndexExpr that are the roots for
-      linalg structured ops.
-  """
-  if not expr_to_info[expr].reduce_indices:
-    return
-
-  # If the expression is the root for reducing some indices, collect the indices
-  # and dimensions for the reduction result.
-  dst_indices = []
-  dst_dims = []
-  mode_fmts = []
-  for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos):
-    if i not in expr_to_info[expr].acc_reduce_indices:
-      dst_indices.append(i)
-      dst_dims.append(d.dim)
-      mode_fmts.append(d.mode_format)
-
-  # Add the information to the dictionary.
-  op_info = _StructOpInfo(
-      tuple(dst_indices),
-      tuple(dst_dims),
-      expr.dtype(),
-      f"temp{len(structop_roots)}",
-      _make_format(mode_fmts),
-  )
-  expr_to_info[expr].structop_info = op_info
-
-  # Add the expression to the list of structured op roots.
-  structop_roots.append(expr)
+    """Adds structured op root expression information to structop_roots.
+
+    This routine is passed to the post-order visitor as an _ExprVisitor object.
+
+    Args:
+      expr: The IndexExpr being visited.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+      structop_roots: The resulting list of IndexExpr that are the roots for
+        linalg structured ops.
+    """
+    if not expr_to_info[expr].reduce_indices:
+        return
+
+    # If the expression is the root for reducing some indices, collect the indices
+    # and dimensions for the reduction result.
+    dst_indices = []
+    dst_dims = []
+    mode_fmts = []
+    for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos):
+        if i not in expr_to_info[expr].acc_reduce_indices:
+            dst_indices.append(i)
+            dst_dims.append(d.dim)
+            mode_fmts.append(d.mode_format)
+
+    # Add the information to the dictionary.
+    op_info = _StructOpInfo(
+        tuple(dst_indices),
+        tuple(dst_dims),
+        expr.dtype(),
+        f"temp{len(structop_roots)}",
+        _make_format(mode_fmts),
+    )
+    expr_to_info[expr].structop_info = op_info
+
+    # Add the expression to the list of structured op roots.
+    structop_roots.append(expr)
 
 
 def _is_structured_op_leaf(
@@ -2063,29 +2143,31 @@ def _is_structured_op_leaf(
     expr_to_info: _ExprInfoDict,
     *unused_args,
 ) -> bool:
-  """Returns true iff the expression is a leaf node for a structured op.
+    """Returns true iff the expression is a leaf node for a structured op.
 
-  The root of a structured op is a leaf of its parent structured op that uses
-  its result. An expression node is a leaf node for the current structured op if
-  it is an Access node or the root for a structured op that is not the current
-  structured op.
+    The root of a structured op is a leaf of its parent structured op that uses
+    its result. An expression node is a leaf node for the current structured op if
+    it is an Access node or the root for a structured op that is not the current
+    structured op.
 
-  This routine is passed to the post-order visitor as a _SubtreeLeafChecker
-  object. Because the post-order visitor pass the same parameters to both
-  _SubtreeLeafChecker and _ExprVisitor, this routine may received unused
-  parameters.
+    This routine is passed to the post-order visitor as a _SubtreeLeafChecker
+    object. Because the post-order visitor pass the same parameters to both
+    _SubtreeLeafChecker and _ExprVisitor, this routine may received unused
+    parameters.
 
-  Args:
-    expr: The IndexExpr being visited.
-    root: The root of the current structured op.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+    Args:
+      expr: The IndexExpr being visited.
+      root: The root of the current structured op.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
 
-  Returns:
-    True if the current IndexExpr is a leaf for the current structured op.
-  """
-  return (expr != root and
-          expr_to_info[expr].structop_info is not None) or isinstance(
-              expr, Access) or isinstance(expr, IndexVar)
+    Returns:
+      True if the current IndexExpr is a leaf for the current structured op.
+    """
+    return (
+        (expr != root and expr_to_info[expr].structop_info is not None)
+        or isinstance(expr, Access)
+        or isinstance(expr, IndexVar)
+    )
 
 
 def _gather_structured_op_input(
@@ -2094,26 +2176,28 @@ def _gather_structured_op_input(
     expr_to_info: _ExprInfoDict,
     structop_inputs: List[IndexExpr],
 ) -> None:
-  """Adds the IndexExpr to structop_inputs if it is an input.
+    """Adds the IndexExpr to structop_inputs if it is an input.
 
-  If the current IndexExpr is an input for the current structured op, adds it to
-  structop_inputs. The current IndexExpr is an input if it is an Access node or
-  if it is the root for a structured op that is not the current structured op.
+    If the current IndexExpr is an input for the current structured op, adds it to
+    structop_inputs. The current IndexExpr is an input if it is an Access node or
+    if it is the root for a structured op that is not the current structured op.
 
-  This routine is passed to the post-order visitor as an _ExprVisitor object.
+    This routine is passed to the post-order visitor as an _ExprVisitor object.
 
-  Args:
-    expr: The IndexExpr being visited.
-    root: The root of the current structured op.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
-    structop_inputs: The resulting list of IndexExpr that provide input to the
-      current structured op.
-  """
-  if ((expr != root or isinstance(expr, Access)) and
-      expr not in structop_inputs) and (isinstance(expr, Access) or
-                                        (expr in expr_to_info and
-                                         expr_to_info[expr].structop_info)):
-    structop_inputs.append(expr)
+    Args:
+      expr: The IndexExpr being visited.
+      root: The root of the current structured op.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+      structop_inputs: The resulting list of IndexExpr that provide input to the
+        current structured op.
+    """
+    if (
+        (expr != root or isinstance(expr, Access)) and expr not in structop_inputs
+    ) and (
+        isinstance(expr, Access)
+        or (expr in expr_to_info and expr_to_info[expr].structop_info)
+    ):
+        structop_inputs.append(expr)
 
 
 def _emit_structured_op_input(
@@ -2121,35 +2205,35 @@ def _emit_structured_op_input(
     expr_to_info: _ExprInfoDict,
     op_def: lang.LinalgOpDef,
 ) -> lang.OperandDef:
-  """Emits OperandDef in the linalg dialect for the input IndexExpr.
-
-  Args:
-    expr: The input IndexExpr for the current structured op.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
-    op_def: The linalg operation for the current structured op.
-
-  Returns:
-    An OperandDef in the linalg dialect for the input IndexExpr.
-  """
-  op_info = expr_to_info[expr].structop_info
-  if op_info and not isinstance(expr, Access):
-    # The input is a temporary tensor produced by another structured op.
-    indices = op_info.dst_indices
-    name = op_info.dst_name
-  else:
-    # The input is a user provided tensor.
-    assert isinstance(expr, Access)
-    indices = expr.indices
-    name = expr.tensor.name
-
-  dim_sym = _mlir_symbols_from_index_vars(indices)
-  opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
-  op_def.add_operand(name, opnd)
-  return opnd
+    """Emits OperandDef in the linalg dialect for the input IndexExpr.
+
+    Args:
+      expr: The input IndexExpr for the current structured op.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+      op_def: The linalg operation for the current structured op.
+
+    Returns:
+      An OperandDef in the linalg dialect for the input IndexExpr.
+    """
+    op_info = expr_to_info[expr].structop_info
+    if op_info and not isinstance(expr, Access):
+        # The input is a temporary tensor produced by another structured op.
+        indices = op_info.dst_indices
+        name = op_info.dst_name
+    else:
+        # The input is a user provided tensor.
+        assert isinstance(expr, Access)
+        indices = expr.indices
+        name = expr.tensor.name
+
+    dim_sym = _mlir_symbols_from_index_vars(indices)
+    opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
+    op_def.add_operand(name, opnd)
+    return opnd
 
 
 def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
-  """Build a unary operation ceil.
+    """Build a unary operation ceil.
 
     Args:
       a: The operand, which could be any Python object from user inputs.
@@ -2161,13 +2245,13 @@ def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
     Raises:
       ValueError: If a is not an IndexExpr.
     """
-  if not isinstance(a, Access):
-    raise ValueError(f"Expected an Access Operand: {a}")
-  return a._build_unary_expr(op)
+    if not isinstance(a, Access):
+        raise ValueError(f"Expected an Access Operand: {a}")
+    return a._build_unary_expr(op)
 
 
 def ceil(a: Access) -> "_UnaryExpr":
-  """Defines the operation ceil.
+    """Defines the operation ceil.
 
     Args:
       a: The operand, which could be any Python object from user inputs.
@@ -2178,11 +2262,11 @@ def ceil(a: Access) -> "_UnaryExpr":
     Raises:
       ValueError: If a is not an IndexExpr.
     """
-  return _check_and_build_unary(a, _op_ceil)
+    return _check_and_build_unary(a, _op_ceil)
 
 
 def floor(a: Access) -> "_UnaryExpr":
-  """Defines the operation floor.
+    """Defines the operation floor.
 
     Args:
       a: The operand, which could be any Python object from user inputs.
@@ -2193,4 +2277,4 @@ def floor(a: Access) -> "_UnaryExpr":
     Raises:
       ValueError: If a is not an IndexExpr.
     """
-  return _check_and_build_unary(a, _op_floor)
+    return _check_and_build_unary(a, _op_floor)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
index e6a7d8e1b4b85..785401c25dc87 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
@@ -31,50 +31,52 @@
 _TNS_FILENAME_SUFFIX = ".tns"
 
 
-def read(filename: str, fmt: Format,
-         dtype: DType = DType(Type.FLOAT32)) -> Tensor:
-  """Inputs a tensor from a given file.
-
-  The name suffix of the file specifies the format of the input tensor. We
-  currently only support .mtx format for support sparse tensors.
-
-  Args:
-    filename: A string input filename.
-    fmt: The storage format of the tensor.
-    dtype: The data type, default to float32.
-
-  Raises:
-    ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
-    instance of Format or fmt is not a sparse tensor.
-  """
-  if (not isinstance(filename, str) or
-      (not filename.endswith(_MTX_FILENAME_SUFFIX) and
-       not filename.endswith(_TNS_FILENAME_SUFFIX))):
-    raise ValueError("Expected string filename ends with "
-                     f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
-                     f"{filename}.")
-
-  return Tensor.from_file(filename, fmt, dtype)
+def read(filename: str, fmt: Format, dtype: DType = DType(Type.FLOAT32)) -> Tensor:
+    """Inputs a tensor from a given file.
+
+    The name suffix of the file specifies the format of the input tensor. We
+    currently only support .mtx format for support sparse tensors.
+
+    Args:
+      filename: A string input filename.
+      fmt: The storage format of the tensor.
+      dtype: The data type, default to float32.
+
+    Raises:
+      ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
+      instance of Format or fmt is not a sparse tensor.
+    """
+    if not isinstance(filename, str) or (
+        not filename.endswith(_MTX_FILENAME_SUFFIX)
+        and not filename.endswith(_TNS_FILENAME_SUFFIX)
+    ):
+        raise ValueError(
+            "Expected string filename ends with "
+            f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
+            f"{filename}."
+        )
+
+    return Tensor.from_file(filename, fmt, dtype)
 
 
 def write(filename: str, tensor: Tensor) -> None:
-  """Outputs a tensor to a given file.
-
-  The name suffix of the file specifies the format of the output. We currently
-  only support .tns format.
-
-  Args:
-    filename: A string output filename.
-    tensor: The tensor to output.
-
-  Raises:
-    ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
-  """
-  if (not isinstance(filename, str) or
-      not filename.endswith(_TNS_FILENAME_SUFFIX)):
-    raise ValueError("Expected string filename ends with"
-                     f" {_TNS_FILENAME_SUFFIX}: {filename}.")
-  if not isinstance(tensor, Tensor):
-    raise ValueError(f"Expected a Tensor object: {tensor}.")
-
-  tensor.to_file(filename)
+    """Outputs a tensor to a given file.
+
+    The name suffix of the file specifies the format of the output. We currently
+    only support .tns format.
+
+    Args:
+      filename: A string output filename.
+      tensor: The tensor to output.
+
+    Raises:
+      ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
+    """
+    if not isinstance(filename, str) or not filename.endswith(_TNS_FILENAME_SUFFIX):
+        raise ValueError(
+            "Expected string filename ends with" f" {_TNS_FILENAME_SUFFIX}: {filename}."
+        )
+    if not isinstance(tensor, Tensor):
+        raise ValueError(f"Expected a Tensor object: {tensor}.")
+
+    tensor.to_file(filename)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
index 988c57b3b33f2..1e1061b8b858d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
@@ -36,190 +36,234 @@
 
 @functools.lru_cache()
 def _get_support_lib_name() -> str:
-  """Gets the string name for the supporting C shared library."""
-  return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
+    """Gets the string name for the supporting C shared library."""
+    return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
 
 
 @functools.lru_cache()
 def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler:
-  """Gets the MLIR sparse compiler with default setting."""
-  return mlir_sparse_compiler.SparseCompiler(
-      options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
+    """Gets the MLIR sparse compiler with default setting."""
+    return mlir_sparse_compiler.SparseCompiler(
+        options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()]
+    )
 
 
 def _record_support_funcs(
-    ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
-    ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
-  """Records the two supporting functions for a given data type."""
-  to_func.restype = ctypes.c_void_p
-  from_func.restype = ctypes.c_void_p
-  ty_to_funcs[ty] = (to_func, from_func)
+    ty: np.dtype,
+    to_func: _SupportFunc,
+    from_func: _SupportFunc,
+    ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]],
+) -> None:
+    """Records the two supporting functions for a given data type."""
+    to_func.restype = ctypes.c_void_p
+    from_func.restype = ctypes.c_void_p
+    ty_to_funcs[ty] = (to_func, from_func)
 
 
 @functools.lru_cache()
 def _get_support_func_locator() -> _SupportFuncLocator:
-  """Constructs a function to locate the supporting functions for a data type.
-
-  Loads the supporting C shared library with the needed routines. Constructs a
-  dictionary from the supported data types to the routines for the data types,
-  and then a function to look up the dictionary for a given data type.
-
-  The name of the supporting C shared library is either provided by an
-  an environment variable or a default value.
-
-  Returns:
-    The function to look up the supporting functions for a given data type.
-
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError: If the shared library doesn't contain the needed routines.
-  """
-  # This raises OSError exception if there is any problem in loading the shared
-  # library.
-  c_lib = ctypes.CDLL(_get_support_lib_name())
-
-  type_to_funcs = {}
-  try:
-    support_types = [(np.int8, c_lib.convertToMLIRSparseTensorI8,
-                      c_lib.convertFromMLIRSparseTensorI8),
-                     (np.int16, c_lib.convertToMLIRSparseTensorI16,
-                      c_lib.convertFromMLIRSparseTensorI16),
-                     (np.int32, c_lib.convertToMLIRSparseTensorI32,
-                      c_lib.convertFromMLIRSparseTensorI32),
-                     (np.int64, c_lib.convertToMLIRSparseTensorI64,
-                      c_lib.convertFromMLIRSparseTensorI64),
-                     (np.float16, c_lib.convertToMLIRSparseTensorF16,
-                      c_lib.convertFromMLIRSparseTensorF16),
-                     (np.float32, c_lib.convertToMLIRSparseTensorF32,
-                      c_lib.convertFromMLIRSparseTensorF32),
-                     (np.float64, c_lib.convertToMLIRSparseTensorF64,
-                      c_lib.convertFromMLIRSparseTensorF64),
-                     (np.complex64, c_lib.convertToMLIRSparseTensorC32,
-                      c_lib.convertFromMLIRSparseTensorC32),
-                     (np.complex128, c_lib.convertToMLIRSparseTensorC64,
-                      c_lib.convertFromMLIRSparseTensorC64)]
-  except Exception as e:
-    raise ValueError(f"Missing supporting function: {e}") from e
-  for i, info in enumerate(support_types):
-    _record_support_funcs(info[0], info[1], info[2], type_to_funcs)
-
-  def get_support_funcs(ty: np.dtype):
-    funcs = type_to_funcs[ty]
-    assert funcs is not None
-    return funcs
-
-  return get_support_funcs
+    """Constructs a function to locate the supporting functions for a data type.
+
+    Loads the supporting C shared library with the needed routines. Constructs a
+    dictionary from the supported data types to the routines for the data types,
+    and then a function to look up the dictionary for a given data type.
+
+    The name of the supporting C shared library is either provided by an
+    an environment variable or a default value.
+
+    Returns:
+      The function to look up the supporting functions for a given data type.
+
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError: If the shared library doesn't contain the needed routines.
+    """
+    # This raises OSError exception if there is any problem in loading the shared
+    # library.
+    c_lib = ctypes.CDLL(_get_support_lib_name())
+
+    type_to_funcs = {}
+    try:
+        support_types = [
+            (
+                np.int8,
+                c_lib.convertToMLIRSparseTensorI8,
+                c_lib.convertFromMLIRSparseTensorI8,
+            ),
+            (
+                np.int16,
+                c_lib.convertToMLIRSparseTensorI16,
+                c_lib.convertFromMLIRSparseTensorI16,
+            ),
+            (
+                np.int32,
+                c_lib.convertToMLIRSparseTensorI32,
+                c_lib.convertFromMLIRSparseTensorI32,
+            ),
+            (
+                np.int64,
+                c_lib.convertToMLIRSparseTensorI64,
+                c_lib.convertFromMLIRSparseTensorI64,
+            ),
+            (
+                np.float16,
+                c_lib.convertToMLIRSparseTensorF16,
+                c_lib.convertFromMLIRSparseTensorF16,
+            ),
+            (
+                np.float32,
+                c_lib.convertToMLIRSparseTensorF32,
+                c_lib.convertFromMLIRSparseTensorF32,
+            ),
+            (
+                np.float64,
+                c_lib.convertToMLIRSparseTensorF64,
+                c_lib.convertFromMLIRSparseTensorF64,
+            ),
+            (
+                np.complex64,
+                c_lib.convertToMLIRSparseTensorC32,
+                c_lib.convertFromMLIRSparseTensorC32,
+            ),
+            (
+                np.complex128,
+                c_lib.convertToMLIRSparseTensorC64,
+                c_lib.convertFromMLIRSparseTensorC64,
+            ),
+        ]
+    except Exception as e:
+        raise ValueError(f"Missing supporting function: {e}") from e
+    for i, info in enumerate(support_types):
+        _record_support_funcs(info[0], info[1], info[2], type_to_funcs)
+
+    def get_support_funcs(ty: np.dtype):
+        funcs = type_to_funcs[ty]
+        assert funcs is not None
+        return funcs
+
+    return get_support_funcs
 
 
 def sparse_tensor_to_coo_tensor(
     sparse_tensor: ctypes.c_void_p,
     dtype: np.dtype,
 ) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
-  """Converts an MLIR sparse tensor to a COO-flavored format tensor.
-
-  Args:
-     sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
-     dtype: The numpy data type for the tensor elements.
-
-  Returns:
-    A tuple that contains the following values for the COO-flavored format
-    tensor:
-    rank: An integer for the rank of the tensor.
-    nse: An integer for the number of non-zero values in the tensor.
-    shape: A 1D numpy array of integers, for the shape of the tensor.
-    values: A 1D numpy array, for the non-zero values in the tensor.
-    indices: A 2D numpy array of integers, representing the indices for the
-      non-zero values in the tensor.
-
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError: If the shared library doesn't contain the needed routines.
-  """
-  convert_from = _get_support_func_locator()(dtype)[1]
-  rank = ctypes.c_ulonglong(0)
-  nse = ctypes.c_ulonglong(0)
-  shape = ctypes.POINTER(ctypes.c_ulonglong)()
-
-  values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
-  indices = ctypes.POINTER(ctypes.c_ulonglong)()
-  convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse),
-               ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices))
-
-  # Convert the returned values to the corresponding numpy types.
-  shape = np.ctypeslib.as_array(shape, shape=[rank.value])
-  values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
-  indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
-  return rank.value, nse.value, shape, values, indices
-
-
-def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
-                                np_indices: np.ndarray, np_perm: np.ndarray,
-                                np_sparse: np.ndarray) -> int:
-  """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
-
-  Args:
-     np_shape: A 1D numpy array of integers, for the shape of the tensor.
-     np_values: A 1D numpy array, for the non-zero values in the tensor.
-     np_indices: A 2D numpy array of integers, representing the indices for the
-       non-zero values in the tensor.
-     np_perm: A 1D numpy array of integers, representing the storage ordering
-       for the dimensions.
-     np_sparse: A 1D numpy array of uint8, representing the sparsity values
-       for the dimensions.
-
-  Returns:
-     An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
-     descriptor.
-
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError: If the shared library doesn't contain the needed routines.
-  """
-
-  r = len(np_shape)
-  rank = ctypes.c_ulonglong(r)
-  nse = ctypes.c_ulonglong(len(np_values))
-  shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
-  values = np_values.ctypes.data_as(
-      ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype))))
-  indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
-
-  perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
-  sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
-
-  convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
-  ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
-  assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
-  return ptr
-
-
-def compile_and_build_engine(
-    module: ir.Module) -> execution_engine.ExecutionEngine:
-  """Compiles an MLIR module and builds a JIT execution engine.
-
-  Args:
-    module: The MLIR module.
-
-  Returns:
-    A JIT execution engine for the MLIR module.
-
-  """
-  return _get_sparse_compiler().compile_and_jit(module)
+    """Converts an MLIR sparse tensor to a COO-flavored format tensor.
+
+    Args:
+       sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
+       dtype: The numpy data type for the tensor elements.
+
+    Returns:
+      A tuple that contains the following values for the COO-flavored format
+      tensor:
+      rank: An integer for the rank of the tensor.
+      nse: An integer for the number of non-zero values in the tensor.
+      shape: A 1D numpy array of integers, for the shape of the tensor.
+      values: A 1D numpy array, for the non-zero values in the tensor.
+      indices: A 2D numpy array of integers, representing the indices for the
+        non-zero values in the tensor.
+
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError: If the shared library doesn't contain the needed routines.
+    """
+    convert_from = _get_support_func_locator()(dtype)[1]
+    rank = ctypes.c_ulonglong(0)
+    nse = ctypes.c_ulonglong(0)
+    shape = ctypes.POINTER(ctypes.c_ulonglong)()
+
+    values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
+    indices = ctypes.POINTER(ctypes.c_ulonglong)()
+    convert_from(
+        sparse_tensor,
+        ctypes.byref(rank),
+        ctypes.byref(nse),
+        ctypes.byref(shape),
+        ctypes.byref(values),
+        ctypes.byref(indices),
+    )
+
+    # Convert the returned values to the corresponding numpy types.
+    shape = np.ctypeslib.as_array(shape, shape=[rank.value])
+    values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
+    indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
+    return rank.value, nse.value, shape, values, indices
+
+
+def coo_tensor_to_sparse_tensor(
+    np_shape: np.ndarray,
+    np_values: np.ndarray,
+    np_indices: np.ndarray,
+    np_perm: np.ndarray,
+    np_sparse: np.ndarray,
+) -> int:
+    """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
+
+    Args:
+       np_shape: A 1D numpy array of integers, for the shape of the tensor.
+       np_values: A 1D numpy array, for the non-zero values in the tensor.
+       np_indices: A 2D numpy array of integers, representing the indices for the
+         non-zero values in the tensor.
+       np_perm: A 1D numpy array of integers, representing the storage ordering
+         for the dimensions.
+       np_sparse: A 1D numpy array of uint8, representing the sparsity values
+         for the dimensions.
+
+    Returns:
+       An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
+       descriptor.
+
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError: If the shared library doesn't contain the needed routines.
+    """
+
+    r = len(np_shape)
+    rank = ctypes.c_ulonglong(r)
+    nse = ctypes.c_ulonglong(len(np_values))
+    shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+    values = np_values.ctypes.data_as(
+        ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype)))
+    )
+    indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+
+    perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+    sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
+
+    convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
+    ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
+    assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
+    return ptr
+
+
+def compile_and_build_engine(module: ir.Module) -> execution_engine.ExecutionEngine:
+    """Compiles an MLIR module and builds a JIT execution engine.
+
+    Args:
+      module: The MLIR module.
+
+    Returns:
+      A JIT execution engine for the MLIR module.
+
+    """
+    return _get_sparse_compiler().compile_and_jit(module)
 
 
 class _SparseTensorDescriptor(ctypes.Structure):
-  """A C structure for an MLIR sparse tensor."""
-  _fields_ = [
-      # A pointer for the MLIR sparse tensor storage.
-      ("storage", ctypes.POINTER(ctypes.c_ulonglong)),
-      # An MLIR MemRef descriptor for the shape of the sparse tensor.
-      ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
-  ]
+    """A C structure for an MLIR sparse tensor."""
+
+    _fields_ = [
+        # A pointer for the MLIR sparse tensor storage.
+        ("storage", ctypes.POINTER(ctypes.c_ulonglong)),
+        # An MLIR MemRef descriptor for the shape of the sparse tensor.
+        ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
+    ]
 
 
 def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
-  """Produces the MLIR text code to output the size for the given dimension."""
-  return f"""
+    """Produces the MLIR text code to output the size for the given dimension."""
+    return f"""
   %c{dim} = arith.constant {dim} : index
   %d{dim} = tensor.dim %t, %c{dim} : tensor<{shape}x{type}, #enc>
   memref.store %d{dim}, %b[%c{dim}] : memref<{rank}xindex>
@@ -233,26 +277,29 @@ def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
 # (2) Use scf.for instead of an unrolled loop to write out the dimension sizes
 #     when tensor.dim supports non-constant dimension value.
 def _get_create_sparse_tensor_kernel(
-    sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str) -> str:
-  """Creates an MLIR text kernel to contruct a sparse tensor from a file.
+    sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
+) -> str:
+    """Creates an MLIR text kernel to contruct a sparse tensor from a file.
 
-  The kernel returns a _SparseTensorDescriptor structure.
-  """
-  rank = len(sparsity_codes)
+    The kernel returns a _SparseTensorDescriptor structure.
+    """
+    rank = len(sparsity_codes)
 
-  # Use ? to represent a dimension in the dynamic shape string representation.
-  shape = "x".join(map(lambda d: "?", range(rank)))
+    # Use ? to represent a dimension in the dynamic shape string representation.
+    shape = "x".join(map(lambda d: "?", range(rank)))
 
-  # Convert the encoded sparsity values to a string representation.
-  sparsity = ", ".join(
-      map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
+    # Convert the encoded sparsity values to a string representation.
+    sparsity = ", ".join(
+        map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
+    )
 
-  # Get the MLIR text code to write the dimension sizes to the output buffer.
-  output_dims = "\n".join(
-      map(lambda d: _output_one_dim(d, rank, shape, type), range(rank)))
+    # Get the MLIR text code to write the dimension sizes to the output buffer.
+    output_dims = "\n".join(
+        map(lambda d: _output_one_dim(d, rank, shape, type), range(rank))
+    )
 
-  # Return the MLIR text kernel.
-  return f"""
+    # Return the MLIR text kernel.
+    return f"""
 !Ptr = !llvm.ptr<i8>
 #enc = #sparse_tensor.encoding<{{
   lvlTypes = [ {sparsity} ]
@@ -266,69 +313,69 @@ def _get_create_sparse_tensor_kernel(
 }}"""
 
 
-def create_sparse_tensor(filename: str,
-                         sparsity: Sequence[sparse_tensor.DimLevelType],
-                         type: str) -> Tuple[ctypes.c_void_p, np.ndarray]:
-  """Creates an MLIR sparse tensor from the input file.
+def create_sparse_tensor(
+    filename: str, sparsity: Sequence[sparse_tensor.DimLevelType], type: str
+) -> Tuple[ctypes.c_void_p, np.ndarray]:
+    """Creates an MLIR sparse tensor from the input file.
 
-  Args:
-    filename: A string for the name of the file that contains the tensor data in
-      a COO-flavored format.
-    sparsity: A sequence of DimLevelType values, one for each dimension of the
-      tensor.
+    Args:
+      filename: A string for the name of the file that contains the tensor data in
+        a COO-flavored format.
+      sparsity: A sequence of DimLevelType values, one for each dimension of the
+        tensor.
 
-  Returns:
-    A Tuple containing the following values:
-    storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
-    shape: A 1D numpy array of integers, for the shape of the tensor.
+    Returns:
+      A Tuple containing the following values:
+      storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
+      shape: A 1D numpy array of integers, for the shape of the tensor.
 
-  Raises:
-    OSError: If there is any problem in loading the supporting C shared library.
-    ValueError:  If the shared library doesn't contain the needed routine.
-  """
-  with ir.Context() as ctx, ir.Location.unknown():
-    module = _get_create_sparse_tensor_kernel(sparsity, type)
-    module = ir.Module.parse(module)
-    engine = compile_and_build_engine(module)
+    Raises:
+      OSError: If there is any problem in loading the supporting C shared library.
+      ValueError:  If the shared library doesn't contain the needed routine.
+    """
+    with ir.Context() as ctx, ir.Location.unknown():
+        module = _get_create_sparse_tensor_kernel(sparsity, type)
+        module = ir.Module.parse(module)
+        engine = compile_and_build_engine(module)
 
-  # A sparse tensor descriptor to receive the kernel result.
-  c_tensor_desc = _SparseTensorDescriptor()
-  # Convert the filename to a byte stream.
-  c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
+    # A sparse tensor descriptor to receive the kernel result.
+    c_tensor_desc = _SparseTensorDescriptor()
+    # Convert the filename to a byte stream.
+    c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
 
-  arg_pointers = [
-      ctypes.byref(ctypes.pointer(c_tensor_desc)),
-      ctypes.byref(c_filename)
-  ]
+    arg_pointers = [
+        ctypes.byref(ctypes.pointer(c_tensor_desc)),
+        ctypes.byref(c_filename),
+    ]
 
-  # Invoke the execution engine to run the module and return the result.
-  engine.invoke(_ENTRY_NAME, *arg_pointers)
-  shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
-  return c_tensor_desc.storage, shape
+    # Invoke the execution engine to run the module and return the result.
+    engine.invoke(_ENTRY_NAME, *arg_pointers)
+    shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
+    return c_tensor_desc.storage, shape
 
 
 # TODO: With better support from MLIR, we may improve the current implementation
 # by using Python code to generate the kernel instead of doing MLIR text code
 # stitching.
 def _get_output_sparse_tensor_kernel(
-        sparsity_codes: Sequence[sparse_tensor.DimLevelType],
-        type: str) -> str:
-  """Creates an MLIR text kernel to output a sparse tensor to a file.
+    sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
+) -> str:
+    """Creates an MLIR text kernel to output a sparse tensor to a file.
 
-  The kernel returns void.
-  """
-  rank = len(sparsity_codes)
+    The kernel returns void.
+    """
+    rank = len(sparsity_codes)
 
-  # Use ? to represent a dimension in the dynamic shape string representation.
-  shape = "x".join(map(lambda d: "?", range(rank)))
+    # Use ? to represent a dimension in the dynamic shape string representation.
+    shape = "x".join(map(lambda d: "?", range(rank)))
 
-  # Convert the encoded sparsity values to a string representation.
-  sparsity = ", ".join(
-      map(lambda s: '"compressed"'
-          if s.value else '"dense"', sparsity_codes))
+    # Convert the encoded sparsity values to a string representation.
+    sparsity = ", ".join(
+        map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
+    )
 
-  # Return the MLIR text kernel.
-  return f"""
+    # Return the MLIR text kernel.
+    return f"""
 !Ptr = !llvm.ptr<i8>
 #enc = #sparse_tensor.encoding<{{
   lvlTypes = [ {sparsity} ]
@@ -340,35 +387,38 @@ def _get_output_sparse_tensor_kernel(
 }}"""
 
 
-def output_sparse_tensor(tensor: ctypes.c_void_p, filename: str,
-                         sparsity: Sequence[sparse_tensor.DimLevelType],
-                         type: str) -> None:
-  """Outputs an MLIR sparse tensor to the given file.
-
-  Args:
-    tensor: A C pointer to the MLIR sparse tensor.
-    filename: A string for the name of the file that contains the tensor data in
-      a COO-flavored format.
-    sparsity: A sequence of DimLevelType values, one for each dimension of the
-      tensor.
-    type: The MLIR string for the data type.
-
-  Raises:
-    OSError: If there is any problem in loading the supporting C shared library.
-    ValueError:  If the shared library doesn't contain the needed routine.
-  """
-  with ir.Context() as ctx, ir.Location.unknown():
-    module = _get_output_sparse_tensor_kernel(sparsity, type)
-    module = ir.Module.parse(module)
-    engine = compile_and_build_engine(module)
-
-  # Convert the filename to a byte stream.
-  c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
-
-  arg_pointers = [
-      ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
-      ctypes.byref(c_filename)
-  ]
-
-  # Invoke the execution engine to run the module and return the result.
-  engine.invoke(_ENTRY_NAME, *arg_pointers)
+def output_sparse_tensor(
+    tensor: ctypes.c_void_p,
+    filename: str,
+    sparsity: Sequence[sparse_tensor.DimLevelType],
+    type: str,
+) -> None:
+    """Outputs an MLIR sparse tensor to the given file.
+
+    Args:
+      tensor: A C pointer to the MLIR sparse tensor.
+      filename: A string for the name of the file that contains the tensor data in
+        a COO-flavored format.
+      sparsity: A sequence of DimLevelType values, one for each dimension of the
+        tensor.
+      type: The MLIR string for the data type.
+
+    Raises:
+      OSError: If there is any problem in loading the supporting C shared library.
+      ValueError:  If the shared library doesn't contain the needed routine.
+    """
+    with ir.Context() as ctx, ir.Location.unknown():
+        module = _get_output_sparse_tensor_kernel(sparsity, type)
+        module = ir.Module.parse(module)
+        engine = compile_and_build_engine(module)
+
+    # Convert the filename to a byte stream.
+    c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
+
+    arg_pointers = [
+        ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
+        ctypes.byref(c_filename),
+    ]
+
+    # Invoke the execution engine to run the module and return the result.
+    engine.invoke(_ENTRY_NAME, *arg_pointers)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
index 69db28d4bccd5..8f193b81bb07c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
@@ -13,29 +13,29 @@
 
 
 class SparseCompiler:
-  """Sparse compiler class for compiling and building MLIR modules."""
-
-  def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
-    pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
-    self.pipeline = pipeline
-    self.opt_level = opt_level
-    self.shared_libs = shared_libs
-
-  def __call__(self, module: ir.Module):
-    """Convenience application method."""
-    self.compile(module)
-
-  def compile(self, module: ir.Module):
-    """Compiles the module by invoking the sparse copmiler pipeline."""
-    passmanager.PassManager.parse(self.pipeline).run(module.operation)
-
-  def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
-    """Wraps the module in a JIT execution engine."""
-    return execution_engine.ExecutionEngine(
-        module, opt_level=self.opt_level, shared_libs=self.shared_libs)
-
-  def compile_and_jit(self,
-                      module: ir.Module) -> execution_engine.ExecutionEngine:
-    """Compiles and jits the module."""
-    self.compile(module)
-    return self.jit(module)
+    """Sparse compiler class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
+        self.pipeline = pipeline
+        self.opt_level = opt_level
+        self.shared_libs = shared_libs
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the sparse copmiler pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
index 466c9df042984..1be88fa8bd709 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
@@ -8,38 +8,40 @@
 
 
 def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool:
-  """Compares sparse tensor actual output file with expected output file.
+    """Compares sparse tensor actual output file with expected output file.
 
-  This routine assumes the input files are in FROSTT format. See
-  http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
+    This routine assumes the input files are in FROSTT format. See
+    http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
 
-  It also assumes the first line in the output file is a comment line.
+    It also assumes the first line in the output file is a comment line.
 
-  """
-  with open(actual, "r") as actual_f:
-    with open(expected, "r") as expected_f:
-      # Skip the first comment line.
-      _ = actual_f.readline()
-      _ = expected_f.readline()
+    """
+    with open(actual, "r") as actual_f:
+        with open(expected, "r") as expected_f:
+            # Skip the first comment line.
+            _ = actual_f.readline()
+            _ = expected_f.readline()
 
-      # Compare the two lines of meta data
-      if (actual_f.readline() != expected_f.readline() or
-          actual_f.readline() != expected_f.readline()):
-        return FALSE
+            # Compare the two lines of meta data
+            if (
+                actual_f.readline() != expected_f.readline()
+                or actual_f.readline() != expected_f.readline()
+            ):
+                return FALSE
 
-  actual_data = np.loadtxt(actual, np.float64, skiprows=3)
-  expected_data = np.loadtxt(expected, np.float64, skiprows=3)
-  return np.allclose(actual_data, expected_data, rtol=rtol)
+    actual_data = np.loadtxt(actual, np.float64, skiprows=3)
+    expected_data = np.loadtxt(expected, np.float64, skiprows=3)
+    return np.allclose(actual_data, expected_data, rtol=rtol)
 
 
 def file_as_string(file: str) -> str:
-  """Returns contents of file as string."""
-  with open(file, "r") as f:
-    return f.read()
+    """Returns contents of file as string."""
+    with open(file, "r") as f:
+        return f.read()
 
 
 def run_test(f):
-  """Prints the test name and runs the test."""
-  print(f.__name__)
-  f()
-  return f
+    """Prints the test name and runs the test."""
+    print(f.__name__)
+    f()
+    return f

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
index 5b7e648b97957..45ce446478dee 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
@@ -18,509 +18,630 @@
 
 
 def _init_3d(T, I, J, K):
-  for i in range(I):
-    for j in range(J):
-      for k in range(K):
-        T.insert([i, j, k], i + j + k + 1)
+    for i in range(I):
+        for j in range(J):
+            for k in range(K):
+                T.insert([i, j, k], i + j + k + 1)
 
 
 def _init_2d(T, I, J):
-  for i in range(I):
-    for j in range(J):
-      T.insert([i, j], i + j + 1)
+    for i in range(I):
+        for j in range(J):
+            T.insert([i, j], i + j + 1)
 
 
 def _init_1d_with_value(T, I, v):
-  for i in range(I):
-    T.insert([i], v)
+    for i in range(I):
+        T.insert([i], v)
 
 
 def test_expect_error(name, code, error):
-  """Executes the code then verifies the expected error message."""
-  try:
-    exec(code)
-  except ValueError as e:
-    passed = "passed" if (str(e).startswith(error)) else "failed"
-    print(f"test_{name}: {passed}")
+    """Executes the code then verifies the expected error message."""
+    try:
+        exec(code)
+    except ValueError as e:
+        passed = "passed" if (str(e).startswith(error)) else "failed"
+        print(f"test_{name}: {passed}")
 
 
 # CHECK-LABEL: test_tensor_dtype
 @testing_utils.run_test
 def test_tensor_dtype():
-  passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
-  passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
-  passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
-  passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
-  passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
-  # CHECK: Number of passed: 5
-  print("Number of passed:", passed)
+    passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
+    passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
+    passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
+    passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
+    passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
+    # CHECK: Number of passed: 5
+    print("Number of passed:", passed)
 
 
 # CHECK: test_mode_ordering_not_int: passed
-test_expect_error("mode_ordering_not_int",
-                  "m = mlir_pytaco.ModeOrdering(['x'])",
-                  "Ordering must be a list of integers")
+test_expect_error(
+    "mode_ordering_not_int",
+    "m = mlir_pytaco.ModeOrdering(['x'])",
+    "Ordering must be a list of integers",
+)
 
 # CHECK: test_mode_ordering_not_permutation: passed
-test_expect_error("mode_ordering_not_permutation",
-                  "m = mlir_pytaco.ModeOrdering([2, 1])", "Invalid ordering")
+test_expect_error(
+    "mode_ordering_not_permutation",
+    "m = mlir_pytaco.ModeOrdering([2, 1])",
+    "Invalid ordering",
+)
 
 # CHECK: test_mode_format_invalid: passed
-test_expect_error("mode_format_invalid",
-                  "m = mlir_pytaco.ModeFormatPack(['y'])",
-                  "Formats must be a list of ModeFormat")
+test_expect_error(
+    "mode_format_invalid",
+    "m = mlir_pytaco.ModeFormatPack(['y'])",
+    "Formats must be a list of ModeFormat",
+)
 
 # CHECK: test_expect_mode_format_pack: passed
-test_expect_error("expect_mode_format_pack", ("""
+test_expect_error(
+    "expect_mode_format_pack",
+    (
+        """
 mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
 f = mlir_pytaco.Format(["x"], mode_ordering)
-    """), "Expected a list of ModeFormat")
+    """
+    ),
+    "Expected a list of ModeFormat",
+)
 
 # CHECK: test_expect_mode_ordering: passed
-test_expect_error("expect_mode_ordering", ("""
+test_expect_error(
+    "expect_mode_ordering",
+    (
+        """
 mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
 f = mlir_pytaco.Format(mode_format_pack, "x")
-    """), "Expected ModeOrdering")
+    """
+    ),
+    "Expected ModeOrdering",
+)
 
 # CHECK: test_inconsistent_mode_format_pack_and_mode_ordering: passed
-test_expect_error("inconsistent_mode_format_pack_and_mode_ordering", ("""
+test_expect_error(
+    "inconsistent_mode_format_pack_and_mode_ordering",
+    (
+        """
 mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
 mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
 f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
-    """), "Inconsistent ModeFormatPack and ModeOrdering")
+    """
+    ),
+    "Inconsistent ModeFormatPack and ModeOrdering",
+)
 
 
 # CHECK-LABEL: test_format_default_ordering
 @testing_utils.run_test
 def test_format_default_ordering():
-  f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
-  passed = 0
-  passed += np.array_equal(f.ordering.ordering, [0, 1])
-  # CHECK: Number of passed: 1
-  print("Number of passed:", passed)
+    f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
+    passed = 0
+    passed += np.array_equal(f.ordering.ordering, [0, 1])
+    # CHECK: Number of passed: 1
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_format_explicit_ordering
 @testing_utils.run_test
 def test_format_explicit_ordering():
-  f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
-  passed = 0
-  passed += np.array_equal(f.ordering.ordering, [1, 0])
-  # CHECK: Number of passed: 1
-  print("Number of passed:", passed)
+    f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
+    passed = 0
+    passed += np.array_equal(f.ordering.ordering, [1, 0])
+    # CHECK: Number of passed: 1
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_index_var
 @testing_utils.run_test
 def test_index_var():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  passed = (i.name != j.name)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    passed = i.name != j.name
 
-  vars = mlir_pytaco.get_index_vars(10)
-  passed += (len(vars) == 10)
-  passed += (all([isinstance(e, mlir_pytaco.IndexVar) for e in vars]))
+    vars = mlir_pytaco.get_index_vars(10)
+    passed += len(vars) == 10
+    passed += all([isinstance(e, mlir_pytaco.IndexVar) for e in vars])
 
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK: test_tensor_invalid_first_argument: passed
-test_expect_error("tensor_invalid_first_argument",
-                  "t = mlir_pytaco.Tensor('f')", "Invalid first argument")
+test_expect_error(
+    "tensor_invalid_first_argument",
+    "t = mlir_pytaco.Tensor('f')",
+    "Invalid first argument",
+)
 
 # CHECK: test_tensor_inconsistent_shape_and_format: passed
-test_expect_error("tensor_inconsistent_shape_and_format", ("""
+test_expect_error(
+    "tensor_inconsistent_shape_and_format",
+    (
+        """
 mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
 mode_ordering = mlir_pytaco.ModeOrdering([0, 1])
 f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
 t = mlir_pytaco.Tensor([3], f)
-    """), "Inconsistent shape and format")
+    """
+    ),
+    "Inconsistent shape and format",
+)
 
 # CHECK: test_tensor_invalid_format: passed
-test_expect_error("tensor_invalid_format", "t = mlir_pytaco.Tensor([3], 'f')",
-                  "Invalid format argument")
+test_expect_error(
+    "tensor_invalid_format",
+    "t = mlir_pytaco.Tensor([3], 'f')",
+    "Invalid format argument",
+)
 
 # CHECK: test_tensor_insert_nonlist_coordinate: passed
-test_expect_error("tensor_insert_nonlist_coordinate", ("""
+test_expect_error(
+    "tensor_insert_nonlist_coordinate",
+    (
+        """
 t = mlir_pytaco.Tensor([3])
 t.insert(1, 0)
-    """), "Non list coordinate detected")
+    """
+    ),
+    "Non list coordinate detected",
+)
 
 # CHECK: test_tensor_insert_too_much_coordinate: passed
-test_expect_error("tensor_insert_too_much_coordinate", ("""
+test_expect_error(
+    "tensor_insert_too_much_coordinate",
+    (
+        """
 t = mlir_pytaco.Tensor([3])
 t.insert([0, 0], 0)
-    """), "Invalid coordinate")
+    """
+    ),
+    "Invalid coordinate",
+)
 
 # CHECK: test_tensor_insert_coordinate_outof_range: passed
-test_expect_error("tensor_insert_coordinate_outof_range", ("""
+test_expect_error(
+    "tensor_insert_coordinate_outof_range",
+    (
+        """
 t = mlir_pytaco.Tensor([1, 1])
 t.insert([1, 0], 0)
-    """), "Invalid coordinate")
+    """
+    ),
+    "Invalid coordinate",
+)
 
 # CHECK: test_tensor_insert_coordinate_nonint: passed
-test_expect_error("tensor_insert_coordinate_nonint", ("""
+test_expect_error(
+    "tensor_insert_coordinate_nonint",
+    (
+        """
 t = mlir_pytaco.Tensor([1, 1])
 t.insert([0, "xy"], 0)
-    """), "Non integer coordinate detected")
+    """
+    ),
+    "Non integer coordinate detected",
+)
 
 # CHECK: test_tensor_insert_invalid_value: passed
-test_expect_error("tensor_insert_invalid_value", ("""
+test_expect_error(
+    "tensor_insert_invalid_value",
+    (
+        """
 t = mlir_pytaco.Tensor([1, 1])
 t.insert([0, 0], "x")
-    """), "Value is neither int nor float")
+    """
+    ),
+    "Value is neither int nor float",
+)
 
 # CHECK: test_access_non_index_var_index: passed
-test_expect_error("access_non_index_var_index", ("""
+test_expect_error(
+    "access_non_index_var_index",
+    (
+        """
 t = mlir_pytaco.Tensor([5, 6])
 i = mlir_pytaco.IndexVar()
 a = mlir_pytaco.Access(t, (i, "j"))
-    """), "Indices contain non IndexVar")
+    """
+    ),
+    "Indices contain non IndexVar",
+)
 
 # CHECK: test_access_inconsistent_rank_indices: passed
-test_expect_error("access_inconsistent_rank_indices", ("""
+test_expect_error(
+    "access_inconsistent_rank_indices",
+    (
+        """
 t = mlir_pytaco.Tensor([5, 6])
 i = mlir_pytaco.IndexVar()
 a = mlir_pytaco.Access(t, (i,))
-    """), "Invalid indices for rank")
+    """
+    ),
+    "Invalid indices for rank",
+)
 
 # CHECK: test_access_invalid_indices_for_rank: passed
-test_expect_error("access_invalid_indices_for_rank", ("""
+test_expect_error(
+    "access_invalid_indices_for_rank",
+    (
+        """
 t = mlir_pytaco.Tensor([5, 6])
 i, j, k = mlir_pytaco.get_index_vars(3)
 a = mlir_pytaco.Access(t, (i,j, k))
-    """), "Invalid indices for rank")
+    """
+    ),
+    "Invalid indices for rank",
+)
 
 # CHECK: test_invalid_indices: passed
-test_expect_error("invalid_indices", ("""
+test_expect_error(
+    "invalid_indices",
+    (
+        """
 i, j = mlir_pytaco.get_index_vars(2)
 A = mlir_pytaco.Tensor([2, 3])
 B = mlir_pytaco.Tensor([2, 3])
 C = mlir_pytaco.Tensor([2, 3], _DENSE)
 C[i, j] = A[1, j] + B[i, j]
-    """), "Expected IndexVars")
+    """
+    ),
+    "Expected IndexVars",
+)
 
 # CHECK: test_inconsistent_rank_indices: passed
-test_expect_error("inconsistent_rank_indices", ("""
+test_expect_error(
+    "inconsistent_rank_indices",
+    (
+        """
 i, j = mlir_pytaco.get_index_vars(2)
 A = mlir_pytaco.Tensor([2, 3])
 C = mlir_pytaco.Tensor([2, 3], _DENSE)
 C[i, j] = A[i]
-    """), "Invalid indices for rank")
+    """
+    ),
+    "Invalid indices for rank",
+)
 
 # CHECK: test_destination_index_not_used_in_source: passed
-test_expect_error("destination_index_not_used_in_source", ("""
+test_expect_error(
+    "destination_index_not_used_in_source",
+    (
+        """
 i, j = mlir_pytaco.get_index_vars(2)
 A = mlir_pytaco.Tensor([3])
 C = mlir_pytaco.Tensor([3], _DENSE)
 C[j] = A[i]
 C.evaluate()
-    """), "Destination IndexVar not used in the source expression")
+    """
+    ),
+    "Destination IndexVar not used in the source expression",
+)
 
 # CHECK: test_destination_dim_not_consistent_with_source: passed
-test_expect_error("destination_dim_not_consistent_with_source", ("""
+test_expect_error(
+    "destination_dim_not_consistent_with_source",
+    (
+        """
 i = mlir_pytaco.IndexVar()
 A = mlir_pytaco.Tensor([3])
 C = mlir_pytaco.Tensor([5], _DENSE)
 C[i] = A[i]
 C.evaluate()
-    """), "Inconsistent destination dimension for IndexVar")
+    """
+    ),
+    "Inconsistent destination dimension for IndexVar",
+)
 
 # CHECK: test_inconsistent_source_dim: passed
-test_expect_error("inconsistent_source_dim", ("""
+test_expect_error(
+    "inconsistent_source_dim",
+    (
+        """
 i = mlir_pytaco.IndexVar()
 A = mlir_pytaco.Tensor([3])
 B = mlir_pytaco.Tensor([5])
 C = mlir_pytaco.Tensor([3], _DENSE)
 C[i] = A[i] + B[i]
 C.evaluate()
-    """), "Inconsistent source dimension for IndexVar")
+    """
+    ),
+    "Inconsistent source dimension for IndexVar",
+)
 
 # CHECK: test_index_var_outside_domain: passed
-test_expect_error("index_var_outside_domain", ("""
+test_expect_error(
+    "index_var_outside_domain",
+    (
+        """
 i, j = mlir_pytaco.get_index_vars(2)
 A = mlir_pytaco.Tensor([3])
 B = mlir_pytaco.Tensor([3])
 B[i] = A[i] + j
 B.evaluate()
-    """), "IndexVar is not part of the iteration domain")
+    """
+    ),
+    "IndexVar is not part of the iteration domain",
+)
 
 
 # CHECK-LABEL: test_tensor_all_dense_sparse
 @testing_utils.run_test
 def test_tensor_all_dense_sparse():
-  a = mlir_pytaco.Tensor([4], [_DENSE])
-  passed = (not a.is_dense())
-  passed += (a.order == 1)
-  passed += (a.shape[0] == 4)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    a = mlir_pytaco.Tensor([4], [_DENSE])
+    passed = not a.is_dense()
+    passed += a.order == 1
+    passed += a.shape[0] == 4
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_tensor_true_dense
 @testing_utils.run_test
 def test_tensor_true_dense():
-  a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
-  passed = a.is_dense()
-  passed += (a.order == 1)
-  passed += (a.shape[0] == 5)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
+    passed = a.is_dense()
+    passed += a.order == 1
+    passed += a.shape[0] == 5
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_tensor_copy
 @testing_utils.run_test
 def test_tensor_copy():
-  i, j = mlir_pytaco.get_index_vars(2)
-  I = 2
-  J = 3
-  A = mlir_pytaco.Tensor([I, J])
-  A.insert([0, 1], 5.0)
-  A.insert([1, 2], 6.0)
-  B = mlir_pytaco.Tensor([I, J])
-  B[i, j] = A[i, j]
-  passed = (B._assignment is not None)
-  passed += (B._engine is None)
-  try:
+    i, j = mlir_pytaco.get_index_vars(2)
+    I = 2
+    J = 3
+    A = mlir_pytaco.Tensor([I, J])
+    A.insert([0, 1], 5.0)
+    A.insert([1, 2], 6.0)
+    B = mlir_pytaco.Tensor([I, J])
+    B[i, j] = A[i, j]
+    passed = B._assignment is not None
+    passed += B._engine is None
+    try:
+        B.compute()
+    except ValueError as e:
+        passed += str(e).startswith("Need to invoke compile")
+    B.compile()
+    passed += B._engine is not None
     B.compute()
-  except ValueError as e:
-    passed += (str(e).startswith("Need to invoke compile"))
-  B.compile()
-  passed += (B._engine is not None)
-  B.compute()
-  passed += (B._assignment is None)
-  passed += (B._engine is None)
-  indices, values = B.get_coordinates_and_values()
-  passed += np.array_equal(indices, [[0, 1], [1, 2]])
-  passed += np.allclose(values, [5.0, 6.0])
-  # No temporary tensor is used.
-  passed += (B._stats.get_total() == 0)
-  # CHECK: Number of passed: 9
-  print("Number of passed:", passed)
+    passed += B._assignment is None
+    passed += B._engine is None
+    indices, values = B.get_coordinates_and_values()
+    passed += np.array_equal(indices, [[0, 1], [1, 2]])
+    passed += np.allclose(values, [5.0, 6.0])
+    # No temporary tensor is used.
+    passed += B._stats.get_total() == 0
+    # CHECK: Number of passed: 9
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_tensor_trivial_reduction
 @testing_utils.run_test
 def test_tensor_trivial_reduction():
-  i, j = mlir_pytaco.get_index_vars(2)
-  I = 2
-  J = 3
-  A = mlir_pytaco.Tensor([I, J])
-  A.insert([0, 1], 5.0)
-  A.insert([0, 2], 3.0)
-  A.insert([1, 2], 6.0)
-  B = mlir_pytaco.Tensor([I])
-  B[i] = A[i, j]
-  indices, values = B.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.allclose(values, [8.0, 6.0])
-  # No temporary tensor is used.
-  passed += (B._stats.get_total() == 0)
-
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i, j = mlir_pytaco.get_index_vars(2)
+    I = 2
+    J = 3
+    A = mlir_pytaco.Tensor([I, J])
+    A.insert([0, 1], 5.0)
+    A.insert([0, 2], 3.0)
+    A.insert([1, 2], 6.0)
+    B = mlir_pytaco.Tensor([I])
+    B[i] = A[i, j]
+    indices, values = B.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.allclose(values, [8.0, 6.0])
+    # No temporary tensor is used.
+    passed += B._stats.get_total() == 0
+
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add
 @testing_utils.run_test
 def test_binary_add():
-  i = mlir_pytaco.IndexVar()
-  A = mlir_pytaco.Tensor([4])
-  B = mlir_pytaco.Tensor([4])
-  C = mlir_pytaco.Tensor([4])
-  A.insert([1], 10)
-  A.insert([2], 1)
-  B.insert([3], 20)
-  B.insert([2], 2)
-  C[i] = A[i] + B[i]
-  indices, values = C.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[1], [2], [3]])
-  passed += np.array_equal(values, [10., 3., 20.])
-  # No temporary tensor is used.
-  passed += (C._stats.get_total() == 0)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    A = mlir_pytaco.Tensor([4])
+    B = mlir_pytaco.Tensor([4])
+    C = mlir_pytaco.Tensor([4])
+    A.insert([1], 10)
+    A.insert([2], 1)
+    B.insert([3], 20)
+    B.insert([2], 2)
+    C[i] = A[i] + B[i]
+    indices, values = C.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[1], [2], [3]])
+    passed += np.array_equal(values, [10.0, 3.0, 20.0])
+    # No temporary tensor is used.
+    passed += C._stats.get_total() == 0
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_sub
 @testing_utils.run_test
 def test_binary_add_sub():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  A = mlir_pytaco.Tensor([2, 3])
-  B = mlir_pytaco.Tensor([2, 3])
-  C = mlir_pytaco.Tensor([2, 3])
-  D = mlir_pytaco.Tensor([2, 3])
-  A.insert([0, 1], 10)
-  A.insert([1, 2], 40)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30)
-  C.insert([0, 1], 5)
-  C.insert([1, 2], 7)
-  D[i, j] = A[i, j] + B[i, j] - C[i, j]
-  indices, values = D.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
-  passed += np.array_equal(values, [20., 5., 63.])
-  # No temporary tensor is used.
-  passed += (D._stats.get_total() == 0)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    A = mlir_pytaco.Tensor([2, 3])
+    B = mlir_pytaco.Tensor([2, 3])
+    C = mlir_pytaco.Tensor([2, 3])
+    D = mlir_pytaco.Tensor([2, 3])
+    A.insert([0, 1], 10)
+    A.insert([1, 2], 40)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30)
+    C.insert([0, 1], 5)
+    C.insert([1, 2], 7)
+    D[i, j] = A[i, j] + B[i, j] - C[i, j]
+    indices, values = D.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+    passed += np.array_equal(values, [20.0, 5.0, 63.0])
+    # No temporary tensor is used.
+    passed += D._stats.get_total() == 0
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_mul_add
 @testing_utils.run_test
 def test_binary_mul_add():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  A = mlir_pytaco.Tensor([2, 3])
-  B = mlir_pytaco.Tensor([2, 3])
-  C = mlir_pytaco.Tensor([2, 3])
-  D = mlir_pytaco.Tensor([2, 3])
-  A.insert([0, 1], 10)
-  A.insert([1, 2], 40)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30)
-  C.insert([0, 1], 5)
-  C.insert([1, 2], 7)
-  D[i, j] = A[i, j] * C[i, j] + B[i, j]
-  indices, values = D.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
-  passed += np.array_equal(values, [20., 50., 310.])
-  # No temporary tensor is used.
-  passed += (D._stats.get_total() == 0)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    A = mlir_pytaco.Tensor([2, 3])
+    B = mlir_pytaco.Tensor([2, 3])
+    C = mlir_pytaco.Tensor([2, 3])
+    D = mlir_pytaco.Tensor([2, 3])
+    A.insert([0, 1], 10)
+    A.insert([1, 2], 40)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30)
+    C.insert([0, 1], 5)
+    C.insert([1, 2], 7)
+    D[i, j] = A[i, j] * C[i, j] + B[i, j]
+    indices, values = D.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+    passed += np.array_equal(values, [20.0, 50.0, 310.0])
+    # No temporary tensor is used.
+    passed += D._stats.get_total() == 0
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_reduce_at_root
 @testing_utils.run_test
 def test_binary_add_reduce_at_root():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  A = mlir_pytaco.Tensor([2, 3])
-  B = mlir_pytaco.Tensor([2, 3])
-  C = mlir_pytaco.Tensor([2], _DENSE)
-  A.insert([0, 1], 10)
-  A.insert([1, 2], 40)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30)
-  C[i] = A[i, j] + B[i, j]
-  indices, values = C.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.array_equal(values, [30., 70.])
-  # No temporary tensor is used.
-  passed += (C._stats.get_total() == 0)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    A = mlir_pytaco.Tensor([2, 3])
+    B = mlir_pytaco.Tensor([2, 3])
+    C = mlir_pytaco.Tensor([2], _DENSE)
+    A.insert([0, 1], 10)
+    A.insert([1, 2], 40)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30)
+    C[i] = A[i, j] + B[i, j]
+    indices, values = C.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.array_equal(values, [30.0, 70.0])
+    # No temporary tensor is used.
+    passed += C._stats.get_total() == 0
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_reduce_at_child
 @testing_utils.run_test
 def test_binary_add_reduce_at_child():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  I = 2
-  J = 3
-  A = mlir_pytaco.Tensor([I, J])
-  B = mlir_pytaco.Tensor([J])
-  C = mlir_pytaco.Tensor([I])
-  D = mlir_pytaco.Tensor([I], _DENSE)
-
-  _init_2d(A, I, J)
-  _init_1d_with_value(C, I, 2)
-  _init_1d_with_value(B, J, 1)
-
-  D[i] = A[i, j] * B[j] + C[i]
-  indices, values = D.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.array_equal(values, [8., 11.])
-
-  # The expression is implemented as:
-  #    temp0[i] = A[i, j] * B[i]
-  #    D[i] = temp0[i] + C[i]
-  # Check the temporary tensor introduced by the implementation.
-  stats = D._stats
-  passed += (stats.get_total() == 1)
-  passed += (stats.get_formats(0) == (_COMPRESSED,))
-  passed += (stats.get_dimensions(0) == (I,))
-  # CHECK: Number of passed: 5
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    I = 2
+    J = 3
+    A = mlir_pytaco.Tensor([I, J])
+    B = mlir_pytaco.Tensor([J])
+    C = mlir_pytaco.Tensor([I])
+    D = mlir_pytaco.Tensor([I], _DENSE)
+
+    _init_2d(A, I, J)
+    _init_1d_with_value(C, I, 2)
+    _init_1d_with_value(B, J, 1)
+
+    D[i] = A[i, j] * B[j] + C[i]
+    indices, values = D.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.array_equal(values, [8.0, 11.0])
+
+    # The expression is implemented as:
+    #    temp0[i] = A[i, j] * B[i]
+    #    D[i] = temp0[i] + C[i]
+    # Check the temporary tensor introduced by the implementation.
+    stats = D._stats
+    passed += stats.get_total() == 1
+    passed += stats.get_formats(0) == (_COMPRESSED,)
+    passed += stats.get_dimensions(0) == (I,)
+    # CHECK: Number of passed: 5
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_reduce_3d_1
 @testing_utils.run_test
 def test_binary_add_reduce_3d_1():
-  i, j, k, l = mlir_pytaco.get_index_vars(4)
-  I = 2
-  J = 3
-  K = 4
-  L = 5
-  A = mlir_pytaco.Tensor([I, J, K])
-  B = mlir_pytaco.Tensor([I, J, L])
-  C = mlir_pytaco.Tensor([K])
-  D = mlir_pytaco.Tensor([L])
-  E = mlir_pytaco.Tensor([I], _DENSE)
-
-  _init_3d(A, I, J, K)
-  _init_3d(B, I, J, L)
-  _init_1d_with_value(C, K, 1)
-  _init_1d_with_value(D, L, 2)
-
-  E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
-  indices, values = E.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.array_equal(values, [162., 204.])
-
-  # The expression is implemented as:
-  #    temp0[i, j] = A[i, j, k] * C[k]
-  #    temp1[i, j] = B[i, j, l] * D[l]
-  #    E[i] = temp0[i, j] + temp1[i, j]
-  # Check the two temporary tensors introduced by the implementation.
-  stats = E._stats
-  passed += (stats.get_total() == 2)
-  passed += (stats.get_formats(0) == (_COMPRESSED, _COMPRESSED))
-  passed += (stats.get_dimensions(0) == (I, J))
-  passed += (stats.get_formats(1) == (_COMPRESSED, _COMPRESSED))
-  passed += (stats.get_dimensions(1) == (I, J))
-  # CHECK: Number of passed: 7
-  print("Number of passed:", passed)
+    i, j, k, l = mlir_pytaco.get_index_vars(4)
+    I = 2
+    J = 3
+    K = 4
+    L = 5
+    A = mlir_pytaco.Tensor([I, J, K])
+    B = mlir_pytaco.Tensor([I, J, L])
+    C = mlir_pytaco.Tensor([K])
+    D = mlir_pytaco.Tensor([L])
+    E = mlir_pytaco.Tensor([I], _DENSE)
+
+    _init_3d(A, I, J, K)
+    _init_3d(B, I, J, L)
+    _init_1d_with_value(C, K, 1)
+    _init_1d_with_value(D, L, 2)
+
+    E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
+    indices, values = E.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.array_equal(values, [162.0, 204.0])
+
+    # The expression is implemented as:
+    #    temp0[i, j] = A[i, j, k] * C[k]
+    #    temp1[i, j] = B[i, j, l] * D[l]
+    #    E[i] = temp0[i, j] + temp1[i, j]
+    # Check the two temporary tensors introduced by the implementation.
+    stats = E._stats
+    passed += stats.get_total() == 2
+    passed += stats.get_formats(0) == (_COMPRESSED, _COMPRESSED)
+    passed += stats.get_dimensions(0) == (I, J)
+    passed += stats.get_formats(1) == (_COMPRESSED, _COMPRESSED)
+    passed += stats.get_dimensions(1) == (I, J)
+    # CHECK: Number of passed: 7
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_reduce_3d_2
 @testing_utils.run_test
 def test_binary_add_reduce_3d_2():
-  i, j, k, l = mlir_pytaco.get_index_vars(4)
-  I = 2
-  J = 3
-  K = 4
-  L = 5
-  A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
-  B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
-  C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
-  D = mlir_pytaco.Tensor([L])
-  E = mlir_pytaco.Tensor([I], _DENSE)
-
-  _init_3d(A, I, J, K)
-  _init_3d(B, I, L, K)
-  _init_2d(C, J, K)
-  _init_1d_with_value(D, L, 2)
-
-  E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
-  indices, values = E.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.array_equal(values, [264., 316.])
-
-  # The expression is implemented as:
-  #    temp0[i, k] = A[i, j, k] + C[j, k]
-  #    temp1[i, k] = B[i, l, k] * D[l]
-  #    E[i] = temp0[i, k] + temp1[i, k]
-  # Check the two temporary tensors introduced by the implementation.
-  stats = E._stats
-  passed += (stats.get_total() == 2)
-  passed += (stats.get_formats(0) == (_COMPRESSED, _DENSE))
-  passed += (stats.get_dimensions(0) == (I, K))
-  passed += (stats.get_formats(1) == (_DENSE, _COMPRESSED))
-  passed += (stats.get_dimensions(1) == (I, K))
-  # CHECK: Number of passed: 7
-  print("Number of passed:", passed)
+    i, j, k, l = mlir_pytaco.get_index_vars(4)
+    I = 2
+    J = 3
+    K = 4
+    L = 5
+    A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
+    B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
+    C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
+    D = mlir_pytaco.Tensor([L])
+    E = mlir_pytaco.Tensor([I], _DENSE)
+
+    _init_3d(A, I, J, K)
+    _init_3d(B, I, L, K)
+    _init_2d(C, J, K)
+    _init_1d_with_value(D, L, 2)
+
+    E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
+    indices, values = E.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.array_equal(values, [264.0, 316.0])
+
+    # The expression is implemented as:
+    #    temp0[i, k] = A[i, j, k] + C[j, k]
+    #    temp1[i, k] = B[i, l, k] * D[l]
+    #    E[i] = temp0[i, k] + temp1[i, k]
+    # Check the two temporary tensors introduced by the implementation.
+    stats = E._stats
+    passed += stats.get_total() == 2
+    passed += stats.get_formats(0) == (_COMPRESSED, _DENSE)
+    passed += stats.get_dimensions(0) == (I, K)
+    passed += stats.get_formats(1) == (_DENSE, _COMPRESSED)
+    passed += stats.get_dimensions(1) == (I, K)
+    # CHECK: Number of passed: 7
+    print("Number of passed:", passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
index cce97d6ef1e25..1d5274759b6a9 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
@@ -32,21 +32,21 @@
 # CHECK-LABEL: test_read_mtx_matrix_general
 @testing_utils.run_test
 def test_read_mtx_matrix_general():
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.mtx")
-    with open(file_name, "w") as file:
-      file.write(_MTX_DATA)
-    a = mlir_pytaco_io.read(file_name, _FORMAT)
-  passed = 0
-  # The value of a is stored as an MLIR sparse tensor.
-  passed += (not a.is_unpacked())
-  a.unpack()
-  passed += (a.is_unpacked())
-  coords, values = a.get_coordinates_and_values()
-  passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
-  passed += np.allclose(values, [2.0, 3.0, 4.0])
-  # CHECK: 4
-  print(passed)
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.mtx")
+        with open(file_name, "w") as file:
+            file.write(_MTX_DATA)
+        a = mlir_pytaco_io.read(file_name, _FORMAT)
+    passed = 0
+    # The value of a is stored as an MLIR sparse tensor.
+    passed += not a.is_unpacked()
+    a.unpack()
+    passed += a.is_unpacked()
+    coords, values = a.get_coordinates_and_values()
+    passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
+    passed += np.allclose(values, [2.0, 3.0, 4.0])
+    # CHECK: 4
+    print(passed)
 
 
 _TNS_DATA = """2 3
@@ -60,57 +60,57 @@ def test_read_mtx_matrix_general():
 # CHECK-LABEL: test_read_tns
 @testing_utils.run_test
 def test_read_tns():
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.tns")
-    with open(file_name, "w") as file:
-      file.write(_TNS_DATA)
-    a = mlir_pytaco_io.read(file_name, _FORMAT)
-  passed = 0
-  # The value of a is stored as an MLIR sparse tensor.
-  passed += (not a.is_unpacked())
-  a.unpack()
-  passed += (a.is_unpacked())
-  coords, values = a.get_coordinates_and_values()
-  passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
-  passed += np.allclose(values, [2.0, 3.0, 4.0])
-  # CHECK: 4
-  print(passed)
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.tns")
+        with open(file_name, "w") as file:
+            file.write(_TNS_DATA)
+        a = mlir_pytaco_io.read(file_name, _FORMAT)
+    passed = 0
+    # The value of a is stored as an MLIR sparse tensor.
+    passed += not a.is_unpacked()
+    a.unpack()
+    passed += a.is_unpacked()
+    coords, values = a.get_coordinates_and_values()
+    passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
+    passed += np.allclose(values, [2.0, 3.0, 4.0])
+    # CHECK: 4
+    print(passed)
 
 
 # CHECK-LABEL: test_write_unpacked_tns
 @testing_utils.run_test
 def test_write_unpacked_tns():
-  a = mlir_pytaco.Tensor([2, 3])
-  a.insert([0, 1], 10)
-  a.insert([1, 2], 40)
-  a.insert([0, 0], 20)
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.tns")
-    try:
-      mlir_pytaco_io.write(file_name, a)
-    except ValueError as e:
-      # CHECK: Writing unpacked sparse tensors to file is not supported
-      print(e)
+    a = mlir_pytaco.Tensor([2, 3])
+    a.insert([0, 1], 10)
+    a.insert([1, 2], 40)
+    a.insert([0, 0], 20)
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.tns")
+        try:
+            mlir_pytaco_io.write(file_name, a)
+        except ValueError as e:
+            # CHECK: Writing unpacked sparse tensors to file is not supported
+            print(e)
 
 
 # CHECK-LABEL: test_write_packed_tns
 @testing_utils.run_test
 def test_write_packed_tns():
-  a = mlir_pytaco.Tensor([2, 3])
-  a.insert([0, 1], 10)
-  a.insert([1, 2], 40)
-  a.insert([0, 0], 20)
-  b = mlir_pytaco.Tensor([2, 3])
-  i, j = mlir_pytaco.get_index_vars(2)
-  b[i, j] = a[i, j] + a[i, j]
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.tns")
-    mlir_pytaco_io.write(file_name, b)
-    with open(file_name, "r") as file:
-      lines = file.readlines()
-  passed = 0
-  # Skip the comment line in the output.
-  if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
-    passed = 1
-  # CHECK: 1
-  print(passed)
+    a = mlir_pytaco.Tensor([2, 3])
+    a.insert([0, 1], 10)
+    a.insert([1, 2], 40)
+    a.insert([0, 0], 20)
+    b = mlir_pytaco.Tensor([2, 3])
+    i, j = mlir_pytaco.get_index_vars(2)
+    b[i, j] = a[i, j] + a[i, j]
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.tns")
+        mlir_pytaco_io.write(file_name, b)
+        with open(file_name, "r") as file:
+            lines = file.readlines()
+    passed = 0
+    # Skip the comment line in the output.
+    if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
+        passed = 1
+    # CHECK: 1
+    print(passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py
index 13259698b1b12..1344f4aa741ab 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py
@@ -20,79 +20,93 @@
 
 
 def _to_string(s: Sequence[int]) -> str:
-  """Converts a sequence of integer to a space separated value string."""
-  return " ".join(map(lambda e: str(e), s))
+    """Converts a sequence of integer to a space separated value string."""
+    return " ".join(map(lambda e: str(e), s))
 
 
 def _add_one(s: Sequence[int]) -> Sequence[int]:
-  """Adds one to each element in the sequence of integer."""
-  return [i + 1 for i in s]
+    """Adds one to each element in the sequence of integer."""
+    return [i + 1 for i in s]
 
 
 @dataclasses.dataclass(frozen=True)
 class _SparseTensorCOO:
-  """Values for a COO-flavored format sparse tensor.
-
-  Attributes:
-    rank: An integer rank for the tensor.
-    nse: An integer for the number of non-zero values.
-    shape: A sequence of integer for the dimension size.
-    values: A sequence of float for the non-zero values of the tensor.
-    indices: A sequence of coordinate, each coordinate is a sequence of integer.
-  """
-  rank: int
-  nse: int
-  shape: Sequence[int]
-  values: Sequence[float]
-  indices: Sequence[Sequence[int]]
+    """Values for a COO-flavored format sparse tensor.
+
+    Attributes:
+      rank: An integer rank for the tensor.
+      nse: An integer for the number of non-zero values.
+      shape: A sequence of integer for the dimension size.
+      values: A sequence of float for the non-zero values of the tensor.
+      indices: A sequence of coordinate, each coordinate is a sequence of integer.
+    """
+
+    rank: int
+    nse: int
+    shape: Sequence[int]
+    values: Sequence[float]
+    indices: Sequence[Sequence[int]]
 
 
 def _coo_values_to_tns_format(t: _SparseTensorCOO) -> str:
-  """Converts a sparse tensor COO-flavored values to TNS text format."""
-  # The coo_value_str contains one line for each (coordinate value) pair.
-  # Indices are 1-based in TNS text format but 0-based in MLIR.
-  coo_value_str = "\n".join(
-      map(lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
-          range(t.nse)))
-
-  # Returns the TNS text format representation for the tensor.
-  return f"""{t.rank} {t.nse}
+    """Converts a sparse tensor COO-flavored values to TNS text format."""
+    # The coo_value_str contains one line for each (coordinate value) pair.
+    # Indices are 1-based in TNS text format but 0-based in MLIR.
+    coo_value_str = "\n".join(
+        map(
+            lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
+            range(t.nse),
+        )
+    )
+
+    # Returns the TNS text format representation for the tensor.
+    return f"""{t.rank} {t.nse}
 {_to_string(t.shape)}
 {coo_value_str}
 """
 
 
 def _implement_read_tns_test(
-    t: _SparseTensorCOO,
-    sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> int:
-  tns_data = _coo_values_to_tns_format(t)
-
-  # Write sparse tensor data to a file.
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.tns")
-    with open(file_name, "w") as file:
-      file.write(tns_data)
-
-    # Read the data from the file and construct an MLIR sparse tensor.
-    sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
-        file_name, sparsity_codes, "f64")
-
-  passed = 0
-
-  # Verify the output shape for the tensor.
-  if np.array_equal(o_shape, t.shape):
-    passed += 1
-
-  # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
-  # values and verify the values.
-  o_rank, o_nse, o_shape, o_values, o_indices = (
-      pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64))
-  if o_rank == t.rank and o_nse == t.nse and np.array_equal(
-      o_shape, t.shape) and np.allclose(o_values, t.values) and np.array_equal(
-          o_indices, t.indices):
-    passed += 1
-
-  return passed
+    t: _SparseTensorCOO, sparsity_codes: Sequence[sparse_tensor.DimLevelType]
+) -> int:
+    tns_data = _coo_values_to_tns_format(t)
+
+    # Write sparse tensor data to a file.
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.tns")
+        with open(file_name, "w") as file:
+            file.write(tns_data)
+
+        # Read the data from the file and construct an MLIR sparse tensor.
+        sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
+            file_name, sparsity_codes, "f64"
+        )
+
+    passed = 0
+
+    # Verify the output shape for the tensor.
+    if np.array_equal(o_shape, t.shape):
+        passed += 1
+
+    # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
+    # values and verify the values.
+    (
+        o_rank,
+        o_nse,
+        o_shape,
+        o_values,
+        o_indices,
+    ) = pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64)
+    if (
+        o_rank == t.rank
+        and o_nse == t.nse
+        and np.array_equal(o_shape, t.shape)
+        and np.allclose(o_values, t.values)
+        and np.array_equal(o_indices, t.indices)
+    ):
+        passed += 1
+
+    return passed
 
 
 # A 2D sparse tensor data in COO-flavored format.

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
index 12c97dbfa61ba..70b4b66f4378d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
@@ -5,11 +5,11 @@ if not config.mlir_run_amx_tests:
     config.unsupported = True
 
 # No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True
 
 if config.intel_sde_executable:
     # Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU.
-    config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli'))
+    config.substitutions.append(("%lli", config.intel_sde_executable + " -spr -- lli"))
 else:
-    config.substitutions.append(('%lli', 'lli'))
+    config.substitutions.append(("%lli", "lli"))

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg
index 0423fc03da5f9..296b4419438e8 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg
@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sme_tests:
     config.unsupported = True
 
 # No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg
index 8a0d884509c80..37d3a74874ce4 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg
@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sve_tests:
     config.unsupported = True
 
 # No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg
index 0e22874db1d18..bde815616b2db 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg
@@ -5,11 +5,11 @@ if not config.mlir_run_x86vector_tests:
     config.unsupported = True
 
 # No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True
 
 if config.intel_sde_executable:
     # Run test in emulator (Intel SDE).
-    config.substitutions.append(('%lli', config.intel_sde_executable + ' -tgl -- lli'))
+    config.substitutions.append(("%lli", config.intel_sde_executable + " -tgl -- lli"))
 else:
-    config.substitutions.append(('%lli', 'lli'))
+    config.substitutions.append(("%lli", "lli"))

diff  --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg
index 0bdebfedeee36..acb8dd43f50b4 100644
--- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg
@@ -1,2 +1,2 @@
 if not config.enable_cuda_runner:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg
index 451b9fcbed3df..3bd70245df7df 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg
@@ -2,4 +2,4 @@ import sys
 
 # TensorCore tests must be enabled via build flag.
 if not config.mlir_run_cuda_tensor_core_tests:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/Integration/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/lit.local.cfg
index 0bdebfedeee36..acb8dd43f50b4 100644
--- a/mlir/test/Integration/GPU/CUDA/lit.local.cfg
+++ b/mlir/test/Integration/GPU/CUDA/lit.local.cfg
@@ -1,2 +1,2 @@
 if not config.enable_cuda_runner:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/Integration/GPU/ROCM/lit.local.cfg b/mlir/test/Integration/GPU/ROCM/lit.local.cfg
index b0d086f9d4d51..e1f864857c5c1 100644
--- a/mlir/test/Integration/GPU/ROCM/lit.local.cfg
+++ b/mlir/test/Integration/GPU/ROCM/lit.local.cfg
@@ -1,4 +1,4 @@
 if not config.enable_rocm_runner or not config.rocm_test_chipset:
-  config.unsupported = True
+    config.unsupported = True
 
-config.substitutions.append(('%chip', config.rocm_test_chipset))
+config.substitutions.append(("%chip", config.rocm_test_chipset))

diff  --git a/mlir/test/Integration/lit.local.cfg b/mlir/test/Integration/lit.local.cfg
index 80a862a1ce664..1b4a323871d75 100644
--- a/mlir/test/Integration/lit.local.cfg
+++ b/mlir/test/Integration/lit.local.cfg
@@ -3,8 +3,9 @@ from lit.llvm import llvm_config
 if not config.mlir_include_integration_tests:
     config.unsupported = True
 
+
 def configure_aarch64_lli_cmd():
-    lli_cmd = 'lli'
+    lli_cmd = "lli"
 
     # NOTE: If the SVE tests are disabled and the SME tests are enabled to run
     # under emulation, the SVE specific RUN lines in the SparseTensor tests
@@ -12,8 +13,12 @@ def configure_aarch64_lli_cmd():
     if not (config.mlir_run_arm_sve_tests or config.mlir_run_arm_sme_tests):
         return lli_cmd
 
-    config.substitutions.append(('%mlir_native_utils_lib_dir',
-        config.arm_emulator_utils_lib_dir or config.mlir_lib_dir))
+    config.substitutions.append(
+        (
+            "%mlir_native_utils_lib_dir",
+            config.arm_emulator_utils_lib_dir or config.mlir_lib_dir,
+        )
+    )
 
     if config.arm_emulator_executable:
         if config.arm_emulator_lli_executable:
@@ -23,16 +28,22 @@ def configure_aarch64_lli_cmd():
             # when running under an emulator. If the user didn't specify an lli
             # executable, use absolute path %llvm_tools_dir/lli.
             lli_cmd = llvm_config.use_llvm_tool(
-                'lli', search_env='LLI', required=True,
-                search_paths=[config.llvm_tools_dir], use_installed=False
+                "lli",
+                search_env="LLI",
+                required=True,
+                search_paths=[config.llvm_tools_dir],
+                use_installed=False,
             )
 
         # Run test in emulator (qemu or armie)
-        emulation_cmd = f'{config.arm_emulator_executable} {config.arm_emulator_options}'
-        lli_cmd = f'{emulation_cmd} {lli_cmd}'
+        emulation_cmd = (
+            f"{config.arm_emulator_executable} {config.arm_emulator_options}"
+        )
+        lli_cmd = f"{emulation_cmd} {lli_cmd}"
 
     return lli_cmd
 
+
 aarch64_lli_cmd = configure_aarch64_lli_cmd()
 
 # Configure the following AArch64 substitutions:
@@ -52,5 +63,5 @@ aarch64_lli_cmd = configure_aarch64_lli_cmd()
 # could be used in the SparseTensor tests where necessary, but the meaning
 # conveyed by the substitution name would be a misnomer if the host target
 # is not AArch64 and MLIR_RUN_ARM_SVE_TESTS=OFF.
-config.substitutions.append(('%lli_aarch64_cmd', aarch64_lli_cmd))
-config.substitutions.append(('%lli_host_or_aarch64_cmd', aarch64_lli_cmd))
+config.substitutions.append(("%lli_aarch64_cmd", aarch64_lli_cmd))
+config.substitutions.append(("%lli_host_or_aarch64_cmd", aarch64_lli_cmd))

diff  --git a/mlir/test/Unit/lit.cfg.py b/mlir/test/Unit/lit.cfg.py
index 5b66517b1788e..1898b72adf002 100644
--- a/mlir/test/Unit/lit.cfg.py
+++ b/mlir/test/Unit/lit.cfg.py
@@ -8,43 +8,43 @@
 import lit.formats
 
 # name: The name of this test suite.
-config.name = 'MLIR-Unit'
+config.name = "MLIR-Unit"
 
 # suffixes: A list of file extensions to treat as test files.
 config.suffixes = []
 
 # test_source_root: The root path where tests are located.
 # test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.mlir_obj_root, 'unittests')
+config.test_exec_root = os.path.join(config.mlir_obj_root, "unittests")
 config.test_source_root = config.test_exec_root
 
 # testFormat: The test format to use to interpret tests.
-config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, 'Tests')
+config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, "Tests")
 
 # Propagate the temp directory. Windows requires this because it uses \Windows\
 # if none of these are present.
-if 'TMP' in os.environ:
-    config.environment['TMP'] = os.environ['TMP']
-if 'TEMP' in os.environ:
-    config.environment['TEMP'] = os.environ['TEMP']
+if "TMP" in os.environ:
+    config.environment["TMP"] = os.environ["TMP"]
+if "TEMP" in os.environ:
+    config.environment["TEMP"] = os.environ["TEMP"]
 
 # Propagate HOME as it can be used to override incorrect homedir in passwd
 # that causes the tests to fail.
-if 'HOME' in os.environ:
-    config.environment['HOME'] = os.environ['HOME']
+if "HOME" in os.environ:
+    config.environment["HOME"] = os.environ["HOME"]
 
 # Propagate sanitizer options.
 for var in [
-    'ASAN_SYMBOLIZER_PATH',
-    'HWASAN_SYMBOLIZER_PATH',
-    'MSAN_SYMBOLIZER_PATH',
-    'TSAN_SYMBOLIZER_PATH',
-    'UBSAN_SYMBOLIZER_PATH',
-    'ASAN_OPTIONS',
-    'HWASAN_OPTIONS',
-    'MSAN_OPTIONS',
-    'TSAN_OPTIONS',
-    'UBSAN_OPTIONS',
+    "ASAN_SYMBOLIZER_PATH",
+    "HWASAN_SYMBOLIZER_PATH",
+    "MSAN_SYMBOLIZER_PATH",
+    "TSAN_SYMBOLIZER_PATH",
+    "UBSAN_SYMBOLIZER_PATH",
+    "ASAN_OPTIONS",
+    "HWASAN_OPTIONS",
+    "MSAN_OPTIONS",
+    "TSAN_OPTIONS",
+    "UBSAN_OPTIONS",
 ]:
     if var in os.environ:
         config.environment[var] = os.environ[var]

diff  --git a/mlir/test/lib/Dialect/Test/lit.local.cfg b/mlir/test/lib/Dialect/Test/lit.local.cfg
index edb5b44b2e2fe..65a7f202dc82a 100644
--- a/mlir/test/lib/Dialect/Test/lit.local.cfg
+++ b/mlir/test/lib/Dialect/Test/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.remove('.td')
\ No newline at end of file
+config.suffixes.remove(".td")

diff  --git a/mlir/test/lib/Dialect/Transform/lit.local.cfg b/mlir/test/lib/Dialect/Transform/lit.local.cfg
index edb5b44b2e2fe..65a7f202dc82a 100644
--- a/mlir/test/lib/Dialect/Transform/lit.local.cfg
+++ b/mlir/test/lib/Dialect/Transform/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.remove('.td')
\ No newline at end of file
+config.suffixes.remove(".td")

diff  --git a/mlir/test/lib/Tools/PDLL/lit.local.cfg b/mlir/test/lib/Tools/PDLL/lit.local.cfg
index 8cfe5cd834f06..8ffccee1d6d79 100644
--- a/mlir/test/lib/Tools/PDLL/lit.local.cfg
+++ b/mlir/test/lib/Tools/PDLL/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.remove('.pdll')
+config.suffixes.remove(".pdll")

diff  --git a/mlir/test/lib/Transforms/lit.local.cfg b/mlir/test/lib/Transforms/lit.local.cfg
index 8cfe5cd834f06..8ffccee1d6d79 100644
--- a/mlir/test/lib/Transforms/lit.local.cfg
+++ b/mlir/test/lib/Transforms/lit.local.cfg
@@ -1 +1 @@
-config.suffixes.remove('.pdll')
+config.suffixes.remove(".pdll")

diff  --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 1fc2e319d19fd..ad0b0d5567779 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -16,21 +16,32 @@
 # Configuration file for the 'lit' test runner.
 
 # name: The name of this test suite.
-config.name = 'MLIR'
+config.name = "MLIR"
 
 config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
 
 # suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test', '.pdll', '.c']
+config.suffixes = [
+    ".td",
+    ".mlir",
+    ".toy",
+    ".ll",
+    ".tc",
+    ".py",
+    ".yaml",
+    ".test",
+    ".pdll",
+    ".c",
+]
 
 # test_source_root: The root path where tests are located.
 config.test_source_root = os.path.dirname(__file__)
 
 # test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.mlir_obj_root, 'test')
+config.test_exec_root = os.path.join(config.mlir_obj_root, "test")
 
-config.substitutions.append(('%PATH%', config.environment['PATH']))
-config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
+config.substitutions.append(("%PATH%", config.environment["PATH"]))
+config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
 config.substitutions.append(("%mlir_src_root", config.mlir_src_root))
 config.substitutions.append(("%host_cxx", config.host_cxx))
 config.substitutions.append(("%host_cc", config.host_cc))
@@ -40,94 +51,109 @@
 # substitution of the same name and the found path.
 # Correctly handles the platforms shared library directory and naming conventions.
 def add_runtime(name):
-    path = ''
-    for prefix in ['', 'lib']:
-        path = os.path.join(config.llvm_shlib_dir, f'{prefix}{name}{config.llvm_shlib_ext}')
+    path = ""
+    for prefix in ["", "lib"]:
+        path = os.path.join(
+            config.llvm_shlib_dir, f"{prefix}{name}{config.llvm_shlib_ext}"
+        )
         if os.path.isfile(path):
             break
-    return ToolSubst(f'%{name}', path)
+    return ToolSubst(f"%{name}", path)
 
 
-llvm_config.with_system_environment(
-    ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
+llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
 
 llvm_config.use_default_substitutions()
 
 # excludes: A list of directories to exclude from the testsuite. The 'Inputs'
 # subdirectories contain auxiliary inputs for various tests in their parent
 # directories.
-config.excludes = ['Inputs', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
-                   'lit.cfg.py', 'lit.site.cfg.py']
+config.excludes = [
+    "Inputs",
+    "CMakeLists.txt",
+    "README.txt",
+    "LICENSE.txt",
+    "lit.cfg.py",
+    "lit.site.cfg.py",
+]
 
 # Tweak the PATH to include the tools dir.
-llvm_config.with_environment('PATH', config.mlir_tools_dir, append_path=True)
-llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
+llvm_config.with_environment("PATH", config.mlir_tools_dir, append_path=True)
+llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
 
 tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
 tools = [
-    'mlir-tblgen',
-    'mlir-translate',
-    'mlir-lsp-server',
-    'mlir-capi-execution-engine-test',
-    'mlir-capi-ir-test',
-    'mlir-capi-llvm-test',
-    'mlir-capi-pass-test',
-    'mlir-capi-pdl-test',
-    'mlir-capi-quant-test',
-    'mlir-capi-sparse-tensor-test',
-    'mlir-capi-transform-test',
-    'mlir-cpu-runner',
-    add_runtime('mlir_runner_utils'),
-    add_runtime('mlir_c_runner_utils'),
-    add_runtime('mlir_async_runtime'),
-    'mlir-linalg-ods-yaml-gen',
-    'mlir-reduce',
-    'mlir-pdll',
-    'not',
+    "mlir-tblgen",
+    "mlir-translate",
+    "mlir-lsp-server",
+    "mlir-capi-execution-engine-test",
+    "mlir-capi-ir-test",
+    "mlir-capi-llvm-test",
+    "mlir-capi-pass-test",
+    "mlir-capi-pdl-test",
+    "mlir-capi-quant-test",
+    "mlir-capi-sparse-tensor-test",
+    "mlir-capi-transform-test",
+    "mlir-cpu-runner",
+    add_runtime("mlir_runner_utils"),
+    add_runtime("mlir_c_runner_utils"),
+    add_runtime("mlir_async_runtime"),
+    "mlir-linalg-ods-yaml-gen",
+    "mlir-reduce",
+    "mlir-pdll",
+    "not",
 ]
 
 if config.enable_spirv_cpu_runner:
-    tools.extend(['mlir-spirv-cpu-runner', add_runtime('mlir_test_spirv_cpu_runner_c_wrappers')])
+    tools.extend(
+        ["mlir-spirv-cpu-runner", add_runtime("mlir_test_spirv_cpu_runner_c_wrappers")]
+    )
 
 if config.enable_vulkan_runner:
-    tools.extend([add_runtime('vulkan-runtime-wrappers')])
+    tools.extend([add_runtime("vulkan-runtime-wrappers")])
 
 if config.enable_rocm_runner:
-    tools.extend([add_runtime('mlir_rocm_runtime')])
+    tools.extend([add_runtime("mlir_rocm_runtime")])
 
 if config.enable_cuda_runner:
-    tools.extend([add_runtime('mlir_cuda_runtime')])
+    tools.extend([add_runtime("mlir_cuda_runtime")])
 
 # The following tools are optional
-tools.extend([
-    ToolSubst('toyc-ch1', unresolved='ignore'),
-    ToolSubst('toyc-ch2', unresolved='ignore'),
-    ToolSubst('toyc-ch3', unresolved='ignore'),
-    ToolSubst('toyc-ch4', unresolved='ignore'),
-    ToolSubst('toyc-ch5', unresolved='ignore'),
-    ToolSubst('toyc-ch6', unresolved='ignore'),
-    ToolSubst('toyc-ch7', unresolved='ignore'),
-    ToolSubst('%mlir_lib_dir', config.mlir_lib_dir, unresolved='ignore'),
-    ToolSubst('%mlir_src_dir', config.mlir_src_root, unresolved='ignore'),
-])
+tools.extend(
+    [
+        ToolSubst("toyc-ch1", unresolved="ignore"),
+        ToolSubst("toyc-ch2", unresolved="ignore"),
+        ToolSubst("toyc-ch3", unresolved="ignore"),
+        ToolSubst("toyc-ch4", unresolved="ignore"),
+        ToolSubst("toyc-ch5", unresolved="ignore"),
+        ToolSubst("toyc-ch6", unresolved="ignore"),
+        ToolSubst("toyc-ch7", unresolved="ignore"),
+        ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"),
+        ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"),
+    ]
+)
 
 python_executable = config.python_executable
 # Python configuration with sanitizer requires some magic preloading. This will only work on clang/linux.
 # TODO: detect Darwin/Windows situation (or mark these tests as unsupported on these platforms).
 if "asan" in config.available_features and "Linux" in config.host_os:
-  python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
+    python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
 # On Windows the path to python could contains spaces in which case it needs to be provided in quotes.
 # This is the equivalent of how %python is setup in llvm/utils/lit/lit/llvm/config.py.
 elif "Windows" in config.host_os:
-  python_executable = '"%s"' % (python_executable)
-tools.extend([
-  ToolSubst('%PYTHON', python_executable, unresolved='ignore'),
-])
+    python_executable = '"%s"' % (python_executable)
+tools.extend(
+    [
+        ToolSubst("%PYTHON", python_executable, unresolved="ignore"),
+    ]
+)
 
 if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ:
-  tools.extend([
-    ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'),
-  ])
+    tools.extend(
+        [
+            ToolSubst("mlir-opt", "mlir-opt --verify-roundtrip", unresolved="fatal"),
+        ]
+    )
 
 llvm_config.add_tool_substitutions(tools, tool_dirs)
 
@@ -135,40 +161,48 @@ def add_runtime(name):
 # FileCheck -enable-var-scope is enabled by default in MLIR test
 # This option avoids to accidentally reuse variable across -LABEL match,
 # it can be explicitly opted-in by prefixing the variable name with $
-config.environment['FILECHECK_OPTS'] = "-enable-var-scope --allow-unused-prefixes=false"
+config.environment["FILECHECK_OPTS"] = "-enable-var-scope --allow-unused-prefixes=false"
 
 # Add the python path for both the source and binary tree.
 # Note that presently, the python sources come from the source tree and the
 # binaries come from the build tree. This should be unified to the build tree
 # by copying/linking sources to build.
 if config.enable_bindings_python:
-  llvm_config.with_environment('PYTHONPATH', [
-      os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_core'),
-      os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_test'),
-  ], append_path=True)
+    llvm_config.with_environment(
+        "PYTHONPATH",
+        [
+            os.path.join(config.mlir_obj_root, "python_packages", "mlir_core"),
+            os.path.join(config.mlir_obj_root, "python_packages", "mlir_test"),
+        ],
+        append_path=True,
+    )
 
 if config.enable_assertions:
-  config.available_features.add('asserts')
+    config.available_features.add("asserts")
 else:
-  config.available_features.add('noasserts')
+    config.available_features.add("noasserts")
+
 
 def have_host_jit_feature_support(feature_name):
-  mlir_cpu_runner_exe = lit.util.which('mlir-cpu-runner', config.mlir_tools_dir)
+    mlir_cpu_runner_exe = lit.util.which("mlir-cpu-runner", config.mlir_tools_dir)
+
+    if not mlir_cpu_runner_exe:
+        return False
 
-  if not mlir_cpu_runner_exe:
-    return False
+    try:
+        mlir_cpu_runner_cmd = subprocess.Popen(
+            [mlir_cpu_runner_exe, "--host-supports-" + feature_name],
+            stdout=subprocess.PIPE,
+        )
+    except OSError:
+        print("could not exec mlir-cpu-runner")
+        return False
 
-  try:
-    mlir_cpu_runner_cmd = subprocess.Popen(
-        [mlir_cpu_runner_exe, '--host-supports-' + feature_name], stdout=subprocess.PIPE)
-  except OSError:
-    print('could not exec mlir-cpu-runner')
-    return False
+    mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode("ascii")
+    mlir_cpu_runner_cmd.wait()
 
-  mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode('ascii')
-  mlir_cpu_runner_cmd.wait()
+    return "true" in mlir_cpu_runner_out
 
-  return 'true' in mlir_cpu_runner_out
 
-if have_host_jit_feature_support('jit'):
-  config.available_features.add('host-supports-jit')
+if have_host_jit_feature_support("jit"):
+    config.available_features.add("host-supports-jit")

diff  --git a/mlir/test/mlir-cpu-runner/lit.local.cfg b/mlir/test/mlir-cpu-runner/lit.local.cfg
index 3f59ff1bd9774..3c20d203b800f 100644
--- a/mlir/test/mlir-cpu-runner/lit.local.cfg
+++ b/mlir/test/mlir-cpu-runner/lit.local.cfg
@@ -1,12 +1,11 @@
 import sys
 
 # MSAN does not work with JIT.
-if 'msan' in config.available_features:
-  config.unsupported = True
+if "msan" in config.available_features:
+    config.unsupported = True
 
 # Requires native execution.
-if 'host-supports-jit' not in config.available_features:
+if "host-supports-jit" not in config.available_features:
     config.unsupported = True
 
-config.available_features.add(
-        config.root.native_target.lower() + '-native-target')
+config.available_features.add(config.root.native_target.lower() + "-native-target")

diff  --git a/mlir/test/mlir-pdll-lsp-server/lit.local.cfg b/mlir/test/mlir-pdll-lsp-server/lit.local.cfg
index 25d08c7aba306..aa35dbfa8c01f 100644
--- a/mlir/test/mlir-pdll-lsp-server/lit.local.cfg
+++ b/mlir/test/mlir-pdll-lsp-server/lit.local.cfg
@@ -1 +1 @@
-config.excludes = ['include']
+config.excludes = ["include"]

diff  --git a/mlir/test/mlir-pdll/lit.local.cfg b/mlir/test/mlir-pdll/lit.local.cfg
index c438027edc2c4..4cb5622aaa761 100644
--- a/mlir/test/mlir-pdll/lit.local.cfg
+++ b/mlir/test/mlir-pdll/lit.local.cfg
@@ -1,2 +1,2 @@
-config.suffixes = ['.pdll', '.mlir']
-config.excludes = ['include']
+config.suffixes = [".pdll", ".mlir"]
+config.excludes = ["include"]

diff  --git a/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg b/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
index 286bea4cace67..8717dd025498e 100644
--- a/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
+++ b/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
@@ -1,4 +1,4 @@
 import sys
 
 if not config.enable_spirv_cpu_runner:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/mlir-vulkan-runner/lit.local.cfg b/mlir/test/mlir-vulkan-runner/lit.local.cfg
index f99be2aeb70e9..6da7fcdd610aa 100644
--- a/mlir/test/mlir-vulkan-runner/lit.local.cfg
+++ b/mlir/test/mlir-vulkan-runner/lit.local.cfg
@@ -1,2 +1,2 @@
 if not config.enable_vulkan_runner:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/python/develoment_files.py b/mlir/test/python/develoment_files.py
index ea0a911890a07..4dc3a0b700b1f 100644
--- a/mlir/test/python/develoment_files.py
+++ b/mlir/test/python/develoment_files.py
@@ -14,5 +14,6 @@
 all_libs = os.listdir(get_lib_dirs()[0])
 found_lib = False
 for file_name in all_libs:
-  if expected_lib_name in file_name: found_lib = True
+    if expected_lib_name in file_name:
+        found_lib = True
 assert found_lib, f"Did not find '{expected_lib_name}' lib in {all_libs}"

diff  --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index acae9b6474083..8e9613d052466 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,16 +4,18 @@
 import mlir.dialects.func as func
 import mlir.dialects.arith as arith
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
+
 
 # CHECK-LABEL: TEST: testConstantOp
 @run
 def testConstantOps():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      arith.ConstantOp(value=42.42, result=F32Type.get())
-    # CHECK:         %cst = arith.constant 4.242000e+01 : f32
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            arith.ConstantOp(value=42.42, result=F32Type.get())
+        # CHECK:         %cst = arith.constant 4.242000e+01 : f32
+        print(module)

diff  --git a/mlir/test/python/dialects/async_dialect.py b/mlir/test/python/dialects/async_dialect.py
index da3103cecddf2..f6181cc76118e 100644
--- a/mlir/test/python/dialects/async_dialect.py
+++ b/mlir/test/python/dialects/async_dialect.py
@@ -5,14 +5,17 @@
 import mlir.dialects.async_dialect.passes
 from mlir.passmanager import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
+
 
 def testAsyncPass():
-  with Context() as context:
-    PassManager.parse('any(async-to-async-runtime)')
-  print('SUCCESS')
+    with Context() as context:
+        PassManager.parse("any(async-to-async-runtime)")
+    print("SUCCESS")
+
 
 # CHECK-LABEL: testAsyncPass
 #       CHECK: SUCCESS

diff  --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index eab24b5c796b0..18ebba61e7fea 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -7,232 +7,242 @@
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testFromPyFunc
 @run
 def testFromPyFunc():
-  with Context() as ctx, Location.unknown() as loc:
-    ctx.allow_unregistered_dialects = True
-    m = builtin.ModuleOp()
-    f32 = F32Type.get()
-    f64 = F64Type.get()
-    with InsertionPoint(m.body):
-      # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
-      # CHECK: return %arg0 : f64
-      @func.FuncOp.from_py_func(f64)
-      def unary_return(a):
-        return a
-
-      # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
-      # CHECK: return %arg0, %arg1 : f32, f64
-      @func.FuncOp.from_py_func(f32, f64)
-      def binary_return(a, b):
-        return a, b
-
-      # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
-      # CHECK: return
-      @func.FuncOp.from_py_func(f32, f64)
-      def none_return(a, b):
-        pass
-
-      # CHECK-LABEL: func @call_unary
-      # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
-      # CHECK: return %0 : f64
-      @func.FuncOp.from_py_func(f64)
-      def call_unary(a):
-        return unary_return(a)
-
-      # CHECK-LABEL: func @call_binary
-      # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
-      # CHECK: return %0#0, %0#1 : f32, f64
-      @func.FuncOp.from_py_func(f32, f64)
-      def call_binary(a, b):
-        return binary_return(a, b)
-
-      # We expect coercion of a single result operation to a returned value.
-      # CHECK-LABEL: func @single_result_op
-      # CHECK: %0 = "custom.op1"() : () -> f32
-      # CHECK: return %0 : f32
-      @func.FuncOp.from_py_func()
-      def single_result_op():
-        return Operation.create("custom.op1", results=[f32])
-
-      # CHECK-LABEL: func @call_none
-      # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
-      # CHECK: return
-      @func.FuncOp.from_py_func(f32, f64)
-      def call_none(a, b):
-        return none_return(a, b)
-
-      ## Variants and optional feature tests.
-      # CHECK-LABEL: func @from_name_arg
-      @func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
-      def explicit_name(a, b):
-        return b
-
-      @func.FuncOp.from_py_func(f32, f64)
-      def positional_func_op(a, b, func_op):
-        assert isinstance(func_op, func.FuncOp)
-        return b
-
-      @func.FuncOp.from_py_func(f32, f64)
-      def kw_func_op(a, b=None, func_op=None):
-        assert isinstance(func_op, func.FuncOp)
-        return b
-
-      @func.FuncOp.from_py_func(f32, f64)
-      def kwargs_func_op(a, b=None, **kwargs):
-        assert isinstance(kwargs["func_op"], func.FuncOp)
-        return b
-
-      # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
-      # CHECK: return %arg1 : f64
-      @func.FuncOp.from_py_func(f32, f64, results=[f64])
-      def explicit_results(a, b):
-        func.ReturnOp([b])
-
-  print(m)
+    with Context() as ctx, Location.unknown() as loc:
+        ctx.allow_unregistered_dialects = True
+        m = builtin.ModuleOp()
+        f32 = F32Type.get()
+        f64 = F64Type.get()
+        with InsertionPoint(m.body):
+            # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
+            # CHECK: return %arg0 : f64
+            @func.FuncOp.from_py_func(f64)
+            def unary_return(a):
+                return a
+
+            # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
+            # CHECK: return %arg0, %arg1 : f32, f64
+            @func.FuncOp.from_py_func(f32, f64)
+            def binary_return(a, b):
+                return a, b
+
+            # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
+            # CHECK: return
+            @func.FuncOp.from_py_func(f32, f64)
+            def none_return(a, b):
+                pass
+
+            # CHECK-LABEL: func @call_unary
+            # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
+            # CHECK: return %0 : f64
+            @func.FuncOp.from_py_func(f64)
+            def call_unary(a):
+                return unary_return(a)
+
+            # CHECK-LABEL: func @call_binary
+            # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
+            # CHECK: return %0#0, %0#1 : f32, f64
+            @func.FuncOp.from_py_func(f32, f64)
+            def call_binary(a, b):
+                return binary_return(a, b)
+
+            # We expect coercion of a single result operation to a returned value.
+            # CHECK-LABEL: func @single_result_op
+            # CHECK: %0 = "custom.op1"() : () -> f32
+            # CHECK: return %0 : f32
+            @func.FuncOp.from_py_func()
+            def single_result_op():
+                return Operation.create("custom.op1", results=[f32])
+
+            # CHECK-LABEL: func @call_none
+            # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
+            # CHECK: return
+            @func.FuncOp.from_py_func(f32, f64)
+            def call_none(a, b):
+                return none_return(a, b)
+
+            ## Variants and optional feature tests.
+            # CHECK-LABEL: func @from_name_arg
+            @func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
+            def explicit_name(a, b):
+                return b
+
+            @func.FuncOp.from_py_func(f32, f64)
+            def positional_func_op(a, b, func_op):
+                assert isinstance(func_op, func.FuncOp)
+                return b
+
+            @func.FuncOp.from_py_func(f32, f64)
+            def kw_func_op(a, b=None, func_op=None):
+                assert isinstance(func_op, func.FuncOp)
+                return b
+
+            @func.FuncOp.from_py_func(f32, f64)
+            def kwargs_func_op(a, b=None, **kwargs):
+                assert isinstance(kwargs["func_op"], func.FuncOp)
+                return b
+
+            # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
+            # CHECK: return %arg1 : f64
+            @func.FuncOp.from_py_func(f32, f64, results=[f64])
+            def explicit_results(a, b):
+                func.ReturnOp([b])
+
+    print(m)
 
 
 # CHECK-LABEL: TEST: testFromPyFuncErrors
 @run
 def testFromPyFuncErrors():
-  with Context() as ctx, Location.unknown() as loc:
-    m = builtin.ModuleOp()
-    f32 = F32Type.get()
-    f64 = F64Type.get()
-    with InsertionPoint(m.body):
-      try:
+    with Context() as ctx, Location.unknown() as loc:
+        m = builtin.ModuleOp()
+        f32 = F32Type.get()
+        f64 = F64Type.get()
+        with InsertionPoint(m.body):
+            try:
 
-        @func.FuncOp.from_py_func(f64, results=[f64])
-        def unary_return(a):
-          return a
-      except AssertionError as e:
-        # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
-        print(e)
+                @func.FuncOp.from_py_func(f64, results=[f64])
+                def unary_return(a):
+                    return a
+
+            except AssertionError as e:
+                # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
+                print(e)
 
 
 # CHECK-LABEL: TEST: testBuildFuncOp
 @run
 def testBuildFuncOp():
-  ctx = Context()
-  with Location.unknown(ctx) as loc:
-    m = builtin.ModuleOp()
-
-    f32 = F32Type.get()
-    tensor_type = RankedTensorType.get((2, 3, 4), f32)
-    with InsertionPoint.at_block_begin(m.body):
-      f = func.FuncOp(name="some_func",
-                            type=FunctionType.get(
-                                inputs=[tensor_type, tensor_type],
-                                results=[tensor_type]),
-                            visibility="nested")
-      # CHECK: Name is: "some_func"
-      print("Name is: ", f.name)
-
-      # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-      print("Type is: ", f.type)
-
-      # CHECK: Visibility is: "nested"
-      print("Visibility is: ", f.visibility)
-
-      try:
-        entry_block = f.entry_block
-      except IndexError as e:
-        # CHECK: External function does not have a body
-        print(e)
-
-      with InsertionPoint(f.add_entry_block()):
-        func.ReturnOp([f.entry_block.arguments[0]])
-        pass
-
-      try:
-        f.add_entry_block()
-      except IndexError as e:
-        # CHECK: The function already has an entry block!
-        print(e)
-
-      # Try the callback builder and passing type as tuple.
-      f = func.FuncOp(name="some_other_func",
-                            type=([tensor_type, tensor_type], [tensor_type]),
-                            visibility="nested",
-                            body_builder=lambda f: func.ReturnOp(
-                                [f.entry_block.arguments[0]]))
-
-  # CHECK: module  {
-  # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
-  # CHECK:   return %arg0 : tensor<2x3x4xf32>
-  # CHECK:  }
-  # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
-  # CHECK:   return %arg0 : tensor<2x3x4xf32>
-  # CHECK:  }
-  print(m)
+    ctx = Context()
+    with Location.unknown(ctx) as loc:
+        m = builtin.ModuleOp()
+
+        f32 = F32Type.get()
+        tensor_type = RankedTensorType.get((2, 3, 4), f32)
+        with InsertionPoint.at_block_begin(m.body):
+            f = func.FuncOp(
+                name="some_func",
+                type=FunctionType.get(
+                    inputs=[tensor_type, tensor_type], results=[tensor_type]
+                ),
+                visibility="nested",
+            )
+            # CHECK: Name is: "some_func"
+            print("Name is: ", f.name)
+
+            # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+            print("Type is: ", f.type)
+
+            # CHECK: Visibility is: "nested"
+            print("Visibility is: ", f.visibility)
+
+            try:
+                entry_block = f.entry_block
+            except IndexError as e:
+                # CHECK: External function does not have a body
+                print(e)
+
+            with InsertionPoint(f.add_entry_block()):
+                func.ReturnOp([f.entry_block.arguments[0]])
+                pass
+
+            try:
+                f.add_entry_block()
+            except IndexError as e:
+                # CHECK: The function already has an entry block!
+                print(e)
+
+            # Try the callback builder and passing type as tuple.
+            f = func.FuncOp(
+                name="some_other_func",
+                type=([tensor_type, tensor_type], [tensor_type]),
+                visibility="nested",
+                body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]),
+            )
+
+    # CHECK: module  {
+    # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+    # CHECK:   return %arg0 : tensor<2x3x4xf32>
+    # CHECK:  }
+    # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+    # CHECK:   return %arg0 : tensor<2x3x4xf32>
+    # CHECK:  }
+    print(m)
 
 
 # CHECK-LABEL: TEST: testFuncArgumentAccess
 @run
 def testFuncArgumentAccess():
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    module = Module.create()
-    f32 = F32Type.get()
-    f64 = F64Type.get()
-    with InsertionPoint(module.body):
-      f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
-      with InsertionPoint(f.add_entry_block()):
-        func.ReturnOp(f.arguments)
-      f.arg_attrs = ArrayAttr.get([
-          DictAttr.get({
-              "custom_dialect.foo": StringAttr.get("bar"),
-              "custom_dialect.baz": UnitAttr.get()
-          }),
-          DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
-      ])
-      f.result_attrs = ArrayAttr.get([
-          DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
-          DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
-      ])
-
-      other = func.FuncOp("other_func", ([f32, f32], []))
-      with InsertionPoint(other.add_entry_block()):
-        func.ReturnOp([])
-      other.arg_attrs = [
-          DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
-          DictAttr.get()
-      ]
-
-  # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
-  print(f.arg_attrs)
-
-  # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
-  print(f.result_attrs)
-
-  # CHECK: func @some_func(
-  # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
-  # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
-  # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
-  # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
-  # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
-  #
-  # CHECK: func @other_func(
-  # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
-  # CHECK: %{{.*}}: f32)
-  print(module)
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        module = Module.create()
+        f32 = F32Type.get()
+        f64 = F64Type.get()
+        with InsertionPoint(module.body):
+            f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
+            with InsertionPoint(f.add_entry_block()):
+                func.ReturnOp(f.arguments)
+            f.arg_attrs = ArrayAttr.get(
+                [
+                    DictAttr.get(
+                        {
+                            "custom_dialect.foo": StringAttr.get("bar"),
+                            "custom_dialect.baz": UnitAttr.get(),
+                        }
+                    ),
+                    DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}),
+                ]
+            )
+            f.result_attrs = ArrayAttr.get(
+                [
+                    DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
+                    DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}),
+                ]
+            )
+
+            other = func.FuncOp("other_func", ([f32, f32], []))
+            with InsertionPoint(other.add_entry_block()):
+                func.ReturnOp([])
+            other.arg_attrs = [
+                DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
+                DictAttr.get(),
+            ]
+
+    # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
+    print(f.arg_attrs)
+
+    # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
+    print(f.result_attrs)
+
+    # CHECK: func @some_func(
+    # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
+    # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
+    # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
+    # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
+    # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
+    #
+    # CHECK: func @other_func(
+    # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
+    # CHECK: %{{.*}}: f32)
+    print(module)
 
 
 # CHECK-LABEL: testDenseElementsAttr
 @run
 def testDenseElementsAttr():
-  with Context(), Location.unknown():
-    values = np.arange(4, dtype=np.int32)
-    i32 = IntegerType.get_signless(32)
-    print(DenseElementsAttr.get(values, type=i32))
-    # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
-    print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
-    # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
-    print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
-    # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
+    with Context(), Location.unknown():
+        values = np.arange(4, dtype=np.int32)
+        i32 = IntegerType.get_signless(32)
+        print(DenseElementsAttr.get(values, type=i32))
+        # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
+        print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
+        # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+        print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
+        # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>

diff  --git a/mlir/test/python/dialects/complex_dialect.py b/mlir/test/python/dialects/complex_dialect.py
index e724575b5bf5b..afad21757bc3c 100644
--- a/mlir/test/python/dialects/complex_dialect.py
+++ b/mlir/test/python/dialects/complex_dialect.py
@@ -9,24 +9,24 @@
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
 
 
 # CHECK-LABEL: TEST: testComplexOps
 @run
 def testComplexOps():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
-      def emit_add(arg):
-        return mlir_complex.AddOp(arg, arg)
-
-    # CHECK-LABEL: func @emit_add(
-    # CHECK-SAME:                  %[[ARG:.*]]: complex<f32>) -> complex<f32> {
-    # CHECK:         %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
-    # CHECK:         return %[[RES]] : complex<f32>
-    # CHECK:       }
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
+            def emit_add(arg):
+                return mlir_complex.AddOp(arg, arg)
+
+        # CHECK-LABEL: func @emit_add(
+        # CHECK-SAME:                  %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+        # CHECK:         %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
+        # CHECK:         return %[[RES]] : complex<f32>
+        # CHECK:       }
+        print(module)

diff  --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 3be9cac2c1925..161a12d78776a 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -7,13 +7,13 @@
 
 
 def constructAndPrintInModule(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      f()
-    print(module)
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 # CHECK-LABEL: TEST: testConstantOp
@@ -21,21 +21,21 @@ def constructAndPrintInModule(f):
 
 @constructAndPrintInModule
 def testConstantOp():
-  c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
-  c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
-  c3 = arith.ConstantOp(F32Type.get(), 3.14)
-  c4 = arith.ConstantOp(F64Type.get(), 1.23)
-  # CHECK: 42
-  print(c1.literal_value)
+    c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
+    c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
+    c3 = arith.ConstantOp(F32Type.get(), 3.14)
+    c4 = arith.ConstantOp(F64Type.get(), 1.23)
+    # CHECK: 42
+    print(c1.literal_value)
 
-  # CHECK: 100
-  print(c2.literal_value)
+    # CHECK: 100
+    print(c2.literal_value)
 
-  # CHECK: 3.140000104904175
-  print(c3.literal_value)
+    # CHECK: 3.140000104904175
+    print(c3.literal_value)
 
-  # CHECK: 1.23
-  print(c4.literal_value)
+    # CHECK: 1.23
+    print(c4.literal_value)
 
 
 # CHECK: = arith.constant 42 : i32
@@ -47,17 +47,17 @@ def testConstantOp():
 # CHECK-LABEL: TEST: testVectorConstantOp
 @constructAndPrintInModule
 def testVectorConstantOp():
-  int_type = IntegerType.get_signless(32)
-  vec_type = VectorType.get([2, 2], int_type)
-  c1 = arith.ConstantOp(
-      vec_type,
-      DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
-  try:
-    print(c1.literal_value)
-  except ValueError as e:
-    assert "only integer and float constants have literal values" in str(e)
-  else:
-    assert False
+    int_type = IntegerType.get_signless(32)
+    vec_type = VectorType.get([2, 2], int_type)
+    c1 = arith.ConstantOp(
+        vec_type, DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))
+    )
+    try:
+        print(c1.literal_value)
+    except ValueError as e:
+        assert "only integer and float constants have literal values" in str(e)
+    else:
+        assert False
 
 
 # CHECK: = arith.constant dense<42> : vector<2x2xi32>
@@ -66,9 +66,9 @@ def testVectorConstantOp():
 # CHECK-LABEL: TEST: testConstantIndexOp
 @constructAndPrintInModule
 def testConstantIndexOp():
-  c1 = arith.ConstantOp.create_index(10)
-  # CHECK: 10
-  print(c1.literal_value)
+    c1 = arith.ConstantOp.create_index(10)
+    # CHECK: 10
+    print(c1.literal_value)
 
 
 # CHECK: = arith.constant 10 : index
@@ -77,18 +77,18 @@ def testConstantIndexOp():
 # CHECK-LABEL: TEST: testFunctionCalls
 @constructAndPrintInModule
 def testFunctionCalls():
-  foo = func.FuncOp("foo", ([], []))
-  foo.sym_visibility = StringAttr.get("private")
-  bar = func.FuncOp("bar", ([], [IndexType.get()]))
-  bar.sym_visibility = StringAttr.get("private")
-  qux = func.FuncOp("qux", ([], [F32Type.get()]))
-  qux.sym_visibility = StringAttr.get("private")
-
-  with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
-    func.CallOp(foo, [])
-    func.CallOp([IndexType.get()], "bar", [])
-    func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
-    func.ReturnOp([])
+    foo = func.FuncOp("foo", ([], []))
+    foo.sym_visibility = StringAttr.get("private")
+    bar = func.FuncOp("bar", ([], [IndexType.get()]))
+    bar.sym_visibility = StringAttr.get("private")
+    qux = func.FuncOp("qux", ([], [F32Type.get()]))
+    qux.sym_visibility = StringAttr.get("private")
+
+    with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
+        func.CallOp(foo, [])
+        func.CallOp([IndexType.get()], "bar", [])
+        func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
+        func.ReturnOp([])
 
 
 # CHECK: func private @foo()

diff  --git a/mlir/test/python/dialects/gpu.py b/mlir/test/python/dialects/gpu.py
index 38bf038a5eeed..7eefaed711c2c 100644
--- a/mlir/test/python/dialects/gpu.py
+++ b/mlir/test/python/dialects/gpu.py
@@ -5,14 +5,17 @@
 import mlir.dialects.gpu.passes
 from mlir.passmanager import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
+
 
 def testGPUPass():
-  with Context() as context:
-    PassManager.parse('any(gpu-kernel-outlining)')
-  print('SUCCESS')
+    with Context() as context:
+        PassManager.parse("any(gpu-kernel-outlining)")
+    print("SUCCESS")
+
 
 # CHECK-LABEL: testGPUPass
 #       CHECK: SUCCESS

diff  --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index d787c5f49c441..7892d020d9067 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -34,8 +34,9 @@ def matmul(
     C=TensorDef(U, S.M, S.N, output=True),
     bfn=BinaryFnAttrDef(default=BinaryFn.mul),
     ufn=UnaryFnAttrDef(default=UnaryFn.exp),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
 
 
 # CHECK: ---
@@ -47,7 +48,7 @@ def matmul(
 # CHECK:     type_var: T
 @linalg_structured_op
 def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
-  O[D.m, D.n] = value
+    O[D.m, D.n] = value
 
 
 # CHECK: ---
@@ -71,5 +72,6 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
 def strided_copy(
     I=TensorDef(T, S.IH, S.IW),
     O=TensorDef(T, S.OH, S.OW, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 2])):
-  O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 2]),
+):
+    O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]

diff  --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index eacf43547b110..ad0a3eac5913e 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -35,8 +35,9 @@ def matmul(
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
     mul=BinaryFnAttrDef(default=BinaryFn.mul),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
 
 
 # CHECK: ---
@@ -79,12 +80,12 @@ def matmul(
 # CHECK:                  scalar_const: '1.{{[0]*}}e+03 : f64'
 @linalg_structured_op
 def constants(
-    O=TensorDef(T, S.M, S.K, output=True),
-    exp=UnaryFnAttrDef(default=UnaryFn.exp)):
-  pi = TypeFn.cast_signed(T, const(3.1415926535897931))
-  cst42 = TypeFn.cast_signed(T, const(42))
-  cst1000 = TypeFn.cast_signed(T, exp(const(1e+3)))
-  O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
+    O=TensorDef(T, S.M, S.K, output=True), exp=UnaryFnAttrDef(default=UnaryFn.exp)
+):
+    pi = TypeFn.cast_signed(T, const(3.1415926535897931))
+    cst42 = TypeFn.cast_signed(T, const(42))
+    cst1000 = TypeFn.cast_signed(T, exp(const(1e3)))
+    O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
 
 
 # CHECK: ---
@@ -100,7 +101,7 @@ def constants(
 # CHECK:          scalar_index: 0
 @linalg_structured_op
 def indices(O=TensorDef(T, S.M, S.K, output=True)):
-  O[D.m, D.n] = index(D.n) + index(D.m)
+    O[D.m, D.n] = index(D.n) + index(D.m)
 
 
 # CHECK: ---
@@ -111,4 +112,4 @@ def indices(O=TensorDef(T, S.M, S.K, output=True)):
 # CHECK:      scalar_arg: value
 @linalg_structured_op
 def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
-  O[D.m, D.n] = value
+    O[D.m, D.n] = value

diff  --git a/mlir/test/python/dialects/linalg/opdsl/doctests.py b/mlir/test/python/dialects/linalg/opdsl/doctests.py
index 4aae768848815..d2f9cec19d570 100644
--- a/mlir/test/python/dialects/linalg/opdsl/doctests.py
+++ b/mlir/test/python/dialects/linalg/opdsl/doctests.py
@@ -3,10 +3,11 @@
 import doctest
 import importlib
 
+
 def test_module(module_name):
-  print(f"--- Testing module: {module_name}")
-  m = importlib.import_module(module_name)
-  doctest.testmod(m, verbose=True, raise_on_error=True, report=True)
+    print(f"--- Testing module: {module_name}")
+    m = importlib.import_module(module_name)
+    doctest.testmod(m, verbose=True, raise_on_error=True, report=True)
 
 
 test_module("mlir.dialects.linalg.opdsl.lang.affine")

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
index ebe2c0f33a286..d666d313767b9 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
@@ -17,43 +17,44 @@ def conv_poly(
     K=TensorDef(T2, S.KH, S.KW, S.C),
     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])):
-  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c])
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2]),
+):
+    domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c])
 
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f32 = F32Type.get()
-  i32 = IntegerType.get_signless(32)
-  with InsertionPoint(module.body):
-
-    # Convolution indexing maps.
-    # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
-    # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
-    # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
-
-    # CHECK-LABEL: @test_f32i32_conv
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
-    # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
-    # CHECK-NEXT:   %[[FILTER_CAST:.+]] = arith.fptosi %[[FILTER:.+]] : f32 to i32
-    # CHECK-NEXT:   %[[PROD:.+]] = arith.muli %[[IN_CAST]], %[[FILTER_CAST]] : i32
-    # CHECK-NEXT:   %[[SUM:.+]] = arith.addi %[[OUT]], %[[PROD]] : i32
-    # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
-    # CHECK-NEXT: -> tensor<1x2x4x1xi32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2, 1), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_conv(input, filter, init_result):
-      # Use default dilations and set non-default strides.
-      return conv_poly(
-          input, filter, outs=[init_result], strides=[2, 4])
+    module = Module.create()
+    f32 = F32Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+        # Convolution indexing maps.
+        # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
+        # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+        # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+
+        # CHECK-LABEL: @test_f32i32_conv
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
+        # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
+        # CHECK-NEXT:   %[[FILTER_CAST:.+]] = arith.fptosi %[[FILTER:.+]] : f32 to i32
+        # CHECK-NEXT:   %[[PROD:.+]] = arith.muli %[[IN_CAST]], %[[FILTER_CAST]] : i32
+        # CHECK-NEXT:   %[[SUM:.+]] = arith.addi %[[OUT]], %[[PROD]] : i32
+        # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
+        # CHECK-NEXT: -> tensor<1x2x4x1xi32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2, 1), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_conv(input, filter, init_result):
+            # Use default dilations and set non-default strides.
+            return conv_poly(input, filter, outs=[init_result], strides=[2, 4])
 
 
 print(module)

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
index 1f840b09b0085..ffef737755e85 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
@@ -13,47 +13,51 @@
 
 @linalg_structured_op
 def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
-  O[None] = TypeFn.cast_signed(U, value)
+    O[None] = TypeFn.cast_signed(U, value)
+
 
 @linalg_structured_op
 def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)):
-  O[None] = TypeFn.cast_signed(U, I[None])
+    O[None] = TypeFn.cast_signed(U, I[None])
+
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f32 = F32Type.get()
-  with InsertionPoint(module.body):
-
-    # Fill indexing maps.
-    # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
-    # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
-    # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-    # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
-    # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
-    # CHECK-LABEL: @test_fill_0d
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]
-    # CHECK-SAME: iterator_types = []
-    @func.FuncOp.from_py_func(f32, RankedTensorType.get([], f32))
-    def test_fill_0d(value, init_result):
-      return fill_poly(value, outs=[init_result])
-
-    # CHECK-LABEL: @test_fill_2d
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel"]
-    @func.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32))
-    def test_fill_2d(value, init_result):
-      return fill_poly(value, outs=[init_result])
-
-    # CHECK-LABEL: @test_fill_rank_zero_3d
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32))
-    def test_fill_rank_zero_3d(input, init_result):
-      return fill_rank_zero_poly(input, outs=[init_result])
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+
+        # Fill indexing maps.
+        # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
+        # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+        # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+        # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
+        # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+        # CHECK-LABEL: @test_fill_0d
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]
+        # CHECK-SAME: iterator_types = []
+        @func.FuncOp.from_py_func(f32, RankedTensorType.get([], f32))
+        def test_fill_0d(value, init_result):
+            return fill_poly(value, outs=[init_result])
+
+        # CHECK-LABEL: @test_fill_2d
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel"]
+        @func.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32))
+        def test_fill_2d(value, init_result):
+            return fill_poly(value, outs=[init_result])
+
+        # CHECK-LABEL: @test_fill_rank_zero_3d
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32)
+        )
+        def test_fill_rank_zero_3d(input, init_result):
+            return fill_rank_zero_poly(input, outs=[init_result])
+
 
 print(module)

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
index 6dff754859eef..18c237c68081a 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
@@ -16,9 +16,10 @@
 def matmul_mono(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
-    C=TensorDef(T, S.M, S.N, output=True)):
-  domain(D.m, D.n, D.k)
-  C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
+    C=TensorDef(T, S.M, S.N, output=True),
+):
+    domain(D.m, D.n, D.k)
+    C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
 
 
 @linalg_structured_op
@@ -26,146 +27,162 @@ def matmul_poly(
     A=TensorDef(T1, S.M, S.K),
     B=TensorDef(T2, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  domain(D.m, D.n, D.k)
-  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    domain(D.m, D.n, D.k)
+    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f16 = F16Type.get()
-  f32 = F32Type.get()
-  f64 = F64Type.get()
-  i8 = IntegerType.get_signless(8)
-  i16 = IntegerType.get_signless(16)
-  i32 = IntegerType.get_signless(32)
-  with InsertionPoint(module.body):
-
-    # Multiplication indexing maps. We verify only the indexing maps of the
-    # first multiplication and then do additional tests on casting and body
-    # generation behavior.
-    # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-    # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-    # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-    # CHECK-LABEL: func @test_matmul_mono
-    # CHECK-SAME:  %[[A:.+]]: tensor<4x16xf32>
-    # CHECK-SAME:  %[[B:.+]]: tensor<16x8xf32>
-    # CHECK: %[[INITC:.+]] = tensor.empty() : tensor<4x8xf32>
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
-    # CHECK-SAME: ins(%[[A]], %[[B]]
-    # CHECK-SAME: outs(%[[INITC]]
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32))
-    def test_matmul_mono(lhs, rhs):
-      init_result = tensor.EmptyOp([4, 8], f32)
-      return matmul_mono(lhs, rhs, outs=[init_result.result])
-
-    # CHECK-LABEL: @test_i8i8i32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i8 to i32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
-    # CHECK-NEXT: -> tensor<4x8xi32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
-        RankedTensorType.get((4, 8), i32))
-    def test_i8i8i32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_i8i8i32_matmul_unsigned
-    # CHECK:   = arith.extui
-    # CHECK:   = arith.extui
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
-        RankedTensorType.get((4, 8), i32))
-    def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
-      return matmul_poly(
-          lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
-
-    # CHECK-LABEL: @test_i8i16i32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i16 to i32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
-    # CHECK-NEXT: -> tensor<4x8xi32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i16),
-        RankedTensorType.get((4, 8), i32))
-    def test_i8i16i32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_i32i32i16_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.trunci %[[A_ARG]] : i32 to i16
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i32 to i16
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i16
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i16
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : i16
-    # CHECK-NEXT: -> tensor<4x8xi16>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i32), RankedTensorType.get((16, 8), i32),
-        RankedTensorType.get((4, 8), i16))
-    def test_i32i32i16_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_i8i8f32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.sitofp %[[A_ARG]] : i8 to f32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.sitofp %[[B_ARG]] : i8 to f32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-    # CHECK-NEXT: -> tensor<4x8xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
-        RankedTensorType.get((4, 8), f32))
-    def test_i8i8f32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_i8i8f32_matmul_unsigned
-    # CHECK:   = arith.uitofp
-    # CHECK:   = arith.uitofp
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
-        RankedTensorType.get((4, 8), f32))
-    def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
-      return matmul_poly(
-          lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
-
-    # CHECK-LABEL: @test_f16f16f32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-    # CHECK-NEXT: -> tensor<4x8xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f16), RankedTensorType.get((16, 8), f16),
-        RankedTensorType.get((4, 8), f32))
-    def test_f16f16f32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_f64f64f32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.truncf %[[A_ARG]] : f64 to f32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-    # CHECK-NEXT: -> tensor<4x8xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f64), RankedTensorType.get((16, 8), f64),
-        RankedTensorType.get((4, 8), f32))
-    def test_f64f64f32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
+    module = Module.create()
+    f16 = F16Type.get()
+    f32 = F32Type.get()
+    f64 = F64Type.get()
+    i8 = IntegerType.get_signless(8)
+    i16 = IntegerType.get_signless(16)
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+        # Multiplication indexing maps. We verify only the indexing maps of the
+        # first multiplication and then do additional tests on casting and body
+        # generation behavior.
+        # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+        # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+        # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+        # CHECK-LABEL: func @test_matmul_mono
+        # CHECK-SAME:  %[[A:.+]]: tensor<4x16xf32>
+        # CHECK-SAME:  %[[B:.+]]: tensor<16x8xf32>
+        # CHECK: %[[INITC:.+]] = tensor.empty() : tensor<4x8xf32>
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+        # CHECK-SAME: ins(%[[A]], %[[B]]
+        # CHECK-SAME: outs(%[[INITC]]
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+        )
+        def test_matmul_mono(lhs, rhs):
+            init_result = tensor.EmptyOp([4, 8], f32)
+            return matmul_mono(lhs, rhs, outs=[init_result.result])
+
+        # CHECK-LABEL: @test_i8i8i32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i8 to i32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+        # CHECK-NEXT: -> tensor<4x8xi32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i8),
+            RankedTensorType.get((4, 8), i32),
+        )
+        def test_i8i8i32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_i8i8i32_matmul_unsigned
+        # CHECK:   = arith.extui
+        # CHECK:   = arith.extui
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i8),
+            RankedTensorType.get((4, 8), i32),
+        )
+        def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
+
+        # CHECK-LABEL: @test_i8i16i32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i16 to i32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+        # CHECK-NEXT: -> tensor<4x8xi32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i16),
+            RankedTensorType.get((4, 8), i32),
+        )
+        def test_i8i16i32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_i32i32i16_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.trunci %[[A_ARG]] : i32 to i16
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i32 to i16
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i16
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i16
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : i16
+        # CHECK-NEXT: -> tensor<4x8xi16>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i32),
+            RankedTensorType.get((16, 8), i32),
+            RankedTensorType.get((4, 8), i16),
+        )
+        def test_i32i32i16_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_i8i8f32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.sitofp %[[A_ARG]] : i8 to f32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.sitofp %[[B_ARG]] : i8 to f32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+        # CHECK-NEXT: -> tensor<4x8xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i8),
+            RankedTensorType.get((4, 8), f32),
+        )
+        def test_i8i8f32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_i8i8f32_matmul_unsigned
+        # CHECK:   = arith.uitofp
+        # CHECK:   = arith.uitofp
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i8),
+            RankedTensorType.get((4, 8), f32),
+        )
+        def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
+
+        # CHECK-LABEL: @test_f16f16f32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+        # CHECK-NEXT: -> tensor<4x8xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f16),
+            RankedTensorType.get((16, 8), f16),
+            RankedTensorType.get((4, 8), f32),
+        )
+        def test_f16f16f32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_f64f64f32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.truncf %[[A_ARG]] : f64 to f32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+        # CHECK-NEXT: -> tensor<4x8xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f64),
+            RankedTensorType.get((16, 8), f64),
+            RankedTensorType.get((4, 8), f32),
+        )
+        def test_f64f64f32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
 
 
 print(module)

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index aad714998c108..f8e034fb0e48b 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -17,14 +17,16 @@
 
 @linalg_structured_op
 def test_const(O=TensorDef(F32, S.M, S.N, output=True)):
-  O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned(
-      F32, const(2.3283064e-10))
+    O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned(
+        F32, const(2.3283064e-10)
+    )
 
 
 @linalg_structured_op
 def test_index(O=TensorDef(I32, S.M, S.N, output=True)):
-  O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed(
-      I32, index(D.n))
+    O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed(
+        I32, index(D.n)
+    )
 
 
 @linalg_structured_op
@@ -32,120 +34,129 @@ def elemwise_unary_poly(
     I=TensorDef(T),
     O=TensorDef(U, output=True),
     fun=UnaryFnAttrDef(default=UnaryFn.exp),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  O[None] = fun(cast(U, I[None]))
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    O[None] = fun(cast(U, I[None]))
 
 
 @linalg_structured_op(op_name="custom_op_name")
 def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
-  O[D.n] = I[D.n]
+    O[D.n] = I[D.n]
 
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f32 = F32Type.get()
-  c32 = ComplexType.get(f32)
-  i32 = IntegerType.get_signless(32)
-  with InsertionPoint(module.body):
-
-    # CHECK-LABEL: @test_f32_const
-    # CHECK-DAG:    %[[CST0:.+]] = arith.constant 42 : i64
-    # CHECK-DAG:    %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32
-    # CHECK-DAG:    %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64
-    # CHECK-DAG:    %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32
-    # CHECK-DAG:    %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32
-    # CHECK-NEXT:   linalg.yield %[[SUM]] : f32
-    @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32))
-    def test_f32_const(init_result):
-      return test_const(outs=[init_result])
-
-    # CHECK-LABEL: @test_i32_index
-    # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
-    # CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
-    # CHECK-DAG:    %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32
-    # CHECK-DAG:    %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32
-    # CHECK-DAG:    %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32
-    # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
-    @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32))
-    def test_i32_index(init_result):
-      return test_index(outs=[init_result])
-
-    # CHECK-LABEL: @test_f32_elemwise_exp
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_exp(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
-
-    # CHECK-LABEL: @test_f32_elemwise_log
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[LOG]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_log(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
-
-    # CHECK-LABEL: @test_f32_elemwise_abs
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = math.absf %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_abs(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
-
-    # CHECK-LABEL: @test_f32_elemwise_ceil
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_ceil(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
-
-    # CHECK-LABEL: @test_f32_elemwise_floor
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_floor(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
-
-    # CHECK-LABEL: @test_f32_elemwise_neg
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_neg(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
-
-    # CHECK-LABEL: @test_c32_elemwise_neg
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
-    # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
-    # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32))
-    def test_c32_elemwise_neg(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
-
-    # Just check that we don't assert out on name mismatch.
-    # CHECK-LABEL: @test_non_default_op_name
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32))
-    def test_non_default_op_name(input, init_result):
-      return non_default_op_name(input, outs=[init_result])
+    module = Module.create()
+    f32 = F32Type.get()
+    c32 = ComplexType.get(f32)
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+        # CHECK-LABEL: @test_f32_const
+        # CHECK-DAG:    %[[CST0:.+]] = arith.constant 42 : i64
+        # CHECK-DAG:    %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32
+        # CHECK-DAG:    %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64
+        # CHECK-DAG:    %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32
+        # CHECK-DAG:    %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32
+        # CHECK-NEXT:   linalg.yield %[[SUM]] : f32
+        @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32))
+        def test_f32_const(init_result):
+            return test_const(outs=[init_result])
+
+        # CHECK-LABEL: @test_i32_index
+        # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
+        # CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
+        # CHECK-DAG:    %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32
+        # CHECK-DAG:    %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32
+        # CHECK-DAG:    %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32
+        # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
+        @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32))
+        def test_i32_index(init_result):
+            return test_index(outs=[init_result])
+
+        # CHECK-LABEL: @test_f32_elemwise_exp
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_exp(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+
+        # CHECK-LABEL: @test_f32_elemwise_log
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[LOG]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_log(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+
+        # CHECK-LABEL: @test_f32_elemwise_abs
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = math.absf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_abs(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
+        # CHECK-LABEL: @test_f32_elemwise_ceil
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_ceil(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
+
+        # CHECK-LABEL: @test_f32_elemwise_floor
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_floor(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
+
+        # CHECK-LABEL: @test_f32_elemwise_neg
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_neg(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
+        # CHECK-LABEL: @test_c32_elemwise_neg
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_neg(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
+        # Just check that we don't assert out on name mismatch.
+        # CHECK-LABEL: @test_non_default_op_name
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32)
+        )
+        def test_non_default_op_name(input, init_result):
+            return non_default_op_name(input, outs=[init_result])
 
 
 print(module)

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
index 2fd63382c4ec3..ab049d3dfae57 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
@@ -19,121 +19,134 @@ def pooling_poly(
     reduce=BinaryFnAttrDef(default=BinaryFn.max_signed),
     cast=TypeFnAttrDef(default=TypeFn.cast_signed),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
-      cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-                D.c]))
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
+        cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f32 = F32Type.get()
-  i32 = IntegerType.get_signless(32)
-  with InsertionPoint(module.body):
-
-    # Pooling indexing maps.
-    # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
-    # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
-    # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
-
-    # CHECK-LABEL: @test_f32i32_max_pooling
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
-    # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
-    # CHECK-NEXT:   %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
-    # CHECK-NEXT:   linalg.yield %[[MAX]] : i32
-    # CHECK-NEXT: -> tensor<1x2x4x1xi32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_max_pooling(input, shape, init_result):
-      return pooling_poly(
-          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
-    # CHECK:   = arith.fptoui
-    # CHECK:   = arith.maxui
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_max_unsigned_pooling(input, shape, init_result):
-      return pooling_poly(
-          input,
-          shape,
-          outs=[init_result],
-          reduce=BinaryFn.max_unsigned,
-          cast=TypeFn.cast_unsigned,
-          strides=[2, 4],
-          dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32f32_max_pooling
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32
-    # CHECK-NEXT:   linalg.yield %[[MAX]] : f32
-    # CHECK-NEXT: -> tensor<1x2x4x1xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), f32))
-    def test_f32f32_max_pooling(input, shape, init_result):
-      return pooling_poly(
-          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32i32_min_pooling
-    # CHECK:   = arith.fptosi
-    # CHECK:   = arith.minsi
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_min_pooling(input, shape, init_result):
-      return pooling_poly(
-          input,
-          shape,
-          outs=[init_result],
-          reduce=BinaryFn.min_signed,
-          strides=[2, 4],
-          dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
-    # CHECK:   = arith.fptoui
-    # CHECK:   = arith.minui
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_min_unsigned_pooling(input, shape, init_result):
-      return pooling_poly(
-          input,
-          shape,
-          outs=[init_result],
-          reduce=BinaryFn.min_unsigned,
-          cast=TypeFn.cast_unsigned,
-          strides=[2, 4],
-          dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32f32_min_pooling
-    # CHECK:   = arith.minf
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), f32))
-    def test_f32f32_min_pooling(input, shape, init_result):
-      return pooling_poly(
-          input,
-          shape,
-          outs=[init_result],
-          reduce=BinaryFn.min_signed,
-          strides=[2, 4],
-          dilations=[1, 2])
+    module = Module.create()
+    f32 = F32Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+        # Pooling indexing maps.
+        # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
+        # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+        # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+
+        # CHECK-LABEL: @test_f32i32_max_pooling
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
+        # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
+        # CHECK-NEXT:   %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
+        # CHECK-NEXT:   linalg.yield %[[MAX]] : i32
+        # CHECK-NEXT: -> tensor<1x2x4x1xi32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_max_pooling(input, shape, init_result):
+            return pooling_poly(
+                input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]
+            )
+
+        # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
+        # CHECK:   = arith.fptoui
+        # CHECK:   = arith.maxui
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_max_unsigned_pooling(input, shape, init_result):
+            return pooling_poly(
+                input,
+                shape,
+                outs=[init_result],
+                reduce=BinaryFn.max_unsigned,
+                cast=TypeFn.cast_unsigned,
+                strides=[2, 4],
+                dilations=[1, 2],
+            )
+
+        # CHECK-LABEL: @test_f32f32_max_pooling
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32
+        # CHECK-NEXT:   linalg.yield %[[MAX]] : f32
+        # CHECK-NEXT: -> tensor<1x2x4x1xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), f32),
+        )
+        def test_f32f32_max_pooling(input, shape, init_result):
+            return pooling_poly(
+                input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]
+            )
+
+        # CHECK-LABEL: @test_f32i32_min_pooling
+        # CHECK:   = arith.fptosi
+        # CHECK:   = arith.minsi
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_min_pooling(input, shape, init_result):
+            return pooling_poly(
+                input,
+                shape,
+                outs=[init_result],
+                reduce=BinaryFn.min_signed,
+                strides=[2, 4],
+                dilations=[1, 2],
+            )
+
+        # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
+        # CHECK:   = arith.fptoui
+        # CHECK:   = arith.minui
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_min_unsigned_pooling(input, shape, init_result):
+            return pooling_poly(
+                input,
+                shape,
+                outs=[init_result],
+                reduce=BinaryFn.min_unsigned,
+                cast=TypeFn.cast_unsigned,
+                strides=[2, 4],
+                dilations=[1, 2],
+            )
+
+        # CHECK-LABEL: @test_f32f32_min_pooling
+        # CHECK:   = arith.minf
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), f32),
+        )
+        def test_f32f32_min_pooling(input, shape, init_result):
+            return pooling_poly(
+                input,
+                shape,
+                outs=[init_result],
+                reduce=BinaryFn.min_signed,
+                strides=[2, 4],
+                dilations=[1, 2],
+            )
 
 
 print(module)

diff  --git a/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg b/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg
index cead85f016b96..18d2d458ed790 100644
--- a/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg
+++ b/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg
@@ -4,6 +4,6 @@
 # Since both lit and the python bindings use the same python interpreter,
 # we can just check whether yaml can be imported here and exclude if not.
 try:
-  import yaml
+    import yaml
 except ModuleNotFoundError:
-  config.unsupported = True
+    config.unsupported = True

diff  --git a/mlir/test/python/dialects/linalg/opdsl/metadata.py b/mlir/test/python/dialects/linalg/opdsl/metadata.py
index a7502e9eb1aae..9c940e1060cab 100644
--- a/mlir/test/python/dialects/linalg/opdsl/metadata.py
+++ b/mlir/test/python/dialects/linalg/opdsl/metadata.py
@@ -13,8 +13,10 @@
 def matmul(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True)):
-  implements(ContractionOpInterface)
-  defines(Canonicalizer)
-  C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
-      U, B[D.k, D.n])
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    implements(ContractionOpInterface)
+    defines(Canonicalizer)
+    C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.k, D.n]
+    )

diff  --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
index 871341c835a5d..4f3569b7974d8 100644
--- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
+++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
@@ -22,10 +22,12 @@
 def matmul(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True)):
-  domain(D.m, D.n, D.k)
-  C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
-      U, B[D.k, D.n])
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    domain(D.m, D.n, D.k)
+    C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.k, D.n]
+    )
 
 
 # Verifies that assignment to a scalar (represented as [None]) is represented
@@ -43,7 +45,7 @@ def matmul(
 # CHECK-NEXT: - reduction
 @linalg_structured_op
 def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
-  C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
+    C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
 
 
 # Verifies that the index_dims of shape-only operands translate to correct
@@ -64,6 +66,7 @@ def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
 def pool(
     I=TensorDef(T, S.I),
     K=TensorDef(T, S.K, index_dims=[D.k]),
-    O=TensorDef(U, S.O, output=True)):
-  domain(D.o, D.k)
-  O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k])
+    O=TensorDef(U, S.O, output=True),
+):
+    domain(D.o, D.k)
+    O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k])

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 1167abf84d4e4..5e8414ad4055c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -6,145 +6,154 @@
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testFill
 @run
 def testFill():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      # CHECK-LABEL: func @fill_tensor
-      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
-      #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
-      #  CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
-      #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
-      def fill_tensor(out):
-        zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
-        return linalg.fill(zero, outs=[out])
-
-      # CHECK-LABEL: func @fill_buffer
-      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
-      #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
-      #  CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
-      #  CHECK-NEXT: return
-      @func.FuncOp.from_py_func(
-          MemRefType.get((12, ShapedType.get_dynamic_size()), f32))
-      def fill_buffer(out):
-        zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
-        linalg.fill(zero, outs=[out])
-
-  print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            # CHECK-LABEL: func @fill_tensor
+            #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
+            #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
+            #  CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
+            #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)
+            )
+            def fill_tensor(out):
+                zero = arith.ConstantOp(
+                    value=FloatAttr.get(f32, 0.0), result=f32
+                ).result
+                return linalg.fill(zero, outs=[out])
+
+            # CHECK-LABEL: func @fill_buffer
+            #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
+            #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
+            #  CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
+            #  CHECK-NEXT: return
+            @func.FuncOp.from_py_func(
+                MemRefType.get((12, ShapedType.get_dynamic_size()), f32)
+            )
+            def fill_buffer(out):
+                zero = arith.ConstantOp(
+                    value=FloatAttr.get(f32, 0.0), result=f32
+                ).result
+                linalg.fill(zero, outs=[out])
+
+    print(module)
 
 
 # CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
 @run
 def testNamedStructuredOpCustomForm():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32))
-      def named_form(lhs, rhs):
-        init_result = tensor.EmptyOp([4, 8], f32)
-        # Check for the named form with custom format
-        #      CHECK: linalg.elemwise_unary
-        # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
-        # CHECK-SAME:    fun = #linalg.unary_fn<exp>
-        # CHECK-SAME:    ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
-        unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
-        #      CHECK: linalg.elemwise_binary
-        # CHECK-SAME:    cast = #linalg.type_fn<cast_unsigned>
-        # CHECK-SAME:    fun = #linalg.binary_fn<mul>
-        # CHECK-SAME:    ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
-        #      CHECK: return
-        binary_result = linalg.elemwise_binary(
-            lhs,
-            rhs,
-            outs=[init_result.result],
-            fun=BinaryFn.mul,
-            cast=TypeFn.cast_unsigned)
-        return unary_result, binary_result
-
-  print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)
+            )
+            def named_form(lhs, rhs):
+                init_result = tensor.EmptyOp([4, 8], f32)
+                # Check for the named form with custom format
+                #      CHECK: linalg.elemwise_unary
+                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
+                # CHECK-SAME:    fun = #linalg.unary_fn<exp>
+                # CHECK-SAME:    ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+                unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
+                #      CHECK: linalg.elemwise_binary
+                # CHECK-SAME:    cast = #linalg.type_fn<cast_unsigned>
+                # CHECK-SAME:    fun = #linalg.binary_fn<mul>
+                # CHECK-SAME:    ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+                #      CHECK: return
+                binary_result = linalg.elemwise_binary(
+                    lhs,
+                    rhs,
+                    outs=[init_result.result],
+                    fun=BinaryFn.mul,
+                    cast=TypeFn.cast_unsigned,
+                )
+                return unary_result, binary_result
+
+    print(module)
 
 
 # CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
 @run
 def testNamedStructuredOpGenericForm():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
-                                                                   f32))
-      def named_form(lhs, rhs):
-        init_result = tensor.EmptyOp([4, 8], f32)
-        #      CHECK: "linalg.matmul"(%{{.*}})
-        # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
-        # CHECK-SAME:    operand_segment_sizes = array<i32: 2, 1>
-        # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
-        # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
-        # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
-        # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-        # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
-        return linalg.matmul(lhs, rhs, outs=[init_result.result])
-
-  module.operation.print(print_generic_op_form=True)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+            )
+            def named_form(lhs, rhs):
+                init_result = tensor.EmptyOp([4, 8], f32)
+                #      CHECK: "linalg.matmul"(%{{.*}})
+                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
+                # CHECK-SAME:    operand_segment_sizes = array<i32: 2, 1>
+                # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
+                # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
+                # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
+                # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
+                # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
+                return linalg.matmul(lhs, rhs, outs=[init_result.result])
+
+    module.operation.print(print_generic_op_form=True)
 
 
 # CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
 @run
 def testNamedStructuredAsGenericOp():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
 
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
-                                                                   f32))
-      def generic_form(lhs, rhs):
-        init_result = tensor.EmptyOp([4, 8], f32)
-        # CHECK: linalg.generic
-        return linalg.matmul(
-            lhs, rhs, outs=[init_result.result], emit_generic=True)
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+            )
+            def generic_form(lhs, rhs):
+                init_result = tensor.EmptyOp([4, 8], f32)
+                # CHECK: linalg.generic
+                return linalg.matmul(
+                    lhs, rhs, outs=[init_result.result], emit_generic=True
+                )
 
-  print(module)
+    print(module)
 
 
 # CHECK-LABEL: TEST: testOpResultFromOtherOp
 @run
 def testOpResultFromOtherOp():
-  with Context(), Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
-                                                                   f32))
-      def pass_an_op_directly(arg0, arg1):
-        one = arith.ConstantOp(F32Type.get(), 1.0)
-        # CHECK: %[[LHS:.*]] = linalg.fill
-        lhs = linalg.fill(one, outs=[arg0])
-        # CHECK: %[[RHS:.*]] = linalg.fill
-        rhs = linalg.fill(one, outs=[arg1])
-        # CHECK: %[[INIT:.*]] = tensor.empty
-        init = tensor.EmptyOp([4, 8], f32)
-        # CHECK: linalg.matmul
-        # CHECK: ins(%[[LHS]], %[[RHS]]
-        # CHECK: outs(%[[INIT]]
-        return linalg.matmul(lhs, rhs, outs=init)
-
-  print(module)
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+            )
+            def pass_an_op_directly(arg0, arg1):
+                one = arith.ConstantOp(F32Type.get(), 1.0)
+                # CHECK: %[[LHS:.*]] = linalg.fill
+                lhs = linalg.fill(one, outs=[arg0])
+                # CHECK: %[[RHS:.*]] = linalg.fill
+                rhs = linalg.fill(one, outs=[arg1])
+                # CHECK: %[[INIT:.*]] = tensor.empty
+                init = tensor.EmptyOp([4, 8], f32)
+                # CHECK: linalg.matmul
+                # CHECK: ins(%[[LHS]], %[[RHS]]
+                # CHECK: outs(%[[INIT]]
+                return linalg.matmul(lhs, rhs, outs=init)
+
+    print(module)

diff  --git a/mlir/test/python/dialects/math_dialect.py b/mlir/test/python/dialects/math_dialect.py
index 04b6d848c7422..3d402c54a11e3 100644
--- a/mlir/test/python/dialects/math_dialect.py
+++ b/mlir/test/python/dialects/math_dialect.py
@@ -7,23 +7,26 @@
 import mlir.dialects.func as func
 import mlir.dialects.math as mlir_math
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
+
 
 # CHECK-LABEL: TEST: testMathOps
 @run
 def testMathOps():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      @func.FuncOp.from_py_func(F32Type.get())
-      def emit_sqrt(arg):
-        return mlir_math.SqrtOp(arg)
-
-    # CHECK-LABEL: func @emit_sqrt(
-    # CHECK-SAME:                  %[[ARG:.*]]: f32) -> f32 {
-    # CHECK:         math.sqrt %[[ARG]] : f32
-    # CHECK:         return
-    # CHECK:       }
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(F32Type.get())
+            def emit_sqrt(arg):
+                return mlir_math.SqrtOp(arg)
+
+        # CHECK-LABEL: func @emit_sqrt(
+        # CHECK-SAME:                  %[[ARG:.*]]: f32) -> f32 {
+        # CHECK:         math.sqrt %[[ARG]] : f32
+        # CHECK:         return
+        # CHECK:       }
+        print(module)

diff  --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 59092feffba23..2e3cae671a9f1 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -6,17 +6,17 @@
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testSubViewAccessors
 @run
 def testSubViewAccessors():
-  ctx = Context()
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    module = Module.parse(
+        r"""
     func.func @f1(%arg0: memref<?x?xf32>) {
       %0 = arith.constant 0 : index
       %1 = arith.constant 1 : index
@@ -27,48 +27,52 @@ def testSubViewAccessors():
       memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
       return
     }
-  """, ctx)
-  func_body = module.body.operations[0].regions[0].blocks[0]
-  subview = func_body.operations[6]
+  """,
+        ctx,
+    )
+    func_body = module.body.operations[0].regions[0].blocks[0]
+    subview = func_body.operations[6]
 
-  assert subview.source == subview.operands[0]
-  assert len(subview.offsets) == 2
-  assert len(subview.sizes) == 2
-  assert len(subview.strides) == 2
-  assert subview.result == subview.results[0]
+    assert subview.source == subview.operands[0]
+    assert len(subview.offsets) == 2
+    assert len(subview.sizes) == 2
+    assert len(subview.strides) == 2
+    assert subview.result == subview.results[0]
 
-  # CHECK: SubViewOp
-  print(type(subview).__name__)
+    # CHECK: SubViewOp
+    print(type(subview).__name__)
 
-  # CHECK: constant 0
-  print(subview.offsets[0])
-  # CHECK: constant 1
-  print(subview.offsets[1])
-  # CHECK: constant 2
-  print(subview.sizes[0])
-  # CHECK: constant 3
-  print(subview.sizes[1])
-  # CHECK: constant 4
-  print(subview.strides[0])
-  # CHECK: constant 5
-  print(subview.strides[1])
+    # CHECK: constant 0
+    print(subview.offsets[0])
+    # CHECK: constant 1
+    print(subview.offsets[1])
+    # CHECK: constant 2
+    print(subview.sizes[0])
+    # CHECK: constant 3
+    print(subview.sizes[1])
+    # CHECK: constant 4
+    print(subview.strides[0])
+    # CHECK: constant 5
+    print(subview.strides[1])
 
 
 # CHECK-LABEL: TEST: testCustomBuidlers
 @run
 def testCustomBuidlers():
-  with Context() as ctx, Location.unknown(ctx):
-    module = Module.parse(r"""
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.parse(
+            r"""
       func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
         return
       }
-    """)
-    f = module.body.operations[0]
-    func_body = f.regions[0].blocks[0]
-    with InsertionPoint.at_block_terminator(func_body):
-      memref.LoadOp(f.arguments[0], f.arguments[1:])
+    """
+        )
+        f = module.body.operations[0]
+        func_body = f.regions[0].blocks[0]
+        with InsertionPoint.at_block_terminator(func_body):
+            memref.LoadOp(f.arguments[0], f.arguments[1:])
 
-    # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-    # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
-    print(module)
-    assert module.operation.verify()
+        # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+        # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+        print(module)
+        assert module.operation.verify()

diff  --git a/mlir/test/python/dialects/ml_program.py b/mlir/test/python/dialects/ml_program.py
index 4d9804ff34732..f16de2add3799 100644
--- a/mlir/test/python/dialects/ml_program.py
+++ b/mlir/test/python/dialects/ml_program.py
@@ -6,23 +6,23 @@
 
 
 def constructAndPrintInModule(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      f()
-    print(module)
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 # CHECK-LABEL: testFuncOp
 @constructAndPrintInModule
 def testFuncOp():
-  # CHECK: ml_program.func @foobar(%arg0: si32) -> si32
-  f = ml_program.FuncOp(
-      name="foobar",
-      type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)]))
-  block = f.add_entry_block()
-  with InsertionPoint(block):
-    # CHECK: ml_program.return
-    ml_program.ReturnOp([block.arguments[0]])
+    # CHECK: ml_program.func @foobar(%arg0: si32) -> si32
+    f = ml_program.FuncOp(
+        name="foobar", type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)])
+    )
+    block = f.add_entry_block()
+    with InsertionPoint(block):
+        # CHECK: ml_program.return
+        ml_program.ReturnOp([block.arguments[0]])

diff  --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 802a1f271c565..71879bdcb51f5 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -6,207 +6,205 @@
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
 
 
 def add_dummy_value():
-  return Operation.create(
-      "custom.value",
-      results=[IntegerType.get_signless(32)]).result
+    return Operation.create(
+        "custom.value", results=[IntegerType.get_signless(32)]
+    ).result
 
 
 def testOdsBuildDefaultImplicitRegions():
-
-  class TestFixedRegionsOp(OpView):
-    OPERATION_NAME = "custom.test_op"
-    _ODS_REGIONS = (2, True)
-
-  class TestVariadicRegionsOp(OpView):
-    OPERATION_NAME = "custom.test_any_regions_op"
-    _ODS_REGIONS = (2, False)
-
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    m = Module.create()
-    with InsertionPoint(m.body):
-      op = TestFixedRegionsOp.build_generic(results=[], operands=[])
-      # CHECK: NUM_REGIONS: 2
-      print(f"NUM_REGIONS: {len(op.regions)}")
-      # Including a regions= that matches should be fine.
-      op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
-      print(f"NUM_REGIONS: {len(op.regions)}")
-      # Reject greater than.
-      try:
-        op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=3)
-      except ValueError as e:
-        # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
-        print(f"ERROR:{e}")
-      # Reject less than.
-      try:
-        op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=1)
-      except ValueError as e:
-        # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
-        print(f"ERROR:{e}")
-
-      # If no regions specified for a variadic region op, build the minimum.
-      op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
-      # CHECK: DEFAULT_NUM_REGIONS: 2
-      print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
-      # Should also accept an explicit regions= that matches the minimum.
-      op = TestVariadicRegionsOp.build_generic(
-          results=[], operands=[], regions=2)
-      # CHECK: EQ_NUM_REGIONS: 2
-      print(f"EQ_NUM_REGIONS: {len(op.regions)}")
-      # And accept greater than minimum.
-      # Should also accept an explicit regions= that matches the minimum.
-      op = TestVariadicRegionsOp.build_generic(
-          results=[], operands=[], regions=3)
-      # CHECK: GT_NUM_REGIONS: 3
-      print(f"GT_NUM_REGIONS: {len(op.regions)}")
-      # Should reject less than minimum.
-      try:
-        op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=1)
-      except ValueError as e:
-        # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
-        print(f"ERROR:{e}")
-
+    class TestFixedRegionsOp(OpView):
+        OPERATION_NAME = "custom.test_op"
+        _ODS_REGIONS = (2, True)
+
+    class TestVariadicRegionsOp(OpView):
+        OPERATION_NAME = "custom.test_any_regions_op"
+        _ODS_REGIONS = (2, False)
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            op = TestFixedRegionsOp.build_generic(results=[], operands=[])
+            # CHECK: NUM_REGIONS: 2
+            print(f"NUM_REGIONS: {len(op.regions)}")
+            # Including a regions= that matches should be fine.
+            op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
+            print(f"NUM_REGIONS: {len(op.regions)}")
+            # Reject greater than.
+            try:
+                op = TestFixedRegionsOp.build_generic(
+                    results=[], operands=[], regions=3
+                )
+            except ValueError as e:
+                # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
+                print(f"ERROR:{e}")
+            # Reject less than.
+            try:
+                op = TestFixedRegionsOp.build_generic(
+                    results=[], operands=[], regions=1
+                )
+            except ValueError as e:
+                # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
+                print(f"ERROR:{e}")
+
+            # If no regions specified for a variadic region op, build the minimum.
+            op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
+            # CHECK: DEFAULT_NUM_REGIONS: 2
+            print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
+            # Should also accept an explicit regions= that matches the minimum.
+            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2)
+            # CHECK: EQ_NUM_REGIONS: 2
+            print(f"EQ_NUM_REGIONS: {len(op.regions)}")
+            # And accept greater than minimum.
+            # Should also accept an explicit regions= that matches the minimum.
+            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3)
+            # CHECK: GT_NUM_REGIONS: 3
+            print(f"GT_NUM_REGIONS: {len(op.regions)}")
+            # Should reject less than minimum.
+            try:
+                op = TestVariadicRegionsOp.build_generic(
+                    results=[], operands=[], regions=1
+                )
+            except ValueError as e:
+                # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
+                print(f"ERROR:{e}")
 
 
 run(testOdsBuildDefaultImplicitRegions)
 
 
 def testOdsBuildDefaultNonVariadic():
+    class TestOp(OpView):
+        OPERATION_NAME = "custom.test_op"
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v0 = add_dummy_value()
+            v1 = add_dummy_value()
+            t0 = IntegerType.get_signless(8)
+            t1 = IntegerType.get_signless(16)
+            op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
+            # CHECK: %[[V0:.+]] = "custom.value"
+            # CHECK: %[[V1:.+]] = "custom.value"
+            # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
+            # CHECK-NOT: operand_segment_sizes
+            # CHECK-NOT: result_segment_sizes
+            # CHECK-SAME: : (i32, i32) -> (i8, i16)
+            print(m)
 
-  class TestOp(OpView):
-    OPERATION_NAME = "custom.test_op"
-
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    m = Module.create()
-    with InsertionPoint(m.body):
-      v0 = add_dummy_value()
-      v1 = add_dummy_value()
-      t0 = IntegerType.get_signless(8)
-      t1 = IntegerType.get_signless(16)
-      op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
-      # CHECK: %[[V0:.+]] = "custom.value"
-      # CHECK: %[[V1:.+]] = "custom.value"
-      # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
-      # CHECK-NOT: operand_segment_sizes
-      # CHECK-NOT: result_segment_sizes
-      # CHECK-SAME: : (i32, i32) -> (i8, i16)
-      print(m)
 
 run(testOdsBuildDefaultNonVariadic)
 
 
 def testOdsBuildDefaultSizedVariadic():
+    class TestOp(OpView):
+        OPERATION_NAME = "custom.test_op"
+        _ODS_OPERAND_SEGMENTS = [1, -1, 0]
+        _ODS_RESULT_SEGMENTS = [-1, 0, 1]
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v0 = add_dummy_value()
+            v1 = add_dummy_value()
+            v2 = add_dummy_value()
+            v3 = add_dummy_value()
+            t0 = IntegerType.get_signless(8)
+            t1 = IntegerType.get_signless(16)
+            t2 = IntegerType.get_signless(32)
+            t3 = IntegerType.get_signless(64)
+            # CHECK: %[[V0:.+]] = "custom.value"
+            # CHECK: %[[V1:.+]] = "custom.value"
+            # CHECK: %[[V2:.+]] = "custom.value"
+            # CHECK: %[[V3:.+]] = "custom.value"
+            # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
+            # CHECK-SAME: operand_segment_sizes = array<i32: 1, 2, 1>
+            # CHECK-SAME: result_segment_sizes = array<i32: 2, 1, 1>
+            # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
+            op = TestOp.build_generic(
+                results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3]
+            )
+
+            # Now test with optional omitted.
+            # CHECK: "custom.test_op"(%[[V0]])
+            # CHECK-SAME: operand_segment_sizes = array<i32: 1, 0, 0>
+            # CHECK-SAME: result_segment_sizes = array<i32: 0, 0, 1>
+            # CHECK-SAME: (i32) -> i64
+            op = TestOp.build_generic(
+                results=[None, None, t3], operands=[v0, None, None]
+            )
+            print(m)
+
+            # And verify that errors are raised for None in a required operand.
+            try:
+                op = TestOp.build_generic(
+                    results=[None, None, t3], operands=[None, None, None]
+                )
+            except ValueError as e:
+                # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
+                print(f"OPERAND_CAST_ERROR:{e}")
+
+            # And verify that errors are raised for None in a required result.
+            try:
+                op = TestOp.build_generic(
+                    results=[None, None, None], operands=[v0, None, None]
+                )
+            except ValueError as e:
+                # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
+                print(f"RESULT_CAST_ERROR:{e}")
+
+            # Variadic lists with None elements should reject.
+            try:
+                op = TestOp.build_generic(
+                    results=[None, None, t3], operands=[v0, [None], None]
+                )
+            except ValueError as e:
+                # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
+                print(f"OPERAND_LIST_CAST_ERROR:{e}")
+            try:
+                op = TestOp.build_generic(
+                    results=[[None], None, t3], operands=[v0, None, None]
+                )
+            except ValueError as e:
+                # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
+                print(f"RESULT_LIST_CAST_ERROR:{e}")
 
-  class TestOp(OpView):
-    OPERATION_NAME = "custom.test_op"
-    _ODS_OPERAND_SEGMENTS = [1, -1, 0]
-    _ODS_RESULT_SEGMENTS = [-1, 0, 1]
-
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    m = Module.create()
-    with InsertionPoint(m.body):
-      v0 = add_dummy_value()
-      v1 = add_dummy_value()
-      v2 = add_dummy_value()
-      v3 = add_dummy_value()
-      t0 = IntegerType.get_signless(8)
-      t1 = IntegerType.get_signless(16)
-      t2 = IntegerType.get_signless(32)
-      t3 = IntegerType.get_signless(64)
-      # CHECK: %[[V0:.+]] = "custom.value"
-      # CHECK: %[[V1:.+]] = "custom.value"
-      # CHECK: %[[V2:.+]] = "custom.value"
-      # CHECK: %[[V3:.+]] = "custom.value"
-      # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
-      # CHECK-SAME: operand_segment_sizes = array<i32: 1, 2, 1>
-      # CHECK-SAME: result_segment_sizes = array<i32: 2, 1, 1>
-      # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
-      op = TestOp.build_generic(
-          results=[[t0, t1], t2, t3],
-          operands=[v0, [v1, v2], v3])
-
-      # Now test with optional omitted.
-      # CHECK: "custom.test_op"(%[[V0]])
-      # CHECK-SAME: operand_segment_sizes = array<i32: 1, 0, 0>
-      # CHECK-SAME: result_segment_sizes = array<i32: 0, 0, 1>
-      # CHECK-SAME: (i32) -> i64
-      op = TestOp.build_generic(
-          results=[None, None, t3],
-          operands=[v0, None, None])
-      print(m)
-
-      # And verify that errors are raised for None in a required operand.
-      try:
-        op = TestOp.build_generic(
-            results=[None, None, t3],
-            operands=[None, None, None])
-      except ValueError as e:
-        # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
-        print(f"OPERAND_CAST_ERROR:{e}")
-
-      # And verify that errors are raised for None in a required result.
-      try:
-        op = TestOp.build_generic(
-            results=[None, None, None],
-            operands=[v0, None, None])
-      except ValueError as e:
-        # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
-        print(f"RESULT_CAST_ERROR:{e}")
-
-      # Variadic lists with None elements should reject.
-      try:
-        op = TestOp.build_generic(
-            results=[None, None, t3],
-            operands=[v0, [None], None])
-      except ValueError as e:
-        # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
-        print(f"OPERAND_LIST_CAST_ERROR:{e}")
-      try:
-        op = TestOp.build_generic(
-            results=[[None], None, t3],
-            operands=[v0, None, None])
-      except ValueError as e:
-        # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
-        print(f"RESULT_LIST_CAST_ERROR:{e}")
 
 run(testOdsBuildDefaultSizedVariadic)
 
 
 def testOdsBuildDefaultCastError():
+    class TestOp(OpView):
+        OPERATION_NAME = "custom.test_op"
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v0 = add_dummy_value()
+            v1 = add_dummy_value()
+            t0 = IntegerType.get_signless(8)
+            t1 = IntegerType.get_signless(16)
+            try:
+                op = TestOp.build_generic(results=[t0, t1], operands=[None, v1])
+            except ValueError as e:
+                # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
+                print(f"ERROR: {e}")
+            try:
+                op = TestOp.build_generic(results=[t0, None], operands=[v0, v1])
+            except ValueError as e:
+                # CHECK: Result 1 of operation "custom.test_op" must be a Type
+                print(f"ERROR: {e}")
 
-  class TestOp(OpView):
-    OPERATION_NAME = "custom.test_op"
-
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    m = Module.create()
-    with InsertionPoint(m.body):
-      v0 = add_dummy_value()
-      v1 = add_dummy_value()
-      t0 = IntegerType.get_signless(8)
-      t1 = IntegerType.get_signless(16)
-      try:
-        op = TestOp.build_generic(
-            results=[t0, t1],
-            operands=[None, v1])
-      except ValueError as e:
-        # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
-        print(f"ERROR: {e}")
-      try:
-        op = TestOp.build_generic(
-            results=[t0, None],
-            operands=[v0, v1])
-      except ValueError as e:
-        # CHECK: Result 1 of operation "custom.test_op" must be a Type
-        print(f"ERROR: {e}")
 
 run(testOdsBuildDefaultCastError)

diff  --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py
index 3d9cd19278997..0d364f9222a65 100644
--- a/mlir/test/python/dialects/pdl_ops.py
+++ b/mlir/test/python/dialects/pdl_ops.py
@@ -5,13 +5,13 @@
 
 
 def constructAndPrintInModule(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      f()
-    print(module)
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 # CHECK: module  {
@@ -27,15 +27,15 @@ def constructAndPrintInModule(f):
 # CHECK: }
 @constructAndPrintInModule
 def test_operations():
-  pattern = PatternOp(1, "operations")
-  with InsertionPoint(pattern.body):
-    attr = AttributeOp()
-    ty = TypeOp()
-    op0 = OperationOp(attributes={"attr": attr}, types=[ty])
-    op0_result = ResultOp(op0, 0)
-    input = OperandOp()
-    root = OperationOp(args=[op0_result, input])
-    RewriteOp(root, "rewriter")
+    pattern = PatternOp(1, "operations")
+    with InsertionPoint(pattern.body):
+        attr = AttributeOp()
+        ty = TypeOp()
+        op0 = OperationOp(attributes={"attr": attr}, types=[ty])
+        op0_result = ResultOp(op0, 0)
+        input = OperandOp()
+        root = OperationOp(args=[op0_result, input])
+        RewriteOp(root, "rewriter")
 
 
 # CHECK: module  {
@@ -47,11 +47,12 @@ def test_operations():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_with_args():
-  pattern = PatternOp(1, "rewrite_with_args")
-  with InsertionPoint(pattern.body):
-    input = OperandOp()
-    root = OperationOp(args=[input])
-    RewriteOp(root, "rewriter", args=[input])
+    pattern = PatternOp(1, "rewrite_with_args")
+    with InsertionPoint(pattern.body):
+        input = OperandOp()
+        root = OperationOp(args=[input])
+        RewriteOp(root, "rewriter", args=[input])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_multi_root_optimal : benefit(1)  {
@@ -69,18 +70,19 @@ def test_rewrite_with_args():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_multi_root_optimal():
-  pattern = PatternOp(1, "rewrite_multi_root_optimal")
-  with InsertionPoint(pattern.body):
-    input1 = OperandOp()
-    input2 = OperandOp()
-    ty = TypeOp()
-    op1 = OperationOp(args=[input1], types=[ty])
-    val1 = ResultOp(op1, 0)
-    root1 = OperationOp(args=[val1])
-    op2 = OperationOp(args=[input2], types=[ty])
-    val2 = ResultOp(op2, 0)
-    root2 = OperationOp(args=[val1, val2])
-    RewriteOp(name="rewriter", args=[root1, root2])
+    pattern = PatternOp(1, "rewrite_multi_root_optimal")
+    with InsertionPoint(pattern.body):
+        input1 = OperandOp()
+        input2 = OperandOp()
+        ty = TypeOp()
+        op1 = OperationOp(args=[input1], types=[ty])
+        val1 = ResultOp(op1, 0)
+        root1 = OperationOp(args=[val1])
+        op2 = OperationOp(args=[input2], types=[ty])
+        val2 = ResultOp(op2, 0)
+        root2 = OperationOp(args=[val1, val2])
+        RewriteOp(name="rewriter", args=[root1, root2])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_multi_root_forced : benefit(1)  {
@@ -98,18 +100,19 @@ def test_rewrite_multi_root_optimal():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_multi_root_forced():
-  pattern = PatternOp(1, "rewrite_multi_root_forced")
-  with InsertionPoint(pattern.body):
-    input1 = OperandOp()
-    input2 = OperandOp()
-    ty = TypeOp()
-    op1 = OperationOp(args=[input1], types=[ty])
-    val1 = ResultOp(op1, 0)
-    root1 = OperationOp(args=[val1])
-    op2 = OperationOp(args=[input2], types=[ty])
-    val2 = ResultOp(op2, 0)
-    root2 = OperationOp(args=[val1, val2])
-    RewriteOp(root1, name="rewriter", args=[root2])
+    pattern = PatternOp(1, "rewrite_multi_root_forced")
+    with InsertionPoint(pattern.body):
+        input1 = OperandOp()
+        input2 = OperandOp()
+        ty = TypeOp()
+        op1 = OperationOp(args=[input1], types=[ty])
+        val1 = ResultOp(op1, 0)
+        root1 = OperationOp(args=[val1])
+        op2 = OperationOp(args=[input2], types=[ty])
+        val2 = ResultOp(op2, 0)
+        root2 = OperationOp(args=[val1, val2])
+        RewriteOp(root1, name="rewriter", args=[root2])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_add_body : benefit(1)  {
@@ -125,16 +128,17 @@ def test_rewrite_multi_root_forced():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_add_body():
-  pattern = PatternOp(1, "rewrite_add_body")
-  with InsertionPoint(pattern.body):
-    ty1 = TypeOp(IntegerType.get_signless(32))
-    ty2 = TypeOp()
-    root = OperationOp(types=[ty1, ty2])
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      ty3 = TypeOp()
-      newOp = OperationOp(name="foo.op", types=[ty1, ty3])
-      ReplaceOp(root, with_op=newOp)
+    pattern = PatternOp(1, "rewrite_add_body")
+    with InsertionPoint(pattern.body):
+        ty1 = TypeOp(IntegerType.get_signless(32))
+        ty2 = TypeOp()
+        root = OperationOp(types=[ty1, ty2])
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            ty3 = TypeOp()
+            newOp = OperationOp(name="foo.op", types=[ty1, ty3])
+            ReplaceOp(root, with_op=newOp)
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_type : benefit(1)  {
@@ -148,14 +152,15 @@ def test_rewrite_add_body():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_type():
-  pattern = PatternOp(1, "rewrite_type")
-  with InsertionPoint(pattern.body):
-    ty1 = TypeOp(IntegerType.get_signless(32))
-    ty2 = TypeOp()
-    root = OperationOp(types=[ty1, ty2])
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      newOp = OperationOp(name="foo.op", types=[ty1, ty2])
+    pattern = PatternOp(1, "rewrite_type")
+    with InsertionPoint(pattern.body):
+        ty1 = TypeOp(IntegerType.get_signless(32))
+        ty2 = TypeOp()
+        root = OperationOp(types=[ty1, ty2])
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            newOp = OperationOp(name="foo.op", types=[ty1, ty2])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_types : benefit(1)  {
@@ -169,14 +174,17 @@ def test_rewrite_type():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_types():
-  pattern = PatternOp(1, "rewrite_types")
-  with InsertionPoint(pattern.body):
-    types = TypesOp()
-    root = OperationOp(types=[types])
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)])
-      newOp = OperationOp(name="foo.op", types=[types, otherTypes])
+    pattern = PatternOp(1, "rewrite_types")
+    with InsertionPoint(pattern.body):
+        types = TypesOp()
+        root = OperationOp(types=[types])
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            otherTypes = TypesOp(
+                [IntegerType.get_signless(32), IntegerType.get_signless(64)]
+            )
+            newOp = OperationOp(name="foo.op", types=[types, otherTypes])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_operands : benefit(1)  {
@@ -190,14 +198,15 @@ def test_rewrite_types():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_operands():
-  pattern = PatternOp(1, "rewrite_operands")
-  with InsertionPoint(pattern.body):
-    types = TypesOp()
-    operands = OperandsOp(types)
-    root = OperationOp(args=[operands])
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      newOp = OperationOp(name="foo.op", types=[types])
+    pattern = PatternOp(1, "rewrite_operands")
+    with InsertionPoint(pattern.body):
+        types = TypesOp()
+        operands = OperandsOp(types)
+        root = OperationOp(args=[operands])
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            newOp = OperationOp(name="foo.op", types=[types])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @native_rewrite : benefit(1)  {
@@ -209,12 +218,13 @@ def test_rewrite_operands():
 # CHECK: }
 @constructAndPrintInModule
 def test_native_rewrite():
-  pattern = PatternOp(1, "native_rewrite")
-  with InsertionPoint(pattern.body):
-    root = OperationOp()
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
+    pattern = PatternOp(1, "native_rewrite")
+    with InsertionPoint(pattern.body):
+        root = OperationOp()
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @attribute_with_value : benefit(1)  {
@@ -227,13 +237,14 @@ def test_native_rewrite():
 # CHECK: }
 @constructAndPrintInModule
 def test_attribute_with_value():
-  pattern = PatternOp(1, "attribute_with_value")
-  with InsertionPoint(pattern.body):
-    root = OperationOp()
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      attr = AttributeOp(value=Attribute.parse('"value"'))
-      ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
+    pattern = PatternOp(1, "attribute_with_value")
+    with InsertionPoint(pattern.body):
+        root = OperationOp()
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            attr = AttributeOp(value=Attribute.parse('"value"'))
+            ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @erase : benefit(1)  {
@@ -245,12 +256,13 @@ def test_attribute_with_value():
 # CHECK: }
 @constructAndPrintInModule
 def test_erase():
-  pattern = PatternOp(1, "erase")
-  with InsertionPoint(pattern.body):
-    root = OperationOp()
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      EraseOp(root)
+    pattern = PatternOp(1, "erase")
+    with InsertionPoint(pattern.body):
+        root = OperationOp()
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            EraseOp(root)
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @operation_results : benefit(1)  {
@@ -263,14 +275,15 @@ def test_erase():
 # CHECK: }
 @constructAndPrintInModule
 def test_operation_results():
-  valueRange = RangeType.get(ValueType.get())
-  pattern = PatternOp(1, "operation_results")
-  with InsertionPoint(pattern.body):
-    types = TypesOp()
-    inputOp = OperationOp(types=[types])
-    results = ResultsOp(valueRange, inputOp)
-    root = OperationOp(args=[results])
-    RewriteOp(root, name="rewriter")
+    valueRange = RangeType.get(ValueType.get())
+    pattern = PatternOp(1, "operation_results")
+    with InsertionPoint(pattern.body):
+        types = TypesOp()
+        inputOp = OperationOp(types=[types])
+        results = ResultsOp(valueRange, inputOp)
+        root = OperationOp(args=[results])
+        RewriteOp(root, name="rewriter")
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern : benefit(1)  {
@@ -282,9 +295,9 @@ def test_operation_results():
 # CHECK: }
 @constructAndPrintInModule
 def test_apply_native_constraint():
-  pattern = PatternOp(1)
-  with InsertionPoint(pattern.body):
-    resultType = TypeOp()
-    ApplyNativeConstraintOp("typeConstraint", args=[resultType])
-    root = OperationOp(types=[resultType])
-    RewriteOp(root, name="rewrite")
+    pattern = PatternOp(1)
+    with InsertionPoint(pattern.body):
+        resultType = TypeOp()
+        ApplyNativeConstraintOp("typeConstraint", args=[resultType])
+        root = OperationOp(types=[resultType])
+        RewriteOp(root, name="rewrite")

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 2ca79b29f567c..72a765c75e52c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -5,367 +5,373 @@
 import mlir.dialects.python_test as test
 import mlir.dialects.tensor as tensor
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
+
 
 # CHECK-LABEL: TEST: testAttributes
 @run
 def testAttributes():
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-
-    #
-    # Check op construction with attributes.
-    #
-
-    i32 = IntegerType.get_signless(32)
-    one = IntegerAttr.get(i32, 1)
-    two = IntegerAttr.get(i32, 2)
-    unit = UnitAttr.get()
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK-DAG: mandatory_i32 = 1 : i32
-    # CHECK-DAG: optional_i32 = 2 : i32
-    # CHECK-DAG: unit
-    # CHECK: }
-    op = test.AttributedOp(one, optional_i32=two, unit=unit)
-    print(f"{op}")
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK: mandatory_i32 = 2 : i32
-    # CHECK: }
-    op2 = test.AttributedOp(two)
-    print(f"{op2}")
-
-    #
-    # Check generic "attributes" access and mutation.
-    #
-
-    assert "additional" not in op.attributes
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK-DAG: additional = 1 : i32
-    # CHECK-DAG: mandatory_i32 = 2 : i32
-    # CHECK: }
-    op2.attributes["additional"] = one
-    print(f"{op2}")
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK-DAG: additional = 2 : i32
-    # CHECK-DAG: mandatory_i32 = 2 : i32
-    # CHECK: }
-    op2.attributes["additional"] = two
-    print(f"{op2}")
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK-NOT: additional = 2 : i32
-    # CHECK:     mandatory_i32 = 2 : i32
-    # CHECK: }
-    del op2.attributes["additional"]
-    print(f"{op2}")
-
-    try:
-      print(op.attributes["additional"])
-    except KeyError:
-      pass
-    else:
-      assert False, "expected KeyError on unknown attribute key"
-
-    #
-    # Check accessors to defined attributes.
-    #
-
-    # CHECK: Mandatory: 1
-    # CHECK: Optional: 2
-    # CHECK: Unit: True
-    print(f"Mandatory: {op.mandatory_i32.value}")
-    print(f"Optional: {op.optional_i32.value}")
-    print(f"Unit: {op.unit}")
-
-    # CHECK: Mandatory: 2
-    # CHECK: Optional: None
-    # CHECK: Unit: False
-    print(f"Mandatory: {op2.mandatory_i32.value}")
-    print(f"Optional: {op2.optional_i32}")
-    print(f"Unit: {op2.unit}")
-
-    # CHECK: Mandatory: 2
-    # CHECK: Optional: None
-    # CHECK: Unit: False
-    op.mandatory_i32 = two
-    op.optional_i32 = None
-    op.unit = False
-    print(f"Mandatory: {op.mandatory_i32.value}")
-    print(f"Optional: {op.optional_i32}")
-    print(f"Unit: {op.unit}")
-    assert "optional_i32" not in op.attributes
-    assert "unit" not in op.attributes
-
-    try:
-      op.mandatory_i32 = None
-    except ValueError:
-      pass
-    else:
-      assert False, "expected ValueError on setting a mandatory attribute to None"
-
-    # CHECK: Optional: 2
-    op.optional_i32 = two
-    print(f"Optional: {op.optional_i32.value}")
-
-    # CHECK: Optional: None
-    del op.optional_i32
-    print(f"Optional: {op.optional_i32}")
-
-    # CHECK: Unit: False
-    op.unit = None
-    print(f"Unit: {op.unit}")
-    assert "unit" not in op.attributes
-
-    # CHECK: Unit: True
-    op.unit = True
-    print(f"Unit: {op.unit}")
-
-    # CHECK: Unit: False
-    del op.unit
-    print(f"Unit: {op.unit}")
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+
+        #
+        # Check op construction with attributes.
+        #
+
+        i32 = IntegerType.get_signless(32)
+        one = IntegerAttr.get(i32, 1)
+        two = IntegerAttr.get(i32, 2)
+        unit = UnitAttr.get()
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK-DAG: mandatory_i32 = 1 : i32
+        # CHECK-DAG: optional_i32 = 2 : i32
+        # CHECK-DAG: unit
+        # CHECK: }
+        op = test.AttributedOp(one, optional_i32=two, unit=unit)
+        print(f"{op}")
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK: mandatory_i32 = 2 : i32
+        # CHECK: }
+        op2 = test.AttributedOp(two)
+        print(f"{op2}")
+
+        #
+        # Check generic "attributes" access and mutation.
+        #
+
+        assert "additional" not in op.attributes
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK-DAG: additional = 1 : i32
+        # CHECK-DAG: mandatory_i32 = 2 : i32
+        # CHECK: }
+        op2.attributes["additional"] = one
+        print(f"{op2}")
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK-DAG: additional = 2 : i32
+        # CHECK-DAG: mandatory_i32 = 2 : i32
+        # CHECK: }
+        op2.attributes["additional"] = two
+        print(f"{op2}")
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK-NOT: additional = 2 : i32
+        # CHECK:     mandatory_i32 = 2 : i32
+        # CHECK: }
+        del op2.attributes["additional"]
+        print(f"{op2}")
+
+        try:
+            print(op.attributes["additional"])
+        except KeyError:
+            pass
+        else:
+            assert False, "expected KeyError on unknown attribute key"
+
+        #
+        # Check accessors to defined attributes.
+        #
+
+        # CHECK: Mandatory: 1
+        # CHECK: Optional: 2
+        # CHECK: Unit: True
+        print(f"Mandatory: {op.mandatory_i32.value}")
+        print(f"Optional: {op.optional_i32.value}")
+        print(f"Unit: {op.unit}")
+
+        # CHECK: Mandatory: 2
+        # CHECK: Optional: None
+        # CHECK: Unit: False
+        print(f"Mandatory: {op2.mandatory_i32.value}")
+        print(f"Optional: {op2.optional_i32}")
+        print(f"Unit: {op2.unit}")
+
+        # CHECK: Mandatory: 2
+        # CHECK: Optional: None
+        # CHECK: Unit: False
+        op.mandatory_i32 = two
+        op.optional_i32 = None
+        op.unit = False
+        print(f"Mandatory: {op.mandatory_i32.value}")
+        print(f"Optional: {op.optional_i32}")
+        print(f"Unit: {op.unit}")
+        assert "optional_i32" not in op.attributes
+        assert "unit" not in op.attributes
+
+        try:
+            op.mandatory_i32 = None
+        except ValueError:
+            pass
+        else:
+            assert False, "expected ValueError on setting a mandatory attribute to None"
+
+        # CHECK: Optional: 2
+        op.optional_i32 = two
+        print(f"Optional: {op.optional_i32.value}")
+
+        # CHECK: Optional: None
+        del op.optional_i32
+        print(f"Optional: {op.optional_i32}")
+
+        # CHECK: Unit: False
+        op.unit = None
+        print(f"Unit: {op.unit}")
+        assert "unit" not in op.attributes
+
+        # CHECK: Unit: True
+        op.unit = True
+        print(f"Unit: {op.unit}")
+
+        # CHECK: Unit: False
+        del op.unit
+        print(f"Unit: {op.unit}")
+
 
 # CHECK-LABEL: TEST: attrBuilder
 @run
 def attrBuilder():
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    op = test.AttributesOp(x_bool=True,
-                           x_i16=1,
-                           x_i32=2,
-                           x_i64=3,
-                           x_si16=-1,
-                           x_si32=-2,
-                           x_f32=1.5,
-                           x_f64=2.5,
-                           x_str='x_str',
-                           x_i32_array=[1, 2, 3],
-                           x_i64_array=[4, 5, 6],
-                           x_f32_array=[1.5, -2.5, 3.5],
-                           x_f64_array=[4.5, 5.5, -6.5],
-                           x_i64_dense=[1, 2, 3, 4, 5, 6])
-    print(op)
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        op = test.AttributesOp(
+            x_bool=True,
+            x_i16=1,
+            x_i32=2,
+            x_i64=3,
+            x_si16=-1,
+            x_si32=-2,
+            x_f32=1.5,
+            x_f64=2.5,
+            x_str="x_str",
+            x_i32_array=[1, 2, 3],
+            x_i64_array=[4, 5, 6],
+            x_f32_array=[1.5, -2.5, 3.5],
+            x_f64_array=[4.5, 5.5, -6.5],
+            x_i64_dense=[1, 2, 3, 4, 5, 6],
+        )
+        print(op)
 
 
 # CHECK-LABEL: TEST: inferReturnTypes
 @run
 def inferReturnTypes():
-  with Context() as ctx, Location.unknown(ctx):
-    test.register_python_test_dialect(ctx)
-    module = Module.create()
-    with InsertionPoint(module.body):
-      op = test.InferResultsOp()
-      dummy = test.DummyOp()
-
-    # CHECK: [Type(i32), Type(i64)]
-    iface = InferTypeOpInterface(op)
-    print(iface.inferReturnTypes())
-
-    # CHECK: [Type(i32), Type(i64)]
-    iface_static = InferTypeOpInterface(test.InferResultsOp)
-    print(iface.inferReturnTypes())
-
-    assert isinstance(iface.opview, test.InferResultsOp)
-    assert iface.opview == iface.operation.opview
-
-    try:
-      iface_static.opview
-    except TypeError:
-      pass
-    else:
-      assert False, ("not expected to be able to obtain an opview from a static"
-                     " interface")
-
-    try:
-      InferTypeOpInterface(dummy)
-    except ValueError:
-      pass
-    else:
-      assert False, "not expected dummy op to implement the interface"
-
-    try:
-      InferTypeOpInterface(test.DummyOp)
-    except ValueError:
-      pass
-    else:
-      assert False, "not expected dummy op class to implement the interface"
+    with Context() as ctx, Location.unknown(ctx):
+        test.register_python_test_dialect(ctx)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            op = test.InferResultsOp()
+            dummy = test.DummyOp()
+
+        # CHECK: [Type(i32), Type(i64)]
+        iface = InferTypeOpInterface(op)
+        print(iface.inferReturnTypes())
+
+        # CHECK: [Type(i32), Type(i64)]
+        iface_static = InferTypeOpInterface(test.InferResultsOp)
+        print(iface.inferReturnTypes())
+
+        assert isinstance(iface.opview, test.InferResultsOp)
+        assert iface.opview == iface.operation.opview
+
+        try:
+            iface_static.opview
+        except TypeError:
+            pass
+        else:
+            assert False, (
+                "not expected to be able to obtain an opview from a static" " interface"
+            )
+
+        try:
+            InferTypeOpInterface(dummy)
+        except ValueError:
+            pass
+        else:
+            assert False, "not expected dummy op to implement the interface"
+
+        try:
+            InferTypeOpInterface(test.DummyOp)
+        except ValueError:
+            pass
+        else:
+            assert False, "not expected dummy op class to implement the interface"
 
 
 # CHECK-LABEL: TEST: resultTypesDefinedByTraits
 @run
 def resultTypesDefinedByTraits():
-  with Context() as ctx, Location.unknown(ctx):
-    test.register_python_test_dialect(ctx)
-    module = Module.create()
-    with InsertionPoint(module.body):
-      inferred = test.InferResultsOp()
-      same = test.SameOperandAndResultTypeOp([inferred.results[0]])
-      # CHECK-COUNT-2: i32
-      print(same.one.type)
-      print(same.two.type)
-
-      first_type_attr = test.FirstAttrDeriveTypeAttrOp(
-          inferred.results[1], TypeAttr.get(IndexType.get()))
-      # CHECK-COUNT-2: index
-      print(first_type_attr.one.type)
-      print(first_type_attr.two.type)
-
-      first_attr = test.FirstAttrDeriveAttrOp(
-          FloatAttr.get(F32Type.get(), 3.14))
-      # CHECK-COUNT-3: f32
-      print(first_attr.one.type)
-      print(first_attr.two.type)
-      print(first_attr.three.type)
-
-      implied = test.InferResultsImpliedOp()
-      # CHECK: i32
-      print(implied.integer.type)
-      # CHECK: f64
-      print(implied.flt.type)
-      # CHECK: index
-      print(implied.index.type)
+    with Context() as ctx, Location.unknown(ctx):
+        test.register_python_test_dialect(ctx)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            inferred = test.InferResultsOp()
+            same = test.SameOperandAndResultTypeOp([inferred.results[0]])
+            # CHECK-COUNT-2: i32
+            print(same.one.type)
+            print(same.two.type)
+
+            first_type_attr = test.FirstAttrDeriveTypeAttrOp(
+                inferred.results[1], TypeAttr.get(IndexType.get())
+            )
+            # CHECK-COUNT-2: index
+            print(first_type_attr.one.type)
+            print(first_type_attr.two.type)
+
+            first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
+            # CHECK-COUNT-3: f32
+            print(first_attr.one.type)
+            print(first_attr.two.type)
+            print(first_attr.three.type)
+
+            implied = test.InferResultsImpliedOp()
+            # CHECK: i32
+            print(implied.integer.type)
+            # CHECK: f64
+            print(implied.flt.type)
+            # CHECK: index
+            print(implied.index.type)
 
 
 # CHECK-LABEL: TEST: testOptionalOperandOp
 @run
 def testOptionalOperandOp():
-  with Context() as ctx, Location.unknown():
-    test.register_python_test_dialect(ctx)
+    with Context() as ctx, Location.unknown():
+        test.register_python_test_dialect(ctx)
 
-    module = Module.create()
-    with InsertionPoint(module.body):
+        module = Module.create()
+        with InsertionPoint(module.body):
 
-      op1 = test.OptionalOperandOp()
-      # CHECK: op1.input is None: True
-      print(f"op1.input is None: {op1.input is None}")
+            op1 = test.OptionalOperandOp()
+            # CHECK: op1.input is None: True
+            print(f"op1.input is None: {op1.input is None}")
 
-      op2 = test.OptionalOperandOp(input=op1)
-      # CHECK: op2.input is None: False
-      print(f"op2.input is None: {op2.input is None}")
+            op2 = test.OptionalOperandOp(input=op1)
+            # CHECK: op2.input is None: False
+            print(f"op2.input is None: {op2.input is None}")
 
 
 # CHECK-LABEL: TEST: testCustomAttribute
 @run
 def testCustomAttribute():
-  with Context() as ctx:
-    test.register_python_test_dialect(ctx)
-    a = test.TestAttr.get()
-    # CHECK: #python_test.test_attr
-    print(a)
-
-    # The following cast must not assert.
-    b = test.TestAttr(a)
-
-    unit = UnitAttr.get()
-    try:
-      test.TestAttr(unit)
-    except ValueError as e:
-      assert "Cannot cast attribute to TestAttr" in str(e)
-    else:
-      raise
-
-    # The following must trigger a TypeError from our adaptors and must not
-    # crash.
-    try:
-      test.TestAttr(42)
-    except TypeError as e:
-      assert "Expected an MLIR object" in str(e)
-    else:
-      raise
-
-    # The following must trigger a TypeError from pybind (therefore, not
-    # checking its message) and must not crash.
-    try:
-      test.TestAttr(42, 56)
-    except TypeError:
-      pass
-    else:
-      raise
+    with Context() as ctx:
+        test.register_python_test_dialect(ctx)
+        a = test.TestAttr.get()
+        # CHECK: #python_test.test_attr
+        print(a)
+
+        # The following cast must not assert.
+        b = test.TestAttr(a)
+
+        unit = UnitAttr.get()
+        try:
+            test.TestAttr(unit)
+        except ValueError as e:
+            assert "Cannot cast attribute to TestAttr" in str(e)
+        else:
+            raise
+
+        # The following must trigger a TypeError from our adaptors and must not
+        # crash.
+        try:
+            test.TestAttr(42)
+        except TypeError as e:
+            assert "Expected an MLIR object" in str(e)
+        else:
+            raise
+
+        # The following must trigger a TypeError from pybind (therefore, not
+        # checking its message) and must not crash.
+        try:
+            test.TestAttr(42, 56)
+        except TypeError:
+            pass
+        else:
+            raise
 
 
 @run
 def testCustomType():
-  with Context() as ctx:
-    test.register_python_test_dialect(ctx)
-    a = test.TestType.get()
-    # CHECK: !python_test.test_type
-    print(a)
-
-    # The following cast must not assert.
-    b = test.TestType(a)
-    # Instance custom types should have typeids
-    assert isinstance(b.typeid, TypeID)
-    # Subclasses of ir.Type should not have a static_typeid
-    # CHECK: 'TestType' object has no attribute 'static_typeid'
-    try:
-      b.static_typeid
-    except AttributeError as e:
-      print(e)
-
-    i8 = IntegerType.get_signless(8)
-    try:
-      test.TestType(i8)
-    except ValueError as e:
-      assert "Cannot cast type to TestType" in str(e)
-    else:
-      raise
-
-    # The following must trigger a TypeError from our adaptors and must not
-    # crash.
-    try:
-      test.TestType(42)
-    except TypeError as e:
-      assert "Expected an MLIR object" in str(e)
-    else:
-      raise
-
-    # The following must trigger a TypeError from pybind (therefore, not
-    # checking its message) and must not crash.
-    try:
-      test.TestType(42, 56)
-    except TypeError:
-      pass
-    else:
-      raise
+    with Context() as ctx:
+        test.register_python_test_dialect(ctx)
+        a = test.TestType.get()
+        # CHECK: !python_test.test_type
+        print(a)
+
+        # The following cast must not assert.
+        b = test.TestType(a)
+        # Instance custom types should have typeids
+        assert isinstance(b.typeid, TypeID)
+        # Subclasses of ir.Type should not have a static_typeid
+        # CHECK: 'TestType' object has no attribute 'static_typeid'
+        try:
+            b.static_typeid
+        except AttributeError as e:
+            print(e)
+
+        i8 = IntegerType.get_signless(8)
+        try:
+            test.TestType(i8)
+        except ValueError as e:
+            assert "Cannot cast type to TestType" in str(e)
+        else:
+            raise
+
+        # The following must trigger a TypeError from our adaptors and must not
+        # crash.
+        try:
+            test.TestType(42)
+        except TypeError as e:
+            assert "Expected an MLIR object" in str(e)
+        else:
+            raise
+
+        # The following must trigger a TypeError from pybind (therefore, not
+        # checking its message) and must not crash.
+        try:
+            test.TestType(42, 56)
+        except TypeError:
+            pass
+        else:
+            raise
 
 
 @run
 # CHECK-LABEL: TEST: testTensorValue
 def testTensorValue():
-  with Context() as ctx, Location.unknown():
-    test.register_python_test_dialect(ctx)
+    with Context() as ctx, Location.unknown():
+        test.register_python_test_dialect(ctx)
 
-    i8 = IntegerType.get_signless(8)
+        i8 = IntegerType.get_signless(8)
 
-    class Tensor(test.TestTensorValue):
-      def __str__(self):
-        return super().__str__().replace("Value", "Tensor")
+        class Tensor(test.TestTensorValue):
+            def __str__(self):
+                return super().__str__().replace("Value", "Tensor")
 
-    module = Module.create()
-    with InsertionPoint(module.body):
-      t = tensor.EmptyOp([10, 10], i8).result
+        module = Module.create()
+        with InsertionPoint(module.body):
+            t = tensor.EmptyOp([10, 10], i8).result
 
-      # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
-      print(Value(t))
+            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+            print(Value(t))
 
-      tt = Tensor(t)
-      # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
-      print(tt)
+            tt = Tensor(t)
+            # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+            print(tt)
 
-      # CHECK: False
-      print(tt.is_null())
+            # CHECK: False
+            print(tt.is_null())
 
-      # Classes of custom types that inherit from concrete types should have
-      # static_typeid
-      assert isinstance(test.TestTensorType.static_typeid, TypeID)
-      # And it should be equal to the in-tree concrete type
-      assert test.TestTensorType.static_typeid == t.type.typeid
+            # Classes of custom types that inherit from concrete types should have
+            # static_typeid
+            assert isinstance(test.TestTensorType.static_typeid, TypeID)
+            # And it should be equal to the in-tree concrete type
+            assert test.TestTensorType.static_typeid == t.type.typeid
 
 
 # CHECK-LABEL: TEST: inferReturnTypeComponents
@@ -412,7 +418,7 @@ def inferReturnTypeComponents():
         # CHECK: shape: None
         iface = InferShapedTypeOpInterface(unranked_op)
         shaped_type_components = iface.inferReturnTypeComponents(
-          operands=[unranked_op.operand]
+            operands=[unranked_op.operand]
         )[0]
         print("has rank:", shaped_type_components.has_rank)
         print("rank:", shaped_type_components.rank)

diff  --git a/mlir/test/python/dialects/quant.py b/mlir/test/python/dialects/quant.py
index 32614be31f377..0ee3327dec152 100644
--- a/mlir/test/python/dialects/quant.py
+++ b/mlir/test/python/dialects/quant.py
@@ -5,127 +5,133 @@
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: test_type_hierarchy
 @run
 def test_type_hierarchy():
-  with Context():
-    i8 = IntegerType.get_signless(8)
-    any = Type.parse("!quant.any<i8<-8:7>:f32>")
-    uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
-    per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
-    calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
+    with Context():
+        i8 = IntegerType.get_signless(8)
+        any = Type.parse("!quant.any<i8<-8:7>:f32>")
+        uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
+        per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
+        calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
 
-    assert not quant.QuantizedType.isinstance(i8)
-    assert quant.QuantizedType.isinstance(any)
-    assert quant.QuantizedType.isinstance(uniform)
-    assert quant.QuantizedType.isinstance(per_axis)
-    assert quant.QuantizedType.isinstance(calibrated)
+        assert not quant.QuantizedType.isinstance(i8)
+        assert quant.QuantizedType.isinstance(any)
+        assert quant.QuantizedType.isinstance(uniform)
+        assert quant.QuantizedType.isinstance(per_axis)
+        assert quant.QuantizedType.isinstance(calibrated)
 
-    assert quant.AnyQuantizedType.isinstance(any)
-    assert quant.UniformQuantizedType.isinstance(uniform)
-    assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
-    assert quant.CalibratedQuantizedType.isinstance(calibrated)
+        assert quant.AnyQuantizedType.isinstance(any)
+        assert quant.UniformQuantizedType.isinstance(uniform)
+        assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
+        assert quant.CalibratedQuantizedType.isinstance(calibrated)
 
-    assert not quant.AnyQuantizedType.isinstance(uniform)
-    assert not quant.UniformQuantizedType.isinstance(per_axis)
+        assert not quant.AnyQuantizedType.isinstance(uniform)
+        assert not quant.UniformQuantizedType.isinstance(per_axis)
 
 
 # CHECK-LABEL: TEST: test_any_quantized_type
 @run
 def test_any_quantized_type():
-  with Context():
-    i8 = IntegerType.get_signless(8)
-    f32 = F32Type.get()
-    any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32,
-                                     -8, 7)
-
-    # CHECK: flags: 1
-    print(f"flags: {any.flags}")
-    # CHECK: signed: True
-    print(f"signed: {any.is_signed}")
-    # CHECK: storage type: i8
-    print(f"storage type: {any.storage_type}")
-    # CHECK: expressed type: f32
-    print(f"expressed type: {any.expressed_type}")
-    # CHECK: storage min: -8
-    print(f"storage min: {any.storage_type_min}")
-    # CHECK: storage max: 7
-    print(f"storage max: {any.storage_type_max}")
-    # CHECK: storage width: 8
-    print(f"storage width: {any.storage_type_integral_width}")
-    # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
-    print(f"quantized element type: {any.quantized_element_type}")
-    # CHECK: !quant.any<i8<-8:7>:f32>
-    print(any)
-    assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
+    with Context():
+        i8 = IntegerType.get_signless(8)
+        f32 = F32Type.get()
+        any = quant.AnyQuantizedType.get(
+            quant.QuantizedType.FLAG_SIGNED, i8, f32, -8, 7
+        )
+
+        # CHECK: flags: 1
+        print(f"flags: {any.flags}")
+        # CHECK: signed: True
+        print(f"signed: {any.is_signed}")
+        # CHECK: storage type: i8
+        print(f"storage type: {any.storage_type}")
+        # CHECK: expressed type: f32
+        print(f"expressed type: {any.expressed_type}")
+        # CHECK: storage min: -8
+        print(f"storage min: {any.storage_type_min}")
+        # CHECK: storage max: 7
+        print(f"storage max: {any.storage_type_max}")
+        # CHECK: storage width: 8
+        print(f"storage width: {any.storage_type_integral_width}")
+        # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
+        print(f"quantized element type: {any.quantized_element_type}")
+        # CHECK: !quant.any<i8<-8:7>:f32>
+        print(any)
+        assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
 
 
 # CHECK-LABEL: TEST: test_uniform_type
 @run
 def test_uniform_type():
-  with Context():
-    i8 = IntegerType.get_signless(8)
-    f32 = F32Type.get()
-    uniform = quant.UniformQuantizedType.get(
-        quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7)
-
-    # CHECK: scale: 0.99872
-    print(f"scale: {uniform.scale}")
-    # CHECK: zero point: 127
-    print(f"zero point: {uniform.zero_point}")
-    # CHECK: fixed point: False
-    print(f"fixed point: {uniform.is_fixed_point}")
-    # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
-    print(uniform)
-    assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
+    with Context():
+        i8 = IntegerType.get_signless(8)
+        f32 = F32Type.get()
+        uniform = quant.UniformQuantizedType.get(
+            quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7
+        )
+
+        # CHECK: scale: 0.99872
+        print(f"scale: {uniform.scale}")
+        # CHECK: zero point: 127
+        print(f"zero point: {uniform.zero_point}")
+        # CHECK: fixed point: False
+        print(f"fixed point: {uniform.is_fixed_point}")
+        # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
+        print(uniform)
+        assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
 
 
 # CHECK-LABEL: TEST: test_uniform_per_axis_type
 @run
 def test_uniform_per_axis_type():
-  with Context():
-    i8 = IntegerType.get_signless(8)
-    f32 = F32Type.get()
-    per_axis = quant.UniformQuantizedPerAxisType.get(
-        quant.QuantizedType.FLAG_SIGNED,
-        i8,
-        f32, [200, 0.99872], [0, 120],
-        quantized_dimension=1,
-        storage_type_min=quant.QuantizedType.default_minimum_for_integer(
-            is_signed=True, integral_width=8),
-        storage_type_max=quant.QuantizedType.default_maximum_for_integer(
-            is_signed=True, integral_width=8))
-
-    # CHECK: scales: None
-    print(f"scales: {per_axis.scales}")
-    # CHECK: zero_points: None
-    print(f"zero_points: {per_axis.zero_points}")
-    # CHECK: quantized dim: 1
-    print(f"quantized dim: {per_axis.quantized_dimension}")
-    # CHECK: fixed point: False
-    print(f"fixed point: {per_axis.is_fixed_point}")
-    # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
-    print(per_axis)
-    assert per_axis == Type.parse(
-        "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
+    with Context():
+        i8 = IntegerType.get_signless(8)
+        f32 = F32Type.get()
+        per_axis = quant.UniformQuantizedPerAxisType.get(
+            quant.QuantizedType.FLAG_SIGNED,
+            i8,
+            f32,
+            [200, 0.99872],
+            [0, 120],
+            quantized_dimension=1,
+            storage_type_min=quant.QuantizedType.default_minimum_for_integer(
+                is_signed=True, integral_width=8
+            ),
+            storage_type_max=quant.QuantizedType.default_maximum_for_integer(
+                is_signed=True, integral_width=8
+            ),
+        )
+
+        # CHECK: scales: None
+        print(f"scales: {per_axis.scales}")
+        # CHECK: zero_points: None
+        print(f"zero_points: {per_axis.zero_points}")
+        # CHECK: quantized dim: 1
+        print(f"quantized dim: {per_axis.quantized_dimension}")
+        # CHECK: fixed point: False
+        print(f"fixed point: {per_axis.is_fixed_point}")
+        # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
+        print(per_axis)
+        assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
 
 
 # CHECK-LABEL: TEST: test_calibrated_type
 @run
 def test_calibrated_type():
-  with Context():
-    f32 = F32Type.get()
-    calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
-
-    # CHECK: min: -0.998
-    print(f"min: {calibrated.min}")
-    # CHECK: max: 1.2321
-    print(f"max: {calibrated.max}")
-    # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
-    print(calibrated)
-    assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
+    with Context():
+        f32 = F32Type.get()
+        calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
+
+        # CHECK: min: -0.998
+        print(f"min: {calibrated.min}")
+        # CHECK: max: 1.2321
+        print(f"max: {calibrated.max}")
+        # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
+        print(calibrated)
+        assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")

diff  --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 4a618ff4eecc3..8cb55fdf6a1eb 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -8,26 +8,26 @@
 
 
 def constructAndPrintInModule(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      f()
-    print(module)
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 # CHECK-LABEL: TEST: testSimpleLoop
 @constructAndPrintInModule
 def testSimpleLoop():
-  index_type = IndexType.get()
+    index_type = IndexType.get()
 
-  @func.FuncOp.from_py_func(index_type, index_type, index_type)
-  def simple_loop(lb, ub, step):
-    loop = scf.ForOp(lb, ub, step, [lb, lb])
-    with InsertionPoint(loop.body):
-      scf.YieldOp(loop.inner_iter_args)
-    return
+    @func.FuncOp.from_py_func(index_type, index_type, index_type)
+    def simple_loop(lb, ub, step):
+        loop = scf.ForOp(lb, ub, step, [lb, lb])
+        with InsertionPoint(loop.body):
+            scf.YieldOp(loop.inner_iter_args)
+        return
 
 
 # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
@@ -39,14 +39,14 @@ def simple_loop(lb, ub, step):
 # CHECK-LABEL: TEST: testInductionVar
 @constructAndPrintInModule
 def testInductionVar():
-  index_type = IndexType.get()
+    index_type = IndexType.get()
 
-  @func.FuncOp.from_py_func(index_type, index_type, index_type)
-  def induction_var(lb, ub, step):
-    loop = scf.ForOp(lb, ub, step, [lb])
-    with InsertionPoint(loop.body):
-      scf.YieldOp([loop.induction_variable])
-    return
+    @func.FuncOp.from_py_func(index_type, index_type, index_type)
+    def induction_var(lb, ub, step):
+        loop = scf.ForOp(lb, ub, step, [lb])
+        with InsertionPoint(loop.body):
+            scf.YieldOp([loop.induction_variable])
+        return
 
 
 # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
@@ -56,19 +56,18 @@ def induction_var(lb, ub, step):
 
 @constructAndPrintInModule
 def testOpsAsArguments():
-  index_type = IndexType.get()
-  callee = func.FuncOp(
-      "callee", ([], [index_type, index_type]), visibility="private")
-  f = func.FuncOp("ops_as_arguments", ([], []))
-  with InsertionPoint(f.add_entry_block()):
-    lb = arith.ConstantOp.create_index(0)
-    ub = arith.ConstantOp.create_index(42)
-    step = arith.ConstantOp.create_index(2)
-    iter_args = func.CallOp(callee, [])
-    loop = scf.ForOp(lb, ub, step, iter_args)
-    with InsertionPoint(loop.body):
-      scf.YieldOp(loop.inner_iter_args)
-    func.ReturnOp([])
+    index_type = IndexType.get()
+    callee = func.FuncOp("callee", ([], [index_type, index_type]), visibility="private")
+    f = func.FuncOp("ops_as_arguments", ([], []))
+    with InsertionPoint(f.add_entry_block()):
+        lb = arith.ConstantOp.create_index(0)
+        ub = arith.ConstantOp.create_index(42)
+        step = arith.ConstantOp.create_index(2)
+        iter_args = func.CallOp(callee, [])
+        loop = scf.ForOp(lb, ub, step, iter_args)
+        with InsertionPoint(loop.body):
+            scf.YieldOp(loop.inner_iter_args)
+        func.ReturnOp([])
 
 
 # CHECK-LABEL: TEST: testOpsAsArguments
@@ -86,17 +85,17 @@ def testOpsAsArguments():
 
 @constructAndPrintInModule
 def testIfWithoutElse():
-  bool = IntegerType.get_signless(1)
-  i32 = IntegerType.get_signless(32)
+    bool = IntegerType.get_signless(1)
+    i32 = IntegerType.get_signless(32)
 
-  @func.FuncOp.from_py_func(bool)
-  def simple_if(cond):
-    if_op = scf.IfOp(cond)
-    with InsertionPoint(if_op.then_block):
-      one = arith.ConstantOp(i32, 1)
-      add = arith.AddIOp(one, one)
-      scf.YieldOp([])
-    return
+    @func.FuncOp.from_py_func(bool)
+    def simple_if(cond):
+        if_op = scf.IfOp(cond)
+        with InsertionPoint(if_op.then_block):
+            one = arith.ConstantOp(i32, 1)
+            add = arith.AddIOp(one, one)
+            scf.YieldOp([])
+        return
 
 
 # CHECK: func @simple_if(%[[ARG0:.*]]: i1)
@@ -108,22 +107,22 @@ def simple_if(cond):
 
 @constructAndPrintInModule
 def testIfWithElse():
-  bool = IntegerType.get_signless(1)
-  i32 = IntegerType.get_signless(32)
-
-  @func.FuncOp.from_py_func(bool)
-  def simple_if_else(cond):
-    if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
-    with InsertionPoint(if_op.then_block):
-      x_true = arith.ConstantOp(i32, 0)
-      y_true = arith.ConstantOp(i32, 1)
-      scf.YieldOp([x_true, y_true])
-    with InsertionPoint(if_op.else_block):
-      x_false = arith.ConstantOp(i32, 2)
-      y_false = arith.ConstantOp(i32, 3)
-      scf.YieldOp([x_false, y_false])
-    add = arith.AddIOp(if_op.results[0], if_op.results[1])
-    return
+    bool = IntegerType.get_signless(1)
+    i32 = IntegerType.get_signless(32)
+
+    @func.FuncOp.from_py_func(bool)
+    def simple_if_else(cond):
+        if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
+        with InsertionPoint(if_op.then_block):
+            x_true = arith.ConstantOp(i32, 0)
+            y_true = arith.ConstantOp(i32, 1)
+            scf.YieldOp([x_true, y_true])
+        with InsertionPoint(if_op.else_block):
+            x_false = arith.ConstantOp(i32, 2)
+            y_false = arith.ConstantOp(i32, 3)
+            scf.YieldOp([x_false, y_false])
+        add = arith.AddIOp(if_op.results[0], if_op.results[1])
+        return
 
 
 # CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)

diff  --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py
index 3e7a8b27d4c09..ad755852f5d37 100644
--- a/mlir/test/python/dialects/shape.py
+++ b/mlir/test/python/dialects/shape.py
@@ -7,36 +7,38 @@
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testConstShape
 @run
 def testConstShape():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
-      def const_shape_tensor(arg):
-        shape.ConstWitnessOp(False)
-        shape.ConstSizeOp(30)
-        shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
-        x = shape.ConstShapeOp([1, 2])
-        shape.MeetOp(x, x, error="impossible")
-        return shape.ConstShapeOp(
-            DenseElementsAttr.get(
-                np.array([3, 4], dtype=np.int64), type=IndexType.get()))
-
-
-
-    # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
-    # CHECK-DAG: shape.const_witness false
-    # CHECK-DAG: shape.const_size 30
-    # CHECK-DAG: shape.const_size 40
-    # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
-    # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)
+            )
+            def const_shape_tensor(arg):
+                shape.ConstWitnessOp(False)
+                shape.ConstSizeOp(30)
+                shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
+                x = shape.ConstShapeOp([1, 2])
+                shape.MeetOp(x, x, error="impossible")
+                return shape.ConstShapeOp(
+                    DenseElementsAttr.get(
+                        np.array([3, 4], dtype=np.int64), type=IndexType.get()
+                    )
+                )
+
+        # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
+        # CHECK-DAG: shape.const_witness false
+        # CHECK-DAG: shape.const_size 30
+        # CHECK-DAG: shape.const_size 40
+        # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
+        # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
+        print(module)

diff  --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 6190bebcd5e98..b7a06067b5f56 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -3,97 +3,106 @@
 from mlir.ir import *
 from mlir.dialects import sparse_tensor as st
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testEncodingAttr1D
 @run
 def testEncodingAttr1D():
-  with Context() as ctx:
-    parsed = Attribute.parse('#sparse_tensor.encoding<{'
-                             '  lvlTypes = [ "compressed" ],'
-                             '  posWidth = 16,'
-                             '  crdWidth = 32'
-                             '}>')
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 16, crdWidth = 32 }>
-    print(parsed)
-
-    casted = st.EncodingAttr(parsed)
-    # CHECK: equal: True
-    print(f"equal: {casted == parsed}")
-
-    # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
-    print(f"lvl_types: {casted.lvl_types}")
-    # CHECK: dim_ordering: None
-    print(f"dim_ordering: {casted.dim_ordering}")
-    # CHECK: pos_width: 16
-    print(f"pos_width: {casted.pos_width}")
-    # CHECK: crd_width: 32
-    print(f"crd_width: {casted.crd_width}")
-
-    created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
-    print(created)
-    # CHECK: created_equal: False
-    print(f"created_equal: {created == casted}")
-
-    # Verify that the factory creates an instance of the proper type.
-    # CHECK: is_proper_instance: True
-    print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
-    # CHECK: created_pos_width: 0
-    print(f"created_pos_width: {created.pos_width}")
+    with Context() as ctx:
+        parsed = Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            '  lvlTypes = [ "compressed" ],'
+            "  posWidth = 16,"
+            "  crdWidth = 32"
+            "}>"
+        )
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 16, crdWidth = 32 }>
+        print(parsed)
+
+        casted = st.EncodingAttr(parsed)
+        # CHECK: equal: True
+        print(f"equal: {casted == parsed}")
+
+        # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
+        print(f"lvl_types: {casted.lvl_types}")
+        # CHECK: dim_ordering: None
+        print(f"dim_ordering: {casted.dim_ordering}")
+        # CHECK: pos_width: 16
+        print(f"pos_width: {casted.pos_width}")
+        # CHECK: crd_width: 32
+        print(f"crd_width: {casted.crd_width}")
+
+        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
+        print(created)
+        # CHECK: created_equal: False
+        print(f"created_equal: {created == casted}")
+
+        # Verify that the factory creates an instance of the proper type.
+        # CHECK: is_proper_instance: True
+        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+        # CHECK: created_pos_width: 0
+        print(f"created_pos_width: {created.pos_width}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttr2D
 @run
 def testEncodingAttr2D():
-  with Context() as ctx:
-    parsed = Attribute.parse('#sparse_tensor.encoding<{'
-                             '  lvlTypes = [ "dense", "compressed" ],'
-                             '  dimOrdering = affine_map<(d0, d1) -> (d1, d0)>,'
-                             '  posWidth = 8,'
-                             '  crdWidth = 32'
-                             '}>')
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
-    print(parsed)
-
-    casted = st.EncodingAttr(parsed)
-    # CHECK: equal: True
-    print(f"equal: {casted == parsed}")
-
-    # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
-    print(f"lvl_types: {casted.lvl_types}")
-    # CHECK: dim_ordering: (d0, d1) -> (d1, d0)
-    print(f"dim_ordering: {casted.dim_ordering}")
-    # CHECK: pos_width: 8
-    print(f"pos_width: {casted.pos_width}")
-    # CHECK: crd_width: 32
-    print(f"crd_width: {casted.crd_width}")
-
-    created = st.EncodingAttr.get(casted.lvl_types, casted.dim_ordering,
-                                  casted.higher_ordering, 8, 32)
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
-    print(created)
-    # CHECK: created_equal: True
-    print(f"created_equal: {created == casted}")
+    with Context() as ctx:
+        parsed = Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            '  lvlTypes = [ "dense", "compressed" ],'
+            "  dimOrdering = affine_map<(d0, d1) -> (d1, d0)>,"
+            "  posWidth = 8,"
+            "  crdWidth = 32"
+            "}>"
+        )
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
+        print(parsed)
+
+        casted = st.EncodingAttr(parsed)
+        # CHECK: equal: True
+        print(f"equal: {casted == parsed}")
+
+        # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
+        print(f"lvl_types: {casted.lvl_types}")
+        # CHECK: dim_ordering: (d0, d1) -> (d1, d0)
+        print(f"dim_ordering: {casted.dim_ordering}")
+        # CHECK: pos_width: 8
+        print(f"pos_width: {casted.pos_width}")
+        # CHECK: crd_width: 32
+        print(f"crd_width: {casted.crd_width}")
+
+        created = st.EncodingAttr.get(
+            casted.lvl_types, casted.dim_ordering, casted.higher_ordering, 8, 32
+        )
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
+        print(created)
+        # CHECK: created_equal: True
+        print(f"created_equal: {created == casted}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttrOnTensorType
 @run
 def testEncodingAttrOnTensorType():
-  with Context() as ctx, Location.unknown():
-    encoding = st.EncodingAttr(
-        Attribute.parse('#sparse_tensor.encoding<{'
-                        '  lvlTypes = [ "compressed" ], '
-                        '  posWidth = 64,'
-                        '  crdWidth = 32'
-                        '}>'))
-    tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
-    # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>>
-    print(tt)
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>
-    print(tt.encoding)
-    assert tt.encoding == encoding
+    with Context() as ctx, Location.unknown():
+        encoding = st.EncodingAttr(
+            Attribute.parse(
+                "#sparse_tensor.encoding<{"
+                '  lvlTypes = [ "compressed" ], '
+                "  posWidth = 64,"
+                "  crdWidth = 32"
+                "}>"
+            )
+        )
+        tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
+        # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>>
+        print(tt)
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>
+        print(tt.encoding)
+        assert tt.encoding == encoding

diff  --git a/mlir/test/python/dialects/sparse_tensor/passes.py b/mlir/test/python/dialects/sparse_tensor/passes.py
index 9319e16e054de..c37c5207ebd9f 100644
--- a/mlir/test/python/dialects/sparse_tensor/passes.py
+++ b/mlir/test/python/dialects/sparse_tensor/passes.py
@@ -7,16 +7,16 @@
 
 
 def run(f):
-  print('\nTEST:', f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testSparseTensorPass
 @run
 def testSparseTensorPass():
-  with Context() as context:
-    PassManager.parse('any(sparsification)')
-    PassManager.parse('any(sparse-tensor-conversion)')
-  # CHECK: SUCCESS
-  print('SUCCESS')
+    with Context() as context:
+        PassManager.parse("any(sparsification)")
+        PassManager.parse("any(sparse-tensor-conversion)")
+    # CHECK: SUCCESS
+    print("SUCCESS")

diff  --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py
index b0ad4b4ad4d65..b690c934dc46b 100644
--- a/mlir/test/python/dialects/tensor.py
+++ b/mlir/test/python/dialects/tensor.py
@@ -7,125 +7,135 @@
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testDimOp
 @run
 def testDimOp():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32Type = F32Type.get()
-    indexType = IndexType.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get(
-              (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()),
-              f32Type))
-      #      CHECK: func @tensor_static_dim
-      # CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
-      #  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-      #  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-      #      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-      #      CHECK:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-      #      CHECK:   return %[[D0]], %[[D1]]
-      def tensor_static_dim(t):
-        c0 = arith.ConstantOp(indexType, 0)
-        c1 = arith.ConstantOp(indexType, 1)
-        d0 = tensor.DimOp(t, c0)
-        d1 = tensor.DimOp(t, c1)
-        return [d0.result, d1.result]
-
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32Type = F32Type.get()
+        indexType = IndexType.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get(
+                    (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()),
+                    f32Type,
+                )
+            )
+            #      CHECK: func @tensor_static_dim
+            # CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+            #  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+            #  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+            #      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+            #      CHECK:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+            #      CHECK:   return %[[D0]], %[[D1]]
+            def tensor_static_dim(t):
+                c0 = arith.ConstantOp(indexType, 0)
+                c1 = arith.ConstantOp(indexType, 1)
+                d0 = tensor.DimOp(t, c0)
+                d1 = tensor.DimOp(t, c1)
+                return [d0.result, d1.result]
+
+        print(module)
 
 
 # CHECK-LABEL: TEST: testEmptyOp
 @run
 def testEmptyOp():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      # CHECK-LABEL: func @static_sizes
-      # CHECK: %0 = tensor.empty() : tensor<3x4xf32>
-      @func.FuncOp.from_py_func()
-      def static_sizes():
-        return tensor.EmptyOp([3, 4], f32)
-
-      # CHECK-LABEL: func @dynamic_sizes
-      # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
-      @func.FuncOp.from_py_func(IndexType.get(), IndexType.get())
-      def dynamic_sizes(d0, d1):
-        return tensor.EmptyOp([d0, d1], f32)
-
-      # CHECK-LABEL: func @mixed_static_dynamic_sizes
-      # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32>
-      @func.FuncOp.from_py_func(IndexType.get())
-      def mixed_static_dynamic_sizes(d0):
-        return tensor.EmptyOp([d0, 4], f32)
-
-      # CHECK-LABEL: func @zero_d
-      # CHECK: %0 = tensor.empty() : tensor<f32>
-      @func.FuncOp.from_py_func()
-      def zero_d():
-        return tensor.EmptyOp([], f32)
-
-  print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            # CHECK-LABEL: func @static_sizes
+            # CHECK: %0 = tensor.empty() : tensor<3x4xf32>
+            @func.FuncOp.from_py_func()
+            def static_sizes():
+                return tensor.EmptyOp([3, 4], f32)
+
+            # CHECK-LABEL: func @dynamic_sizes
+            # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
+            @func.FuncOp.from_py_func(IndexType.get(), IndexType.get())
+            def dynamic_sizes(d0, d1):
+                return tensor.EmptyOp([d0, d1], f32)
+
+            # CHECK-LABEL: func @mixed_static_dynamic_sizes
+            # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32>
+            @func.FuncOp.from_py_func(IndexType.get())
+            def mixed_static_dynamic_sizes(d0):
+                return tensor.EmptyOp([d0, 4], f32)
+
+            # CHECK-LABEL: func @zero_d
+            # CHECK: %0 = tensor.empty() : tensor<f32>
+            @func.FuncOp.from_py_func()
+            def zero_d():
+                return tensor.EmptyOp([], f32)
+
+    print(module)
 
 
 # CHECK-LABEL: TEST: testInferTypesInsertSlice
 @run
 def testInferTypesInsertSlice():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32Type = F32Type.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((1, 1), f32Type),
-          RankedTensorType.get((1, 1), f32Type))
-      # CHECK: func @f
-      # CHECK:      tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
-      # CHECK-SAME:   tensor<1x1xf32> into tensor<1x1xf32>
-      def f(source, dest):
-        d0 = tensor.InsertSliceOp(source, dest, [], [], [],
-                                  DenseI64ArrayAttr.get([0, 0]),
-                                  DenseI64ArrayAttr.get([1, 1]),
-                                  DenseI64ArrayAttr.get([0, 0]))
-        return [d0.result]
-
-  print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32Type = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((1, 1), f32Type),
+                RankedTensorType.get((1, 1), f32Type),
+            )
+            # CHECK: func @f
+            # CHECK:      tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
+            # CHECK-SAME:   tensor<1x1xf32> into tensor<1x1xf32>
+            def f(source, dest):
+                d0 = tensor.InsertSliceOp(
+                    source,
+                    dest,
+                    [],
+                    [],
+                    [],
+                    DenseI64ArrayAttr.get([0, 0]),
+                    DenseI64ArrayAttr.get([1, 1]),
+                    DenseI64ArrayAttr.get([0, 0]),
+                )
+                return [d0.result]
+
+    print(module)
 
 
 # CHECK-LABEL: TEST: testFromElementsOp
 @run
 def testFromElementsOp():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      @func.FuncOp.from_py_func()
-      def default_builder():
-        c0 = arith.ConstantOp(f32, 0.0)
-        # CHECK: %[[C0:.*]] = "arith.constant
-        # CHECK-SAME: value = 0.000000e+00 : f32
-        print(c0)
-        c1 = arith.ConstantOp(f32, 1.0)
-        # CHECK: %[[C1:.*]] = "arith.constant
-        # CHECK-SAME: value = 1.000000e+00 : f32
-        print(c1)
-
-        t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1])
-        # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32>
-        print(t)
-
-        t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1])
-        # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32>
-        print(t)
-
-        t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
-        # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
-        print(t)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func()
+            def default_builder():
+                c0 = arith.ConstantOp(f32, 0.0)
+                # CHECK: %[[C0:.*]] = "arith.constant
+                # CHECK-SAME: value = 0.000000e+00 : f32
+                print(c0)
+                c1 = arith.ConstantOp(f32, 1.0)
+                # CHECK: %[[C1:.*]] = "arith.constant
+                # CHECK-SAME: value = 1.000000e+00 : f32
+                print(c1)
+
+                t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1])
+                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32>
+                print(t)
+
+                t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1])
+                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32>
+                print(t)
+
+                t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
+                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
+                print(t)

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6b36c025eafa3..ca6499b5706d1 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -6,158 +6,188 @@
 
 
 def run(f):
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      print("\nTEST:", f.__name__)
-      f()
-    print(module)
-  return f
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
 
 
 @run
 def testTypes():
-  # CHECK-LABEL: TEST: testTypes
-  # CHECK: !transform.any_op
-  any_op = transform.AnyOpType.get()
-  print(any_op)
+    # CHECK-LABEL: TEST: testTypes
+    # CHECK: !transform.any_op
+    any_op = transform.AnyOpType.get()
+    print(any_op)
 
-  # CHECK: !transform.op<"foo.bar">
-  # CHECK: foo.bar
-  concrete_op = transform.OperationType.get("foo.bar")
-  print(concrete_op)
-  print(concrete_op.operation_name)
+    # CHECK: !transform.op<"foo.bar">
+    # CHECK: foo.bar
+    concrete_op = transform.OperationType.get("foo.bar")
+    print(concrete_op)
+    print(concrete_op.operation_name)
 
 
 @run
 def testSequenceOp():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [transform.AnyOpType.get()],
-                                  transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    transform.YieldOp([sequence.bodyTarget])
-  # CHECK-LABEL: TEST: testSequenceOp
-  # CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
-  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-  # CHECK:   yield %[[ARG0]] : !transform.any_op
-  # CHECK: }
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [transform.AnyOpType.get()],
+        transform.AnyOpType.get(),
+    )
+    with InsertionPoint(sequence.body):
+        transform.YieldOp([sequence.bodyTarget])
+    # CHECK-LABEL: TEST: testSequenceOp
+    # CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   yield %[[ARG0]] : !transform.any_op
+    # CHECK: }
 
 
 @run
 def testNestedSequenceOp():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget)
-    with InsertionPoint(nested.body):
-      doubly_nested = transform.SequenceOp(
-          transform.FailurePropagationMode.PROPAGATE,
-          [transform.AnyOpType.get()], nested.bodyTarget)
-      with InsertionPoint(doubly_nested.body):
-        transform.YieldOp([doubly_nested.bodyTarget])
-      transform.YieldOp()
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testNestedSequenceOp
-  # CHECK: transform.sequence failures(propagate) {
-  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-  # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
-  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
-  # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
-  # CHECK:       yield %[[ARG2]] : !transform.any_op
-  # CHECK:     }
-  # CHECK:   }
-  # CHECK: }
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        nested = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget
+        )
+        with InsertionPoint(nested.body):
+            doubly_nested = transform.SequenceOp(
+                transform.FailurePropagationMode.PROPAGATE,
+                [transform.AnyOpType.get()],
+                nested.bodyTarget,
+            )
+            with InsertionPoint(doubly_nested.body):
+                transform.YieldOp([doubly_nested.bodyTarget])
+            transform.YieldOp()
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testNestedSequenceOp
+    # CHECK: transform.sequence failures(propagate) {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
+    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
+    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
+    # CHECK:       yield %[[ARG2]] : !transform.any_op
+    # CHECK:     }
+    # CHECK:   }
+    # CHECK: }
 
 
 @run
 def testSequenceOpWithExtras():
-  sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
-      [transform.AnyOpType.get(),
-       transform.OperationType.get("foo.bar")])
-  with InsertionPoint(sequence.body):
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testSequenceOpWithExtras
-  # CHECK: transform.sequence failures(propagate)
-  # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+    with InsertionPoint(sequence.body):
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testSequenceOpWithExtras
+    # CHECK: transform.sequence failures(propagate)
+    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
 
 
 @run
 def testNestedSequenceOpWithExtras():
-  sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
-      [transform.AnyOpType.get(),
-       transform.OperationType.get("foo.bar")])
-  with InsertionPoint(sequence.body):
-    nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], sequence.bodyTarget,
-                                  sequence.bodyExtraArgs)
-    with InsertionPoint(nested.body):
-      transform.YieldOp()
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
-  # CHECK: transform.sequence failures(propagate)
-  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
-  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+    with InsertionPoint(sequence.body):
+        nested = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE,
+            [],
+            sequence.bodyTarget,
+            sequence.bodyExtraArgs,
+        )
+        with InsertionPoint(nested.body):
+            transform.YieldOp()
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+    # CHECK: transform.sequence failures(propagate)
+    # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+    # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
 
 
 @run
 def testTransformPDLOps():
-  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
-  with InsertionPoint(withPdl.body):
-    sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                    [transform.AnyOpType.get()],
-                                    withPdl.bodyTarget)
-    with InsertionPoint(sequence.body):
-      match = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher")
-      transform.YieldOp(match)
-  # CHECK-LABEL: TEST: testTransformPDLOps
-  # CHECK: transform.with_pdl_patterns {
-  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
-  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
-  # CHECK:     yield %[[RES]] : !transform.any_op
-  # CHECK:   }
-  # CHECK: }
+    withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+    with InsertionPoint(withPdl.body):
+        sequence = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE,
+            [transform.AnyOpType.get()],
+            withPdl.bodyTarget,
+        )
+        with InsertionPoint(sequence.body):
+            match = transform_pdl.PDLMatchOp(
+                transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+            )
+            transform.YieldOp(match)
+    # CHECK-LABEL: TEST: testTransformPDLOps
+    # CHECK: transform.with_pdl_patterns {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+    # CHECK:     yield %[[RES]] : !transform.any_op
+    # CHECK:   }
+    # CHECK: }
 
 
 @run
 def testGetClosestIsolatedParentOp():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    transform.GetClosestIsolatedParentOp(transform.AnyOpType.get(), sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
-  # CHECK: transform.sequence
-  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:   = get_closest_isolated_parent %[[ARG1]]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        transform.GetClosestIsolatedParentOp(
+            transform.AnyOpType.get(), sequence.bodyTarget
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
+    # CHECK: transform.sequence
+    # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:   = get_closest_isolated_parent %[[ARG1]]
 
 
 @run
 def testMergeHandlesOp():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    transform.MergeHandlesOp([sequence.bodyTarget])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testMergeHandlesOp
-  # CHECK: transform.sequence
-  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:   = merge_handles %[[ARG1]]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        transform.MergeHandlesOp([sequence.bodyTarget])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMergeHandlesOp
+    # CHECK: transform.sequence
+    # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:   = merge_handles %[[ARG1]]
 
 
 @run
 def testReplicateOp():
-  with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
-  with InsertionPoint(with_pdl.body):
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget)
-    with InsertionPoint(sequence.body):
-      m1 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first")
-      m2 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second")
-      transform.ReplicateOp(m1, [m2])
-      transform.YieldOp()
-  # CHECK-LABEL: TEST: testReplicateOp
-  # CHECK: %[[FIRST:.+]] = pdl_match
-  # CHECK: %[[SECOND:.+]] = pdl_match
-  # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+    with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+    with InsertionPoint(with_pdl.body):
+        sequence = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
+        )
+        with InsertionPoint(sequence.body):
+            m1 = transform_pdl.PDLMatchOp(
+                transform.AnyOpType.get(), sequence.bodyTarget, "first"
+            )
+            m2 = transform_pdl.PDLMatchOp(
+                transform.AnyOpType.get(), sequence.bodyTarget, "second"
+            )
+            transform.ReplicateOp(m1, [m2])
+            transform.YieldOp()
+    # CHECK-LABEL: TEST: testReplicateOp
+    # CHECK: %[[FIRST:.+]] = pdl_match
+    # CHECK: %[[SECOND:.+]] = pdl_match
+    # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]

diff  --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py
index 067a8b60d4f89..28a022a400fe6 100644
--- a/mlir/test/python/dialects/transform_loop_ext.py
+++ b/mlir/test/python/dialects/transform_loop_ext.py
@@ -7,70 +7,92 @@
 
 
 def run(f):
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      print("\nTEST:", f.__name__)
-      f()
-    print(module)
-  return f
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
 
 
 @run
 def getParentLoop():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    loop.GetParentForOp(transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: getParentLoop
-  # CHECK: = transform.loop.get_parent_for %
-  # CHECK: num_loops = 2
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        loop.GetParentForOp(
+            transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: getParentLoop
+    # CHECK: = transform.loop.get_parent_for %
+    # CHECK: num_loops = 2
 
 
 @run
 def loopOutline():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.OperationType.get("scf.for"))
-  with InsertionPoint(sequence.body):
-    loop.LoopOutlineOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo")
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: loopOutline
-  # CHECK: = transform.loop.outline %
-  # CHECK: func_name = "foo"
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("scf.for"),
+    )
+    with InsertionPoint(sequence.body):
+        loop.LoopOutlineOp(
+            transform.AnyOpType.get(),
+            transform.AnyOpType.get(),
+            sequence.bodyTarget,
+            func_name="foo",
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: loopOutline
+    # CHECK: = transform.loop.outline %
+    # CHECK: func_name = "foo"
 
 
 @run
 def loopPeel():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.OperationType.get("scf.for"))
-  with InsertionPoint(sequence.body):
-    loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: loopPeel
-  # CHECK: = transform.loop.peel %
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("scf.for"),
+    )
+    with InsertionPoint(sequence.body):
+        loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: loopPeel
+    # CHECK: = transform.loop.peel %
 
 
 @run
 def loopPipeline():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.OperationType.get("scf.for"))
-  with InsertionPoint(sequence.body):
-    loop.LoopPipelineOp(pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: loopPipeline
-  # CHECK: = transform.loop.pipeline %
-  # CHECK-DAG: iteration_interval = 3
-  # (read_latency has default value and is not printed)
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("scf.for"),
+    )
+    with InsertionPoint(sequence.body):
+        loop.LoopPipelineOp(
+            pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: loopPipeline
+    # CHECK: = transform.loop.pipeline %
+    # CHECK-DAG: iteration_interval = 3
+    # (read_latency has default value and is not printed)
 
 
 @run
 def loopUnroll():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.OperationType.get("scf.for"))
-  with InsertionPoint(sequence.body):
-    loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: loopUnroll
-  # CHECK: transform.loop.unroll %
-  # CHECK: factor = 42
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("scf.for"),
+    )
+    with InsertionPoint(sequence.body):
+        loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: loopUnroll
+    # CHECK: transform.loop.unroll %
+    # CHECK: factor = 42

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index d2a82b8218f25..2dfae47bdfb49 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -8,204 +8,230 @@
 
 
 def run(f):
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      print("\nTEST:", f.__name__)
-      f()
-    print(module)
-  return f
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
 
 
 @run
 def testDecompose():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.DecomposeOp(sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testDecompose
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.decompose
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.DecomposeOp(sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testDecompose
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.decompose
 
 
 @run
 def testGeneralize():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.GeneralizeOp(sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testGeneralize
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.generalize
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.GeneralizeOp(sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testGeneralize
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.generalize
 
 
 @run
 def testInterchange():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.InterchangeOp(
-        sequence.bodyTarget,
-        iterator_interchange=[1, 0])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testInterchange
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.interchange
-  # CHECK: iterator_interchange = [1, 0]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testInterchange
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.interchange
+    # CHECK: iterator_interchange = [1, 0]
 
 
 @run
 def testMultitileSizes():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.MultiTileSizesOp(pdl.OperationType.get(),
-                                sequence.bodyTarget,
-                                dimension=1,
-                                target_size=42)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testMultitileSizes
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.multitile_sizes
-  # CHECK-DAG: dimension = 1
-  # CHECK-DAG: target_size = 42
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.MultiTileSizesOp(
+            pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMultitileSizes
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.multitile_sizes
+    # CHECK-DAG: dimension = 1
+    # CHECK-DAG: target_size = 42
 
 
 @run
 def testPad():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.PadOp(
-        sequence.bodyTarget,
-        padding_values=[FloatAttr.get_f32(42.0)],
-        padding_dimensions=[1],
-        transpose_paddings=[[1, 0]])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testPad
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.pad
-  # CHECK-DAG: padding_values = [4.200000e+01 : f32]
-  # CHECK-DAG: padding_dimensions = [1]
-  # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
-  # (pack_paddings has default values)
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.PadOp(
+            sequence.bodyTarget,
+            padding_values=[FloatAttr.get_f32(42.0)],
+            padding_dimensions=[1],
+            transpose_paddings=[[1, 0]],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testPad
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.pad
+    # CHECK-DAG: padding_values = [4.200000e+01 : f32]
+    # CHECK-DAG: padding_dimensions = [1]
+    # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
+    # (pack_paddings has default values)
+
 
 @run
 def testScalarize():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.ScalarizeOp(sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testScalarize
-  # CHECK: transform.structured.scalarize
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.ScalarizeOp(sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testScalarize
+    # CHECK: transform.structured.scalarize
 
 
 @run
 def testSplit():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
-    structured.SplitOp(
-        split.results[0], dimension=3, split_point=split.results[1])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testSplit
-  # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
-  # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
+        structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testSplit
+    # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
+    # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
+
 
 @run
 def testTileCompact():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.TileOp(sequence.bodyTarget,
-                      sizes=[4, 8],
-                      interchange=[0, 1])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileCompact
-  # CHECK: transform.sequence
-  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
-  # CHECK: interchange = [0, 1]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileCompact
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
+    # CHECK: interchange = [0, 1]
+
 
 @run
 def testTileAttributes():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  attr = DenseI64ArrayAttr.get([4, 8])
-  ichange = DenseI64ArrayAttr.get([0, 1])
-  with InsertionPoint(sequence.body):
-    structured.TileOp(sequence.bodyTarget,
-                      sizes=attr,
-                      interchange=ichange)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileAttributes
-  # CHECK: transform.sequence
-  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
-  # CHECK: interchange = [0, 1]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    attr = DenseI64ArrayAttr.get([4, 8])
+    ichange = DenseI64ArrayAttr.get([0, 1])
+    with InsertionPoint(sequence.body):
+        structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileAttributes
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
+    # CHECK: interchange = [0, 1]
+
 
 @run
 def testTileZero():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.TileOp(sequence.bodyTarget,
-                      sizes=[4, 0, 2, 0],
-                      interchange=[0, 1, 2, 3])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileZero
-  # CHECK: transform.sequence
-  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
-  # CHECK: interchange = [0, 1, 2, 3]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileOp(
+            sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileZero
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
+    # CHECK: interchange = [0, 1, 2, 3]
+
 
 @run
 def testTileDynamic():
-  with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
-  with InsertionPoint(with_pdl.body):
-    sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [],
-                                    with_pdl.bodyTarget)
-    with InsertionPoint(sequence.body):
-      m1 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
-      m2 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
-      structured.TileOp(sequence.bodyTarget,
-                        sizes=[m1, 3, m2, 0])
-      transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileDynamic
-  # CHECK: %[[FIRST:.+]] = pdl_match
-  # CHECK: %[[SECOND:.+]] = pdl_match
-  # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
+    with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
+    with InsertionPoint(with_pdl.body):
+        sequence = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
+        )
+        with InsertionPoint(sequence.body):
+            m1 = transform_pdl.PDLMatchOp(
+                pdl.OperationType.get(), sequence.bodyTarget, "first"
+            )
+            m2 = transform_pdl.PDLMatchOp(
+                pdl.OperationType.get(), sequence.bodyTarget, "second"
+            )
+            structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
+            transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileDynamic
+    # CHECK: %[[FIRST:.+]] = pdl_match
+    # CHECK: %[[SECOND:.+]] = pdl_match
+    # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
 
 
 @run
 def testTileExplicitLoopTypeSingle():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    structured.TileOp(transform.OperationType.get("scf.for"),
-                      sequence.bodyTarget,
-                      sizes=[2, 3, 4])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
-  # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
-  # CHECK-COUNT-3: !transform.op<"scf.for">
-
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileOp(
+            transform.OperationType.get("scf.for"), sequence.bodyTarget, sizes=[2, 3, 4]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
+    # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
+    # CHECK-COUNT-3: !transform.op<"scf.for">
 
 
 @run
 def testTileExplicitLoopTypeAll():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.AnyOpType.get())
-  types = [
-      transform.OperationType.get(x)
-      for x in ["scf.for", "scf.parallel", "scf.forall"]
-  ]
-  with InsertionPoint(sequence.body):
-    structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
-  # CHECK: = transform.structured.tile
-  # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
-  # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    types = [
+        transform.OperationType.get(x)
+        for x in ["scf.for", "scf.parallel", "scf.forall"]
+    ]
+    with InsertionPoint(sequence.body):
+        structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
+    # CHECK: = transform.structured.tile
+    # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
+    # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
+
 
 @run
 def testVectorize():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testVectorize
-  # CHECK: transform.sequence
-  # CHECK: = transform.structured.vectorize
-  # CHECK: {vectorize_padding}
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testVectorize
+    # CHECK: transform.sequence
+    # CHECK: = transform.structured.vectorize
+    # CHECK: {vectorize_padding}

diff  --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py
index 83c09616d56fe..2347abb62b410 100644
--- a/mlir/test/python/dialects/vector.py
+++ b/mlir/test/python/dialects/vector.py
@@ -5,57 +5,62 @@
 import mlir.dialects.func as func
 import mlir.dialects.vector as vector
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    f()
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        f()
+    return f
+
 
 # CHECK-LABEL: TEST: testPrintOp
 @run
 def testPrintOp():
-  module = Module.create()
-  with InsertionPoint(module.body):
+    module = Module.create()
+    with InsertionPoint(module.body):
 
-    @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
-    def print_vector(arg):
-      return vector.PrintOp(arg)
+        @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
+        def print_vector(arg):
+            return vector.PrintOp(arg)
 
-  # CHECK-LABEL: func @print_vector(
-  # CHECK-SAME:                     %[[ARG:.*]]: vector<12x5xf32>) {
-  #       CHECK:   vector.print %[[ARG]] : vector<12x5xf32>
-  #       CHECK:   return
-  #       CHECK: }
-  print(module)
+    # CHECK-LABEL: func @print_vector(
+    # CHECK-SAME:                     %[[ARG:.*]]: vector<12x5xf32>) {
+    #       CHECK:   vector.print %[[ARG]] : vector<12x5xf32>
+    #       CHECK:   return
+    #       CHECK: }
+    print(module)
 
 
 # CHECK-LABEL: TEST: testTransferReadOp
 @run
 def testTransferReadOp():
-  module = Module.create()
-  with InsertionPoint(module.body):
-    vector_type = VectorType.get([2, 3], F32Type.get())
-    memref_type = MemRefType.get(
-        [ShapedType.get_dynamic_size(),
-         ShapedType.get_dynamic_size()], F32Type.get())
-    index_type = IndexType.get()
-    mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
-    identity_map = AffineMap.get_identity(vector_type.rank)
-    identity_map_attr = AffineMapAttr.get(identity_map)
-    f = func.FuncOp("transfer_read",
-                          ([memref_type, index_type,
-                            F32Type.get(), mask_type], []))
-    with InsertionPoint(f.add_entry_block()):
-      A, zero, padding, mask = f.arguments
-      vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
-                            padding, mask=mask)
-      vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
-                            padding)
-      func.ReturnOp([])
-
-  # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
-  # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
-  # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
-  # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
-  # CHECK-NOT: %[[MASK]]
-  print(module)
+    module = Module.create()
+    with InsertionPoint(module.body):
+        vector_type = VectorType.get([2, 3], F32Type.get())
+        memref_type = MemRefType.get(
+            [ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()],
+            F32Type.get(),
+        )
+        index_type = IndexType.get()
+        mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
+        identity_map = AffineMap.get_identity(vector_type.rank)
+        identity_map_attr = AffineMapAttr.get(identity_map)
+        f = func.FuncOp(
+            "transfer_read", ([memref_type, index_type, F32Type.get(), mask_type], [])
+        )
+        with InsertionPoint(f.add_entry_block()):
+            A, zero, padding, mask = f.arguments
+            vector.TransferReadOp(
+                vector_type, A, [zero, zero], identity_map_attr, padding, mask=mask
+            )
+            vector.TransferReadOp(
+                vector_type, A, [zero, zero], identity_map_attr, padding
+            )
+            func.ReturnOp([])
+
+    # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
+    # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
+    # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
+    # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
+    # CHECK-NOT: %[[MASK]]
+    print(module)

diff  --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 973810d5b82f4..50d6e82348a9f 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -10,34 +10,36 @@
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
 def log(*args):
-  print(*args, file=sys.stderr)
-  sys.stderr.flush()
+    print(*args, file=sys.stderr)
+