[Mlir-commits] [mlir] [mlir][vector] Handle empty `MaskOp` in `LowerVectorMask`, `MaskOpRewritePattern` (PR #72031)
Felix Schneider
llvmlistbot at llvm.org
Sat Nov 11 08:44:46 PST 2023
https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/72031
>From a5e3445c636a4759b426cb0ac895fd30c8e72d91 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 11 Nov 2023 17:34:51 +0100
Subject: [PATCH] [mlir][vector] Handle empty `MaskOp` in `LowerVectorMask`,
`MaskOpRewritePattern`
This patch adds handling of an empty `MaskOp` to `MaskOpRewritePattern`
and thereby fixes a crash.
It also pulls the `MaskOp` canonicalization patterns into
`LowerVectorMask` so that empty `MaskOp`s are folded away in the Pass.
Fix https://github.com/llvm/llvm-project/issues/71036
---
.../lib/Dialect/Vector/Transforms/LowerVectorMask.cpp | 5 ++++-
mlir/test/Dialect/Vector/lower-vector-mask.mlir | 11 +++++++++++
2 files changed, 15 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 887d1af7645419f..806301bb5568ca3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,9 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const final {
- auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
+ auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+ if (!maskableOp)
+ return failure();
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
if (!sourceOp)
return failure();
@@ -282,6 +284,7 @@ struct LowerVectorMaskPass
RewritePatternSet loweringPatterns(context);
populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
+ MaskOp::populateCanonicalizationPatterns(loweringPatterns, context);
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index 8f8fae095cac37c..a8a1164e2f762b8 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -77,3 +77,14 @@ func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<
// CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
+// -----
+
+// CHECK-LABEL: func @empty_vector_mask_with_return
+// CHECK-SAME: %[[IN:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
+// CHECK-NOT: vector.mask
+// CHECK: return %[[IN]] : vector<8xf32>
+ %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
More information about the Mlir-commits
mailing list