[Mlir-commits] [mlir] [mlir][shape] Fix crash in ShapeOfOpToConstShapeOp (PR #180737)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 10 05:50:17 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-shape

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR fixes a crash when `shape.shape_of` op has static arg and shape result type. Fixes #<!-- -->180719.

---
Full diff: https://github.com/llvm/llvm-project/pull/180737.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+7-3) 
- (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+11) 


``````````diff
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c1210eef4e589..5db18ae5b3b99 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1706,14 +1706,18 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
     auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
     if (!type || !type.hasStaticShape())
       return failure();
+    Type resultType = op.getResult().getType();
+    if (isa<ShapeType>(resultType))
+      return failure();
+
     Location loc = op.getLoc();
     Value constShape =
         ConstShapeOp::create(rewriter, loc,
                              rewriter.getIndexTensorAttr(type.getShape()))
             .getResult();
-    if (constShape.getType() != op.getResult().getType())
-      constShape = tensor::CastOp::create(rewriter, loc,
-                                          op.getResult().getType(), constShape);
+    if (constShape.getType() != resultType)
+      constShape =
+          tensor::CastOp::create(rewriter, loc, resultType, constShape);
     rewriter.replaceOp(op, constShape);
     return success();
   }
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index f3c25b8c8100e..697bfb19e70a4 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1626,3 +1626,14 @@ func.func @shape_of_0d(%arg0: tensor<f32>) -> tensor<?xindex> {
   %0 = shape.shape_of %arg0 : tensor<f32> -> tensor<?xindex>
   return %0 : tensor<?xindex>
 }
+
+// -----
+
+// Ensure this case not crash.
+
+// CHECK-LABEL: func @shape_of_static_with_shape_result(
+func.func @shape_of_static_with_shape_result(%arg0: tensor<f32>) -> !shape.shape {
+  // CHECK: shape.shape_of
+  %0 = shape.shape_of %arg0 : tensor<f32> -> !shape.shape
+  return %0 : !shape.shape
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/180737


More information about the Mlir-commits mailing list