[Mlir-commits] [mlir] 1a09ffe - [mlir][ArmSME][NFC] Check early for unsupported mask ops (#135955)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 16 23:48:16 PDT 2025
Author: Matthias Springer
Date: 2025-04-17T08:48:12+02:00
New Revision: 1a09ffea31e22172ce3de2fd553b770f52add576
URL: https://github.com/llvm/llvm-project/commit/1a09ffea31e22172ce3de2fd553b770f52add576
DIFF: https://github.com/llvm/llvm-project/commit/1a09ffea31e22172ce3de2fd553b770f52add576.diff
LOG: [mlir][ArmSME][NFC] Check early for unsupported mask ops (#135955)
This is to avoid rollbacks in the dialect conversion, which are
expensive.
Note: This is in preparation of the One-Shot Dialect Conversion
refactoring.
Added:
Modified:
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 6ed29903ea407..630414030d98b 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -77,11 +77,6 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
Value upperBound;
if (mask) {
auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp)
- return rewriter.notifyMatchFailure(
- loc, "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
-
auto maskDim0 = createMaskOp.getOperands()[0];
auto maskDim1 = createMaskOp.getOperands()[1];
@@ -184,6 +179,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
Value initTile;
if (mask) {
+ if (!mask.getDefiningOp<vector::CreateMaskOp>())
+ return rewriter.notifyMatchFailure(
+ loc, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
auto padOp = tileLoadOp.getPadding();
assert(padOp && "expected padding when masking!");
@@ -373,6 +372,14 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
PatternRewriter &rewriter) const override {
+ if (Value mask = tileStoreOp.getMask()) {
+ if (!mask.getDefiningOp<vector::CreateMaskOp>())
+ return rewriter.notifyMatchFailure(
+ tileStoreOp.getLoc(),
+ "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+ }
+
// Create a loop that stores each active ZA tile slice from memory.
return createLoadStoreForOverTileSlices(
rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
More information about the Mlir-commits
mailing list