[Mlir-commits] [mlir] [mlir][shape] Fix crash in ShapeOfOpToConstShapeOp (PR #180737)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 10 05:50:17 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-shape
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR fixes a crash when `shape.shape_of` op has static arg and shape result type. Fixes #<!-- -->180719.
---
Full diff: https://github.com/llvm/llvm-project/pull/180737.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+7-3)
- (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c1210eef4e589..5db18ae5b3b99 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1706,14 +1706,18 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
if (!type || !type.hasStaticShape())
return failure();
+ Type resultType = op.getResult().getType();
+ if (isa<ShapeType>(resultType))
+ return failure();
+
Location loc = op.getLoc();
Value constShape =
ConstShapeOp::create(rewriter, loc,
rewriter.getIndexTensorAttr(type.getShape()))
.getResult();
- if (constShape.getType() != op.getResult().getType())
- constShape = tensor::CastOp::create(rewriter, loc,
- op.getResult().getType(), constShape);
+ if (constShape.getType() != resultType)
+ constShape =
+ tensor::CastOp::create(rewriter, loc, resultType, constShape);
rewriter.replaceOp(op, constShape);
return success();
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index f3c25b8c8100e..697bfb19e70a4 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1626,3 +1626,14 @@ func.func @shape_of_0d(%arg0: tensor<f32>) -> tensor<?xindex> {
%0 = shape.shape_of %arg0 : tensor<f32> -> tensor<?xindex>
return %0 : tensor<?xindex>
}
+
+// -----
+
+// Ensure this case not crash.
+
+// CHECK-LABEL: func @shape_of_static_with_shape_result(
+func.func @shape_of_static_with_shape_result(%arg0: tensor<f32>) -> !shape.shape {
+ // CHECK: shape.shape_of
+ %0 = shape.shape_of %arg0 : tensor<f32> -> !shape.shape
+ return %0 : !shape.shape
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/180737
More information about the Mlir-commits
mailing list