[Mlir-commits] [mlir] [mlir][ArmSME] Migrate `arm-sme-vector-legalization` to dialect conversion (PR #121101)
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 25 01:37:29 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/121101
Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.
>From 9548de0fb71cf19f04d725e57e26b909460d9ac3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 24 Dec 2024 17:16:47 +0100
Subject: [PATCH] migrate to dialect conversion
---
.../ArmSME/Transforms/VectorLegalization.cpp | 94 +++++++++++--------
1 file changed, 56 insertions(+), 38 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 61767f3b21c9c3..12c65a72babcb8 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -17,7 +17,7 @@
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -25,7 +25,8 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "arm-sme-vector-legalization"
@@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
/// Legalize `arith.constant dense<value>` splat operations to fit within SME
/// tiles by decomposing them into tile-sized operations.
struct LegalizeArithConstantOpsByDecomposition
- : public OneToNOpConversionPattern<arith::ConstantOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = dyn_cast<VectorType>(constantOp.getType());
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
if (!vectorType || !denseAttr || !denseAttr.isSplat())
@@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
auto tileSplat = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
- rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
- adaptor.getResultMapping());
+ SmallVector<Value> repl(tileCount, tileSplat);
+ rewriter.replaceOpWithMultiple(constantOp, {repl});
return success();
}
@@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeVectorOuterProductOpsByDecomposition
- : public OneToNOpConversionPattern<vector::OuterProductOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::OuterProductOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::OuterProductOp outerProductOp,
+ OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = outerProductOp.getResultVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(outerProductOp,
@@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
auto maskOp = outerProductOp.getMaskingOp();
mask = maskOp.getMask();
rootOp = maskOp;
+ rewriter.setInsertionPoint(rootOp);
}
if (!isSupportedMaskOp(mask))
@@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
resultSMETiles.push_back(maskedOuterProduct->getResult(0));
}
- rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+ rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
return success();
}
};
@@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
// (invalid). This pattern matches on `vector.mask` then calls into the
// `vector.outerproduct` pattern to work around this issue.
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
- : public OneToNOpConversionPattern<vector::MaskOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::MaskOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
maskOp.getMaskableOp())) {
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
@@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
/// Legalize `vector.transfer_read` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferReadOpsByDecomposition
- : public OneToNOpConversionPattern<vector::TransferReadOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::TransferReadOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = readOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(readOp,
@@ -319,7 +322,7 @@ struct LegalizeTransferReadOpsByDecomposition
resultSMETiles.push_back(smeRead);
}
- rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+ rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
return success();
}
};
@@ -327,12 +330,12 @@ struct LegalizeTransferReadOpsByDecomposition
/// Legalize `vector.transfer_write` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferWriteOpsByDecomposition
- : public OneToNOpConversionPattern<vector::TransferWriteOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::TransferWriteOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = writeOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(writeOp,
@@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition
/// }
/// ```
struct LegalizeMultiTileTransferWriteAsStoreLoop
- : public OneToNOpConversionPattern<vector::TransferWriteOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::TransferWriteOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
if (writeOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
writeOp, "TODO: tensor semantics are unsupported");
@@ -936,10 +939,16 @@ struct VectorLegalizationPass
return success();
});
- patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
- LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes,
- LowerIllegalTransposeStoreViaZA>(context);
+ // Apply preprocessing patterns.
+ RewritePatternSet rewritePatterns(context);
+ rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
+ LiftIllegalVectorTransposeToMemory,
+ ConvertIllegalShapeCastOpsToTransposes,
+ LowerIllegalTransposeStoreViaZA>(context);
+ if (failed(
+ applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
+ return signalPassFailure();
+
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
@@ -950,11 +959,20 @@ struct VectorLegalizationPass
LegalizeVectorOuterProductOpsByDecomposition,
LegalizeTransferReadOpsByDecomposition,
LegalizeTransferWriteOpsByDecomposition>(converter, context);
- populateFuncTypeConversionPatterns(converter, patterns);
- scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
-
- if (failed(applyPartialOneToNConversion(getOperation(), converter,
- std::move(patterns))))
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
+ populateReturnOpTypeConversionPattern(patterns, converter);
+ scf::populateSCFStructuralTypeConversions(converter, patterns);
+
+ ConversionTarget target(getContext());
+ target.markUnknownOpDynamicallyLegal(
+ [&](Operation *op) { return converter.isLegal(op); });
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return converter.isSignatureLegal(op.getFunctionType());
+ });
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
return signalPassFailure();
}
};
More information about the Mlir-commits
mailing list