[Mlir-commits] [mlir] [mlir][tensor] Avoid asserts in reshape verifiers on unranked tensors (PR #179005)
Samarth Narang
llvmlistbot at llvm.org
Mon Feb 2 11:57:12 PST 2026
https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/179005
>From 69e33d969c2f6a10fda4e924e269de5a431f3815 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 19:42:16 -0500
Subject: [PATCH 1/5] [mlir][tensor] Avoid assert fail in verifier
tensor's collapse_shape and expand_shape ops
assert when the operand/result is unranked.
Guard the verifier using dyn_cast and emit
a meaningful error message instead.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 22 +++++++++++++++++++---
1 file changed, 19 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d885d2c871e3f..2c406cb351c2d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2051,8 +2051,14 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
}
LogicalResult ExpandShapeOp::verify() {
- auto srcType = getSrcType();
- auto resultType = getResultType();
+ auto srcType = llvm::dyn_cast<RankedTensorType>(getSrc().getType());
+ if (!srcType)
+ return emitOpError("expects ranked tensor source type, but got ")
+ << getSrc().getType();
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
+ if (!resultType)
+ return emitOpError("expects ranked tensor result type, but got ")
+ << getResult().getType();
if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
return emitOpError("expected number of static shape dims to be equal to "
@@ -2077,7 +2083,17 @@ LogicalResult CollapseShapeOp::verify() {
[](ReassociationIndices group) { return group.empty(); })) {
return op.emitOpError("reassociation indices must not be empty");
}
- return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
+ auto srcType = llvm::dyn_cast<RankedTensorType>(op.getSrc().getType());
+ if (!srcType)
+ return op.emitOpError("expects ranked tensor source type, but got ")
+ << op.getSrc().getType();
+
+ auto resultType = llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
+ if (!resultType)
+ return op.emitOpError("expects ranked tensor result type, but got ")
+ << op.getResult().getType();
+
+ return verifyTensorReshapeOp(op, srcType, resultType);
}
namespace {
>From 627e65c8118d13ab90e38bc9fd8091816b202761 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 19:48:29 -0500
Subject: [PATCH 2/5] Add test cases
---
mlir/test/Dialect/Tensor/invalid.mlir | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 63be5493e8935..dadf586005173 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -690,3 +690,19 @@ func.func @test_empty_reassociation(%arg0: tensor<1x?xf32>) -> tensor<?x10xf32>
return %0 : tensor<?x10xf32>
}
+// -----
+
+func.func @collapse_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
+ // expected-error at +1 {{expects ranked tensor source type}}
+ %0 = tensor.collapse_shape %arg0 [[0]] : tensor<*xf32> into tensor<f32>
+ return
+}
+
+// -----
+
+func.func @expand_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
+ // expected-error at +1 {{expects ranked tensor source type}}
+ %0 = tensor.expand_shape %arg0 [[0]] output_shape [1] : tensor<*xf32> into tensor<1xf32>
+ return
+}
+
>From a06d75252d503659ebebd09f6a7c673d80ae4949 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 14:25:07 -0500
Subject: [PATCH 3/5] Change .td description of ExpandShapeOp
---
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 2 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 10 ++--------
2 files changed, 3 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 8b10c00008865..ebad49f66da16 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1133,7 +1133,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
```
}];
- let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+ let arguments = (ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation,
Variadic<Index>:$output_shape,
DenseI64ArrayAttr:$static_output_shape);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2c406cb351c2d..58768c2fe6873 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2051,14 +2051,8 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
}
LogicalResult ExpandShapeOp::verify() {
- auto srcType = llvm::dyn_cast<RankedTensorType>(getSrc().getType());
- if (!srcType)
- return emitOpError("expects ranked tensor source type, but got ")
- << getSrc().getType();
- auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
- if (!resultType)
- return emitOpError("expects ranked tensor result type, but got ")
- << getResult().getType();
+ auto srcType = llvm::cast<RankedTensorType>(getSrc().getType());
+ auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
return emitOpError("expected number of static shape dims to be equal to "
>From 155d15504f669f91858130eb67f893bd42f75528 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 14:28:32 -0500
Subject: [PATCH 4/5] Change .td description for CollapseShapeOp
---
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 2 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 11 ++---------
2 files changed, 3 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index ebad49f66da16..15e0577c6ad70 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1191,7 +1191,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
let summary = "operation to produce a tensor with a smaller rank";
- let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
+ let arguments = (ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation);
let description = [{
The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
rank whose dimension sizes are a reassociation of the original `src` dimensions.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 58768c2fe6873..b0d323a080bd5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2077,15 +2077,8 @@ LogicalResult CollapseShapeOp::verify() {
[](ReassociationIndices group) { return group.empty(); })) {
return op.emitOpError("reassociation indices must not be empty");
}
- auto srcType = llvm::dyn_cast<RankedTensorType>(op.getSrc().getType());
- if (!srcType)
- return op.emitOpError("expects ranked tensor source type, but got ")
- << op.getSrc().getType();
-
- auto resultType = llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
- if (!resultType)
- return op.emitOpError("expects ranked tensor result type, but got ")
- << op.getResult().getType();
+ auto srcType = llvm::cast<RankedTensorType>(op.getSrc().getType());
+ auto resultType = llvm::cast<RankedTensorType>(op.getResult().getType());
return verifyTensorReshapeOp(op, srcType, resultType);
}
>From ed86731f7e91711eca4f23c577a0a7df384db927 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 14:54:57 -0500
Subject: [PATCH 5/5] Update tests
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
mlir/test/Dialect/Tensor/invalid.mlir | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b0d323a080bd5..a8b67b4c73ad5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2077,7 +2077,7 @@ LogicalResult CollapseShapeOp::verify() {
[](ReassociationIndices group) { return group.empty(); })) {
return op.emitOpError("reassociation indices must not be empty");
}
- auto srcType = llvm::cast<RankedTensorType>(op.getSrc().getType());
+ auto srcType = llvm::cast<RankedTensorType>(getSrc().getType());
auto resultType = llvm::cast<RankedTensorType>(op.getResult().getType());
return verifyTensorReshapeOp(op, srcType, resultType);
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index dadf586005173..c149c39f99dce 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -693,7 +693,7 @@ func.func @test_empty_reassociation(%arg0: tensor<1x?xf32>) -> tensor<?x10xf32>
// -----
func.func @collapse_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
- // expected-error at +1 {{expects ranked tensor source type}}
+ // expected-error at +1 {{custom op 'tensor.collapse_shape' invalid kind of type specified: expected builtin.tensor, but found 'tensor<*xf32>'}}
%0 = tensor.collapse_shape %arg0 [[0]] : tensor<*xf32> into tensor<f32>
return
}
@@ -701,7 +701,7 @@ func.func @collapse_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
// -----
func.func @expand_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
- // expected-error at +1 {{expects ranked tensor source type}}
+ // expected-error at +1 {{custom op 'tensor.expand_shape' invalid kind of type specified: expected builtin.tensor, but found 'tensor<*xf32>'}}
%0 = tensor.expand_shape %arg0 [[0]] output_shape [1] : tensor<*xf32> into tensor<1xf32>
return
}
More information about the Mlir-commits
mailing list