[Mlir-commits] [mlir] [mlir][vector] Replace cast_or_null with dyn_cast_or_null (PR #108048)
Longsheng Mou
llvmlistbot at llvm.org
Tue Sep 10 09:04:32 PDT 2024
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/108048
>From 13ea7f8a0cbdb426fd9c8d81221b078171c945f0 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Tue, 10 Sep 2024 23:49:37 +0800
Subject: [PATCH] [mlir][vector] Replace cast_or_null with dyn_cast_or_null
This patch replaces cast_or_null with dyn_cast_or_null, which fixes
a crash when operation in `vector.mask` not implements the
`MaskableOpInterface`.
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp | 3 ++-
mlir/test/Dialect/Vector/lower-vector-mask.mlir | 9 +++++++++
2 files changed, 11 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index bfc05c71f53401..794a9328fc7096 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,8 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const final {
- auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+ auto maskableOp =
+ dyn_cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
if (!maskableOp)
return failure();
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index a8a1164e2f762b..bd37e0172f6182 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -88,3 +88,12 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
return %0 : vector<8xf32>
}
+// -----
+
+// CHECK-LABEL: func @vector_mask_with_unmaskable_op
+// CHECK: vector.mask
+func.func @vector_mask_with_unmaskable_op(%arg0: vector<2xf32>) -> vector<2xi32> {
+ %mask = arith.constant dense<[0, 1]> : vector<2xi1>
+ %0 = vector.mask %mask { vector.bitcast %arg0 : vector<2xf32> to vector<2xi32> } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
More information about the Mlir-commits
mailing list