[Mlir-commits] [mlir] [mlir][tensor] Avoid asserts in reshape verifiers on unranked tensors (PR #179005)

Samarth Narang llvmlistbot at llvm.org
Mon Feb 2 11:57:12 PST 2026


https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/179005

>From 69e33d969c2f6a10fda4e924e269de5a431f3815 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 19:42:16 -0500
Subject: [PATCH 1/5] [mlir][tensor] Avoid assert fail in verifier

tensor's collapse_shape and expand_shape ops
assert when the operand/result is unranked.

Guard the verifier using dyn_cast and emit
a meaningful error message instead.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 22 +++++++++++++++++++---
 1 file changed, 19 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d885d2c871e3f..2c406cb351c2d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2051,8 +2051,14 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
 }
 
 LogicalResult ExpandShapeOp::verify() {
-  auto srcType = getSrcType();
-  auto resultType = getResultType();
+  auto srcType = llvm::dyn_cast<RankedTensorType>(getSrc().getType());
+  if (!srcType)
+    return emitOpError("expects ranked tensor source type, but got ")
+           << getSrc().getType();
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
+  if (!resultType)
+    return emitOpError("expects ranked tensor result type, but got ")
+           << getResult().getType();
 
   if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
     return emitOpError("expected number of static shape dims to be equal to "
@@ -2077,7 +2083,17 @@ LogicalResult CollapseShapeOp::verify() {
                    [](ReassociationIndices group) { return group.empty(); })) {
     return op.emitOpError("reassociation indices must not be empty");
   }
-  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
+  auto srcType = llvm::dyn_cast<RankedTensorType>(op.getSrc().getType());
+  if (!srcType)
+    return op.emitOpError("expects ranked tensor source type, but got ")
+           << op.getSrc().getType();
+
+  auto resultType = llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
+  if (!resultType)
+    return op.emitOpError("expects ranked tensor result type, but got ")
+           << op.getResult().getType();
+
+  return verifyTensorReshapeOp(op, srcType, resultType);
 }
 
 namespace {

>From 627e65c8118d13ab90e38bc9fd8091816b202761 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 19:48:29 -0500
Subject: [PATCH 2/5] Add test cases

---
 mlir/test/Dialect/Tensor/invalid.mlir | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 63be5493e8935..dadf586005173 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -690,3 +690,19 @@ func.func @test_empty_reassociation(%arg0: tensor<1x?xf32>) -> tensor<?x10xf32>
   return %0 : tensor<?x10xf32>
 }
 
+// -----
+
+func.func @collapse_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
+  // expected-error at +1 {{expects ranked tensor source type}}
+  %0 = tensor.collapse_shape %arg0 [[0]] : tensor<*xf32> into tensor<f32>
+  return
+}
+
+// -----
+
+func.func @expand_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
+  // expected-error at +1 {{expects ranked tensor source type}}
+  %0 = tensor.expand_shape %arg0 [[0]] output_shape [1] : tensor<*xf32> into tensor<1xf32>
+  return
+}
+

>From a06d75252d503659ebebd09f6a7c673d80ae4949 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 14:25:07 -0500
Subject: [PATCH 3/5] Change .td description of ExpandShapeOp

---
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td |  2 +-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp         | 10 ++--------
 2 files changed, 3 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 8b10c00008865..ebad49f66da16 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1133,7 +1133,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     ```
   }];
 
-  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+  let arguments = (ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation,
                        Variadic<Index>:$output_shape,
                        DenseI64ArrayAttr:$static_output_shape);
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2c406cb351c2d..58768c2fe6873 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2051,14 +2051,8 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
 }
 
 LogicalResult ExpandShapeOp::verify() {
-  auto srcType = llvm::dyn_cast<RankedTensorType>(getSrc().getType());
-  if (!srcType)
-    return emitOpError("expects ranked tensor source type, but got ")
-           << getSrc().getType();
-  auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
-  if (!resultType)
-    return emitOpError("expects ranked tensor result type, but got ")
-           << getResult().getType();
+  auto srcType = llvm::cast<RankedTensorType>(getSrc().getType());
+  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
 
   if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
     return emitOpError("expected number of static shape dims to be equal to "

>From 155d15504f669f91858130eb67f893bd42f75528 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 14:28:32 -0500
Subject: [PATCH 4/5] Change .td description for CollapseShapeOp

---
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td |  2 +-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp         | 11 ++---------
 2 files changed, 3 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index ebad49f66da16..15e0577c6ad70 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1191,7 +1191,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
-  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
+  let arguments = (ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation);
   let description = [{
     The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
     rank whose dimension sizes are a reassociation of the original `src` dimensions.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 58768c2fe6873..b0d323a080bd5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2077,15 +2077,8 @@ LogicalResult CollapseShapeOp::verify() {
                    [](ReassociationIndices group) { return group.empty(); })) {
     return op.emitOpError("reassociation indices must not be empty");
   }
-  auto srcType = llvm::dyn_cast<RankedTensorType>(op.getSrc().getType());
-  if (!srcType)
-    return op.emitOpError("expects ranked tensor source type, but got ")
-           << op.getSrc().getType();
-
-  auto resultType = llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
-  if (!resultType)
-    return op.emitOpError("expects ranked tensor result type, but got ")
-           << op.getResult().getType();
+  auto srcType = llvm::cast<RankedTensorType>(op.getSrc().getType());
+  auto resultType = llvm::cast<RankedTensorType>(op.getResult().getType());
 
   return verifyTensorReshapeOp(op, srcType, resultType);
 }

>From ed86731f7e91711eca4f23c577a0a7df384db927 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 14:54:57 -0500
Subject: [PATCH 5/5] Update tests

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
 mlir/test/Dialect/Tensor/invalid.mlir    | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b0d323a080bd5..a8b67b4c73ad5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2077,7 +2077,7 @@ LogicalResult CollapseShapeOp::verify() {
                    [](ReassociationIndices group) { return group.empty(); })) {
     return op.emitOpError("reassociation indices must not be empty");
   }
-  auto srcType = llvm::cast<RankedTensorType>(op.getSrc().getType());
+  auto srcType = llvm::cast<RankedTensorType>(getSrc().getType());
   auto resultType = llvm::cast<RankedTensorType>(op.getResult().getType());
 
   return verifyTensorReshapeOp(op, srcType, resultType);
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index dadf586005173..c149c39f99dce 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -693,7 +693,7 @@ func.func @test_empty_reassociation(%arg0: tensor<1x?xf32>) -> tensor<?x10xf32>
 // -----
 
 func.func @collapse_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
-  // expected-error at +1 {{expects ranked tensor source type}}
+  // expected-error at +1 {{custom op 'tensor.collapse_shape' invalid kind of type specified: expected builtin.tensor, but found 'tensor<*xf32>'}}
   %0 = tensor.collapse_shape %arg0 [[0]] : tensor<*xf32> into tensor<f32>
   return
 }
@@ -701,7 +701,7 @@ func.func @collapse_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
 // -----
 
 func.func @expand_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
-  // expected-error at +1 {{expects ranked tensor source type}}
+  // expected-error at +1 {{custom op 'tensor.expand_shape' invalid kind of type specified: expected builtin.tensor, but found 'tensor<*xf32>'}}
   %0 = tensor.expand_shape %arg0 [[0]] output_shape [1] : tensor<*xf32> into tensor<1xf32>
   return
 }



More information about the Mlir-commits mailing list