[Mlir-commits] [mlir] Tensor shape ops no assert (PR #179005)
Samarth Narang
llvmlistbot at llvm.org
Fri Jan 30 16:49:04 PST 2026
https://github.com/snarang181 created https://github.com/llvm/llvm-project/pull/179005
Fixes https://github.com/llvm/llvm-project/issues/178228
>From 2976b2b28cd7afffb822fedfb2ff714720fc0026 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 19:42:16 -0500
Subject: [PATCH 1/2] [mlir][tensor] Avoid assert fail in verifier
tensor's collapse_shape and expand_shape ops
assert when the operand/result is unranked.
Guard the verifier using dyn_cast and emit
a meaningful error message instead.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 22 +++++++++++++++++++---
1 file changed, 19 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d885d2c871e3f..2c406cb351c2d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2051,8 +2051,14 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
}
LogicalResult ExpandShapeOp::verify() {
- auto srcType = getSrcType();
- auto resultType = getResultType();
+ auto srcType = llvm::dyn_cast<RankedTensorType>(getSrc().getType());
+ if (!srcType)
+ return emitOpError("expects ranked tensor source type, but got ")
+ << getSrc().getType();
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
+ if (!resultType)
+ return emitOpError("expects ranked tensor result type, but got ")
+ << getResult().getType();
if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
return emitOpError("expected number of static shape dims to be equal to "
@@ -2077,7 +2083,17 @@ LogicalResult CollapseShapeOp::verify() {
[](ReassociationIndices group) { return group.empty(); })) {
return op.emitOpError("reassociation indices must not be empty");
}
- return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
+ auto srcType = llvm::dyn_cast<RankedTensorType>(op.getSrc().getType());
+ if (!srcType)
+ return op.emitOpError("expects ranked tensor source type, but got ")
+ << op.getSrc().getType();
+
+ auto resultType = llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
+ if (!resultType)
+ return op.emitOpError("expects ranked tensor result type, but got ")
+ << op.getResult().getType();
+
+ return verifyTensorReshapeOp(op, srcType, resultType);
}
namespace {
>From d80cc501a97d55344e318dd91b4f430260353352 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 19:48:29 -0500
Subject: [PATCH 2/2] Add test cases
---
mlir/test/Dialect/Tensor/invalid.mlir | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 63be5493e8935..dadf586005173 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -690,3 +690,19 @@ func.func @test_empty_reassociation(%arg0: tensor<1x?xf32>) -> tensor<?x10xf32>
return %0 : tensor<?x10xf32>
}
+// -----
+
+func.func @collapse_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
+ // expected-error at +1 {{expects ranked tensor source type}}
+ %0 = tensor.collapse_shape %arg0 [[0]] : tensor<*xf32> into tensor<f32>
+ return
+}
+
+// -----
+
+func.func @expand_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
+ // expected-error at +1 {{expects ranked tensor source type}}
+ %0 = tensor.expand_shape %arg0 [[0]] output_shape [1] : tensor<*xf32> into tensor<1xf32>
+ return
+}
+
More information about the Mlir-commits
mailing list