[Mlir-commits] [mlir] [mlir][vector] Replace cast_or_null with dyn_cast_or_null (PR #108048)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 10 08:55:25 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This patch replaces cast_or_null with dyn_cast_or_null, which fixes a crash when operation in `vector.mask` not implements the `MaskableOpInterface`. Fixes #<!-- -->107811.

---
Full diff: https://github.com/llvm/llvm-project/pull/108048.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (+1-1) 
- (modified) mlir/test/Dialect/Vector/lower-vector-mask.mlir (+9) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index bfc05c71f53401..028dcfd72fcb91 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,7 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
 private:
   LogicalResult matchAndRewrite(MaskOp maskOp,
                                 PatternRewriter &rewriter) const final {
-    auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+    auto maskableOp = dyn_cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
     if (!maskableOp)
       return failure();
     SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index a8a1164e2f762b..bd37e0172f6182 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -88,3 +88,12 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
   return %0 : vector<8xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @vector_mask_with_unmaskable_op
+//       CHECK:   vector.mask
+func.func @vector_mask_with_unmaskable_op(%arg0: vector<2xf32>) -> vector<2xi32> {
+  %mask = arith.constant dense<[0, 1]> : vector<2xi1>
+  %0 = vector.mask %mask { vector.bitcast %arg0 : vector<2xf32> to vector<2xi32> } : vector<2xi1> -> vector<2xi32>
+  return %0 : vector<2xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/108048


More information about the Mlir-commits mailing list