[Mlir-commits] [mlir] d5a0fb3 - [mlir][vector] Handle empty `MaskOp` in `LowerVectorMask`, `MaskOpRewritePattern` (#72031)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Nov 11 23:12:32 PST 2023
Author: Felix Schneider
Date: 2023-11-12T08:12:28+01:00
New Revision: d5a0fb39ae1d481fe75c3d2c3d42df3de977762b
URL: https://github.com/llvm/llvm-project/commit/d5a0fb39ae1d481fe75c3d2c3d42df3de977762b
DIFF: https://github.com/llvm/llvm-project/commit/d5a0fb39ae1d481fe75c3d2c3d42df3de977762b.diff
LOG: [mlir][vector] Handle empty `MaskOp` in `LowerVectorMask`, `MaskOpRewritePattern` (#72031)
This patch adds handling of an empty `MaskOp` to `MaskOpRewritePattern`
and thereby fixes a crash.
It also pulls the `MaskOp` canonicalization patterns into
`LowerVectorMask` so that empty `MaskOp`s are folded away in the Pass.
Fix https://github.com/llvm/llvm-project/issues/71036
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/test/Dialect/Vector/lower-vector-mask.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 887d1af7645419f..f53bb5157eb37bc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,9 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const final {
- auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
+ auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+ if (!maskableOp)
+ return failure();
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
if (!sourceOp)
return failure();
@@ -282,6 +284,7 @@ struct LowerVectorMaskPass
RewritePatternSet loweringPatterns(context);
populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
+ MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index 8f8fae095cac37c..a8a1164e2f762b8 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -77,3 +77,14 @@ func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<
// CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
+// -----
+
+// CHECK-LABEL: func @empty_vector_mask_with_return
+// CHECK-SAME: %[[IN:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
+// CHECK-NOT: vector.mask
+// CHECK: return %[[IN]] : vector<8xf32>
+ %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
More information about the Mlir-commits
mailing list