[Mlir-commits] [mlir] [mlir][vector] Replace cast_or_null with dyn_cast_or_null (PR #108048)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 10 08:55:25 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This patch replaces cast_or_null with dyn_cast_or_null, which fixes a crash when operation in `vector.mask` not implements the `MaskableOpInterface`. Fixes #<!-- -->107811.
---
Full diff: https://github.com/llvm/llvm-project/pull/108048.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (+1-1)
- (modified) mlir/test/Dialect/Vector/lower-vector-mask.mlir (+9)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index bfc05c71f53401..028dcfd72fcb91 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,7 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const final {
- auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+ auto maskableOp = dyn_cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
if (!maskableOp)
return failure();
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index a8a1164e2f762b..bd37e0172f6182 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -88,3 +88,12 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
return %0 : vector<8xf32>
}
+// -----
+
+// CHECK-LABEL: func @vector_mask_with_unmaskable_op
+// CHECK: vector.mask
+func.func @vector_mask_with_unmaskable_op(%arg0: vector<2xf32>) -> vector<2xi32> {
+ %mask = arith.constant dense<[0, 1]> : vector<2xi1>
+ %0 = vector.mask %mask { vector.bitcast %arg0 : vector<2xf32> to vector<2xi32> } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/108048
More information about the Mlir-commits
mailing list