[Mlir-commits] [mlir] [MLIR][Python] Add structured.fuseop to generator. (PR #120601)

Hugo Trachino llvmlistbot at llvm.org
Thu Dec 19 08:27:15 PST 2024


https://github.com/nujaa created https://github.com/llvm/llvm-project/pull/120601

Implements a python interface for structured fuseOp allowing more freedom on inputs.

>From 375fa4f42ebcd3fa874944e13e49be6461810a3f Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Fri, 16 Aug 2024 17:05:59 +0800
Subject: [PATCH 1/2] [mlir][vector] Rename LowerVectorToLLVM to
 ConvertVectorToLLVM (NFC)

---
 mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h   | 2 +-
 .../lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp | 6 +++---
 .../SparseTensor/Pipelines/SparseTensorPipelines.cpp        | 6 +++---
 3 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 90021ffa7c380bc..efbe5c56a219b37 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -162,7 +162,7 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
   }
 
   /// Projects out the options for `createConvertVectorToLLVMPass`.
-  ConvertVectorToLLVMPassOptions lowerVectorToLLVMOptions() const {
+  ConvertVectorToLLVMPassOptions convertVectorToLLVMOptions() const {
     ConvertVectorToLLVMPassOptions opts{};
     opts.reassociateFPReductions = reassociateFPReductions;
     opts.force32BitVectorIndices = force32BitVectorIndices;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 55143d5939ba257..842d239cf6a512c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -35,8 +35,8 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-struct LowerVectorToLLVMPass
-    : public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> {
+struct ConvertVectorToLLVMPass
+    : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> {
 
   using Base::Base;
 
@@ -58,7 +58,7 @@ struct LowerVectorToLLVMPass
 };
 } // namespace
 
-void LowerVectorToLLVMPass::runOnOperation() {
+void ConvertVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and
   // all contraction operations. Also applies folding and DCE.
   {
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index c5eb965884396ae..5e49252c0e57d98 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -76,16 +76,16 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
   pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
   pm.addPass(memref::createExpandStridedMetadataPass());
   pm.addPass(createLowerAffinePass());
-  pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+  pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
   pm.addPass(createFinalizeMemRefToLLVMConversionPass());
   pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
   pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
   pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
   pm.addPass(createConvertMathToLibmPass());
   pm.addPass(createConvertComplexToLibmPass());
-  pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+  pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
   pm.addPass(createConvertComplexToLLVMPass());
-  pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+  pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
   pm.addPass(createConvertFuncToLLVMPass());
 
   // Finalize GPU code generation.

>From 8a4bf109fbf158d1cff2e90625a6516adc532a7f Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 19 Dec 2024 19:30:52 +0800
Subject: [PATCH 2/2] [MLIR][Python] Add structured.fuseop to generator.

---
 .../mlir/dialects/transform/structured.py     | 52 +++++++++++++++++++
 .../dialects/transform_structured_ext.py      | 21 ++++++++
 2 files changed, 73 insertions(+)

diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 41051c0d5b2ffb6..b97a1aa8a82bfcf 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -139,6 +139,58 @@ def __init__(
             ip=ip,
         )
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuseOp(FuseOp):
+    """Specialization for FuseOp class."""
+
+    @overload
+    def __init__(
+        self,
+        target: Union[Operation, Value, OpView],
+        *,
+        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    def __init__(
+        self,
+        loop_types_or_target: Union[Type, List[Type], Operation, Value],
+        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+        *,
+        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        sizes = sizes if sizes else []
+        num_loops = sum(v if v == 0 else 1 for v in sizes)
+
+        if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+            loop_types = [transform.AnyOpType.get()] * num_loops
+            target = loop_types_or_target
+            assert (
+                target_or_none is None
+            ), "Cannot construct FuseOp with two targets."
+        else:
+            loop_types = (
+                ([loop_types_or_target] * num_loops)
+                if isinstance(loop_types_or_target, Type)
+                else loop_types_or_target
+            )
+            target = target_or_none
+        super().__init__(
+            target.type,
+            loop_types,
+            target,
+            tile_sizes=sizes,
+            tile_interchange=interchange,
+            loc=loc,
+            ip=ip,
+        )
+
 
 @_ods_cext.register_operation(_Dialect, replace=True)
 class GeneralizeOp(GeneralizeOp):
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 3ea73e8beea3688..551c2fa1e48acd3 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -101,6 +101,27 @@ def testFuseIntoContainingOpCompact(target):
     # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
 
 
+ at run
+ at create_sequence
+def testFuseOpCompact(target):
+    structured.FuseOp(target, sizes=[4, 8], interchange=[0, 1])
+    # CHECK-LABEL: TEST: testFuseOpCompact
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+    # CHECK-SAME: interchange [0, 1]
+    # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+ at run
+ at create_sequence
+def testFuseOpNoArg(target):
+    structured.FuseOp(target)
+    # CHECK-LABEL: TEST: testFuseOpNoArg
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
+    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
 @run
 @create_sequence
 def testGeneralize(target):



More information about the Mlir-commits mailing list