[Mlir-commits] [mlir] [mlir][mesh] Add null check for dyn_cast to prevent crash (PR #149266)

Longsheng Mou llvmlistbot at llvm.org
Wed Jul 16 23:58:00 PDT 2025


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/149266

This PR adds a null check for dyn_cast result before use to prevent crash, and use `isa` instead `dyn_cast` to make code clean. Fixes #148619.

>From 71bd6a1d6e06bf90ccf83ce698c9fee65fccfbe9 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 17 Jul 2025 14:52:30 +0800
Subject: [PATCH 1/2] [mlir][mesh] Add null check for dyn_cast to prevent crash

This PR adds a null check for dyn_cast to prevent crash, and use `isa` instead `dyn_cast` to make code clean.
---
 .../mlir/Dialect/Mesh/Transforms/Simplifications.h        | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index 3f1041cb25103..243dbf081b999 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -62,9 +62,11 @@ void populateAllReduceEndomorphismSimplificationPatterns(
   auto isEndomorphismOp = [reduction](Operation *op,
                                       std::optional<Operation *> referenceOp) {
     auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+    if (!allReduceOp)
+      return false;
     auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
     auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
-    if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
+    if (inType.getElementType() != outType.getElementType() ||
         allReduceOp.getReduction() != reduction) {
       return false;
     }
@@ -87,9 +89,7 @@ void populateAllReduceEndomorphismSimplificationPatterns(
     return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
            inType.getElementType() == refType.getElementType();
   };
-  auto isAlgebraicOp = [](Operation *op) {
-    return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
-  };
+  auto isAlgebraicOp = [](Operation *op) { return isa<AlgebraicOp>(op); };
 
   using ConcreteEndomorphismSimplification = EndomorphismSimplification<
       std::decay_t<decltype(getEndomorphismOpOperand)>,

>From 7071c472001e5b08e289cc3bedb9362b0e345e2e Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 17 Jul 2025 14:54:59 +0800
Subject: [PATCH 2/2] add test

---
 mlir/test/Dialect/Mesh/simplifications.mlir | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
index 2540fbf9510c4..e955f4c134259 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -165,3 +165,15 @@ func.func @all_reduce_arith_minsi_endomorphism(
   // CHECK: return %[[ALL_REDUCE_RES]]
   return %2 : tensor<5xi32>
 }
+
+// Ensure this case without endomorphism op not crash.
+// CHECK-LABEL: func.func @no_endomorphism_op
+func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 {
+  %c0 = arith.constant 0 : index
+  %c1_i64 = arith.constant 1 : i64
+  // CHECK: tensor.extract
+  %extracted = tensor.extract %arg0[%c0] : tensor<2xi64>
+  // CHECK: arith.maxsi
+  %0 = arith.maxsi %extracted, %c1_i64 : i64
+  return %0 : i64
+}



More information about the Mlir-commits mailing list