[Mlir-commits] [mlir] d5a0fb3 - [mlir][vector] Handle empty `MaskOp` in `LowerVectorMask`, `MaskOpRewritePattern` (#72031)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 11 23:12:32 PST 2023


Author: Felix Schneider
Date: 2023-11-12T08:12:28+01:00
New Revision: d5a0fb39ae1d481fe75c3d2c3d42df3de977762b

URL: https://github.com/llvm/llvm-project/commit/d5a0fb39ae1d481fe75c3d2c3d42df3de977762b
DIFF: https://github.com/llvm/llvm-project/commit/d5a0fb39ae1d481fe75c3d2c3d42df3de977762b.diff

LOG: [mlir][vector] Handle empty `MaskOp` in `LowerVectorMask`, `MaskOpRewritePattern` (#72031)

This patch adds handling of an empty `MaskOp` to `MaskOpRewritePattern`
and thereby fixes a crash.
It also pulls the `MaskOp` canonicalization patterns into
`LowerVectorMask` so that empty `MaskOp`s are folded away in the Pass.

Fix https://github.com/llvm/llvm-project/issues/71036

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
    mlir/test/Dialect/Vector/lower-vector-mask.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 887d1af7645419f..f53bb5157eb37bc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,9 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
 private:
   LogicalResult matchAndRewrite(MaskOp maskOp,
                                 PatternRewriter &rewriter) const final {
-    auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
+    auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+    if (!maskableOp)
+      return failure();
     SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
     if (!sourceOp)
       return failure();
@@ -282,6 +284,7 @@ struct LowerVectorMaskPass
 
     RewritePatternSet loweringPatterns(context);
     populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
+    MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
 
     if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
       signalPassFailure();

diff  --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index 8f8fae095cac37c..a8a1164e2f762b8 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -77,3 +77,14 @@ func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<
 // CHECK:           %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
 // CHECK:           %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
 
+// -----
+
+// CHECK-LABEL: func @empty_vector_mask_with_return
+//  CHECK-SAME:     %[[IN:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
+//   CHECK-NOT:   vector.mask
+//       CHECK:   return %[[IN]] : vector<8xf32>
+  %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+


        


More information about the Mlir-commits mailing list