[Mlir-commits] [mlir] [mlir][vector] Replace cast_or_null with dyn_cast_or_null (PR #108048)

Longsheng Mou llvmlistbot at llvm.org
Tue Sep 10 08:54:53 PDT 2024


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/108048

This patch replaces cast_or_null with dyn_cast_or_null, which fixes a crash when operation in `vector.mask` not implements the `MaskableOpInterface`. Fixes #107811.

>From 24620873e7b00838394165d47fc2ccbe47370abd Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Tue, 10 Sep 2024 23:49:37 +0800
Subject: [PATCH] [mlir][vector] Replace cast_or_null with dyn_cast_or_null

This patch replaces cast_or_null with dyn_cast_or_null, which fixes
a crash when operation in `vector.mask` not implements the
`MaskableOpInterface`.
---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp | 2 +-
 mlir/test/Dialect/Vector/lower-vector-mask.mlir        | 9 +++++++++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index bfc05c71f53401..028dcfd72fcb91 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,7 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
 private:
   LogicalResult matchAndRewrite(MaskOp maskOp,
                                 PatternRewriter &rewriter) const final {
-    auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+    auto maskableOp = dyn_cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
     if (!maskableOp)
       return failure();
     SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index a8a1164e2f762b..bd37e0172f6182 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -88,3 +88,12 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
   return %0 : vector<8xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @vector_mask_with_unmaskable_op
+//       CHECK:   vector.mask
+func.func @vector_mask_with_unmaskable_op(%arg0: vector<2xf32>) -> vector<2xi32> {
+  %mask = arith.constant dense<[0, 1]> : vector<2xi1>
+  %0 = vector.mask %mask { vector.bitcast %arg0 : vector<2xf32> to vector<2xi32> } : vector<2xi1> -> vector<2xi32>
+  return %0 : vector<2xi32>
+}



More information about the Mlir-commits mailing list