[Mlir-commits] [mlir] [mlir][vector] Handle empty `MaskOp` in `LowerVectorMask`, `MaskOpRewritePattern` (PR #72031)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 11 08:42:43 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Felix Schneider (ubfx)

<details>
<summary>Changes</summary>

This patch adds handling of an empty `MaskOp` to `MaskOpRewritePattern` and thereby fixes a crash.
It also pulls the `MaskOp` canonicalization patterns into `LowerVectorMask` so that empty `MaskOp`s are folded away in the Pass.

Fix https://github.com/llvm/llvm-project/issues/71036

---
Full diff: https://github.com/llvm/llvm-project/pull/72031.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (+4-1) 
- (modified) mlir/test/Dialect/Vector/lower-vector-mask.mlir (+10) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 887d1af7645419f..806301bb5568ca3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,9 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
 private:
   LogicalResult matchAndRewrite(MaskOp maskOp,
                                 PatternRewriter &rewriter) const final {
-    auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
+    auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+    if (!maskableOp)
+      return failure();
     SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
     if (!sourceOp)
       return failure();
@@ -282,6 +284,7 @@ struct LowerVectorMaskPass
 
     RewritePatternSet loweringPatterns(context);
     populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
+    MaskOp::populateCanonicalizationPatterns(loweringPatterns, context);
 
     if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
       signalPassFailure();
diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index 8f8fae095cac37c..05bff7d3575bb8b 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -77,3 +77,13 @@ func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<
 // CHECK:           %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
 // CHECK:           %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
 
+// -----
+
+// CHECK-LABEL: func @empty_vector_mask_with_return
+//  CHECK-SAME:     %[[IN:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
+//   CHECK-NOT:   vector.mask
+//       CHECK:   return %[[IN]] : vector<8xf32>
+  %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/72031


More information about the Mlir-commits mailing list