[Mlir-commits] [mlir] [mlir][ArmSME] Migrate `arm-sme-vector-legalization` to dialect conversion (PR #121101)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 25 01:37:59 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.


---
Full diff: https://github.com/llvm/llvm-project/pull/121101.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+56-38) 


``````````diff
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 61767f3b21c9c3..12c65a72babcb8 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -17,7 +17,7 @@
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -25,7 +25,8 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "arm-sme-vector-legalization"
 
@@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
 /// Legalize `arith.constant dense<value>` splat operations to fit within SME
 /// tiles by decomposing them into tile-sized operations.
 struct LegalizeArithConstantOpsByDecomposition
-    : public OneToNOpConversionPattern<arith::ConstantOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+                  ConversionPatternRewriter &rewriter) const override {
     auto vectorType = dyn_cast<VectorType>(constantOp.getType());
     auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
     if (!vectorType || !denseAttr || !denseAttr.isSplat())
@@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition
     auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
     auto tileSplat = rewriter.create<arith::ConstantOp>(
         constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
-    rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
-                       adaptor.getResultMapping());
+    SmallVector<Value> repl(tileCount, tileSplat);
+    rewriter.replaceOpWithMultiple(constantOp, {repl});
 
     return success();
   }
@@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition
 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeVectorOuterProductOpsByDecomposition
-    : public OneToNOpConversionPattern<vector::OuterProductOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::OuterProductOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::OuterProductOp outerProductOp,
+                  OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto vectorType = outerProductOp.getResultVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
       return rewriter.notifyMatchFailure(outerProductOp,
@@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
       auto maskOp = outerProductOp.getMaskingOp();
       mask = maskOp.getMask();
       rootOp = maskOp;
+      rewriter.setInsertionPoint(rootOp);
     }
 
     if (!isSupportedMaskOp(mask))
@@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
       resultSMETiles.push_back(maskedOuterProduct->getResult(0));
     }
 
-    rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+    rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
     return success();
   }
 };
@@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
 // (invalid). This pattern matches on `vector.mask` then calls into the
 // `vector.outerproduct` pattern to work around this issue.
 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
-    : public OneToNOpConversionPattern<vector::MaskOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::MaskOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
             maskOp.getMaskableOp())) {
       LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
@@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
 /// Legalize `vector.transfer_read` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeTransferReadOpsByDecomposition
-    : public OneToNOpConversionPattern<vector::TransferReadOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::TransferReadOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto vectorType = readOp.getVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
       return rewriter.notifyMatchFailure(readOp,
@@ -319,7 +322,7 @@ struct LegalizeTransferReadOpsByDecomposition
       resultSMETiles.push_back(smeRead);
     }
 
-    rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+    rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
     return success();
   }
 };
@@ -327,12 +330,12 @@ struct LegalizeTransferReadOpsByDecomposition
 /// Legalize `vector.transfer_write` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeTransferWriteOpsByDecomposition
-    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::TransferWriteOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto vectorType = writeOp.getVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
       return rewriter.notifyMatchFailure(writeOp,
@@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition
 /// }
 /// ```
 struct LegalizeMultiTileTransferWriteAsStoreLoop
-    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::TransferWriteOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     if (writeOp.hasPureTensorSemantics())
       return rewriter.notifyMatchFailure(
           writeOp, "TODO: tensor semantics are unsupported");
@@ -936,10 +939,16 @@ struct VectorLegalizationPass
           return success();
         });
 
-    patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
-                 LiftIllegalVectorTransposeToMemory,
-                 ConvertIllegalShapeCastOpsToTransposes,
-                 LowerIllegalTransposeStoreViaZA>(context);
+    // Apply preprocessing patterns.
+    RewritePatternSet rewritePatterns(context);
+    rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
+                        LiftIllegalVectorTransposeToMemory,
+                        ConvertIllegalShapeCastOpsToTransposes,
+                        LowerIllegalTransposeStoreViaZA>(context);
+    if (failed(
+            applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
+      return signalPassFailure();
+
     // Note: These two patterns are added with a high benefit to ensure:
     //  - Masked outer products are handled before unmasked ones
     //  - Multi-tile writes are lowered as a store loop (if possible)
@@ -950,11 +959,20 @@ struct VectorLegalizationPass
                  LegalizeVectorOuterProductOpsByDecomposition,
                  LegalizeTransferReadOpsByDecomposition,
                  LegalizeTransferWriteOpsByDecomposition>(converter, context);
-    populateFuncTypeConversionPatterns(converter, patterns);
-    scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
-
-    if (failed(applyPartialOneToNConversion(getOperation(), converter,
-                                            std::move(patterns))))
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    populateReturnOpTypeConversionPattern(patterns, converter);
+    scf::populateSCFStructuralTypeConversions(converter, patterns);
+
+    ConversionTarget target(getContext());
+    target.markUnknownOpDynamicallyLegal(
+        [&](Operation *op) { return converter.isLegal(op); });
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return converter.isSignatureLegal(op.getFunctionType());
+    });
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
       return signalPassFailure();
   }
 };

``````````

</details>


https://github.com/llvm/llvm-project/pull/121101


More information about the Mlir-commits mailing list