[Mlir-commits] [mlir] [mlir][vector] Handle empty `MaskOp` in `LowerVectorMask`, `MaskOpRewritePattern` (PR #72031)

Felix Schneider llvmlistbot at llvm.org
Sat Nov 11 08:43:45 PST 2023


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/72031

>From a23333b2c4ec4aa21f44d9d032c03f456a9c3b13 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 11 Nov 2023 17:34:51 +0100
Subject: [PATCH] [mlir][vector] Handle empty `MaskOp` in `LowerVectorMask`,
 `MaskOpRewritePattern`

This patch adds handling of an empty `MaskOp` to `MaskOpRewritePattern`
and thereby fixes a crash.
It also pulls the `MaskOp` canonicalization patterns into
`LowerVectorMask` so that empty `MaskOp`s are folded away in the Pass.

Fix https://github.com/llvm/llvm-project/issues/71036
---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp |  5 ++++-
 mlir/test/Dialect/Vector/lower-vector-mask.mlir        | 10 ++++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 887d1af7645419f..806301bb5568ca3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,9 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
 private:
   LogicalResult matchAndRewrite(MaskOp maskOp,
                                 PatternRewriter &rewriter) const final {
-    auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
+    auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+    if (!maskableOp)
+      return failure();
     SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
     if (!sourceOp)
       return failure();
@@ -282,6 +284,7 @@ struct LowerVectorMaskPass
 
     RewritePatternSet loweringPatterns(context);
     populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
+    MaskOp::populateCanonicalizationPatterns(loweringPatterns, context);
 
     if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
       signalPassFailure();
diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index 8f8fae095cac37c..05bff7d3575bb8b 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -77,3 +77,13 @@ func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<
 // CHECK:           %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
 // CHECK:           %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
 
+// -----
+
+// CHECK-LABEL: func @empty_vector_mask_with_return
+//  CHECK-SAME:     %[[IN:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
+//   CHECK-NOT:   vector.mask
+//       CHECK:   return %[[IN]] : vector<8xf32>
+  %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
\ No newline at end of file



More information about the Mlir-commits mailing list