[Mlir-commits] [mlir] [mlir][vector] Replace cast_or_null with dyn_cast_or_null (PR #108048)

Longsheng Mou llvmlistbot at llvm.org
Tue Sep 10 09:04:32 PDT 2024


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/108048

>From 13ea7f8a0cbdb426fd9c8d81221b078171c945f0 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Tue, 10 Sep 2024 23:49:37 +0800
Subject: [PATCH] [mlir][vector] Replace cast_or_null with dyn_cast_or_null

This patch replaces cast_or_null with dyn_cast_or_null, which fixes
a crash when operation in `vector.mask` not implements the
`MaskableOpInterface`.
---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp | 3 ++-
 mlir/test/Dialect/Vector/lower-vector-mask.mlir        | 9 +++++++++
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index bfc05c71f53401..794a9328fc7096 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -188,7 +188,8 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
 private:
   LogicalResult matchAndRewrite(MaskOp maskOp,
                                 PatternRewriter &rewriter) const final {
-    auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
+    auto maskableOp =
+        dyn_cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
     if (!maskableOp)
       return failure();
     SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index a8a1164e2f762b..bd37e0172f6182 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -88,3 +88,12 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
   return %0 : vector<8xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @vector_mask_with_unmaskable_op
+//       CHECK:   vector.mask
+func.func @vector_mask_with_unmaskable_op(%arg0: vector<2xf32>) -> vector<2xi32> {
+  %mask = arith.constant dense<[0, 1]> : vector<2xi1>
+  %0 = vector.mask %mask { vector.bitcast %arg0 : vector<2xf32> to vector<2xi32> } : vector<2xi1> -> vector<2xi32>
+  return %0 : vector<2xi32>
+}



More information about the Mlir-commits mailing list