[Mlir-commits] [mlir] Tensor shape ops no assert (PR #179005)

Samarth Narang llvmlistbot at llvm.org
Fri Jan 30 16:49:04 PST 2026


https://github.com/snarang181 created https://github.com/llvm/llvm-project/pull/179005

Fixes https://github.com/llvm/llvm-project/issues/178228

>From 2976b2b28cd7afffb822fedfb2ff714720fc0026 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 19:42:16 -0500
Subject: [PATCH 1/2] [mlir][tensor] Avoid assert fail in verifier

tensor's collapse_shape and expand_shape ops
assert when the operand/result is unranked.

Guard the verifier using dyn_cast and emit
a meaningful error message instead.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 22 +++++++++++++++++++---
 1 file changed, 19 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d885d2c871e3f..2c406cb351c2d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2051,8 +2051,14 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
 }
 
 LogicalResult ExpandShapeOp::verify() {
-  auto srcType = getSrcType();
-  auto resultType = getResultType();
+  auto srcType = llvm::dyn_cast<RankedTensorType>(getSrc().getType());
+  if (!srcType)
+    return emitOpError("expects ranked tensor source type, but got ")
+           << getSrc().getType();
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
+  if (!resultType)
+    return emitOpError("expects ranked tensor result type, but got ")
+           << getResult().getType();
 
   if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
     return emitOpError("expected number of static shape dims to be equal to "
@@ -2077,7 +2083,17 @@ LogicalResult CollapseShapeOp::verify() {
                    [](ReassociationIndices group) { return group.empty(); })) {
     return op.emitOpError("reassociation indices must not be empty");
   }
-  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
+  auto srcType = llvm::dyn_cast<RankedTensorType>(op.getSrc().getType());
+  if (!srcType)
+    return op.emitOpError("expects ranked tensor source type, but got ")
+           << op.getSrc().getType();
+
+  auto resultType = llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
+  if (!resultType)
+    return op.emitOpError("expects ranked tensor result type, but got ")
+           << op.getResult().getType();
+
+  return verifyTensorReshapeOp(op, srcType, resultType);
 }
 
 namespace {

>From d80cc501a97d55344e318dd91b4f430260353352 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 19:48:29 -0500
Subject: [PATCH 2/2] Add test cases

---
 mlir/test/Dialect/Tensor/invalid.mlir | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 63be5493e8935..dadf586005173 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -690,3 +690,19 @@ func.func @test_empty_reassociation(%arg0: tensor<1x?xf32>) -> tensor<?x10xf32>
   return %0 : tensor<?x10xf32>
 }
 
+// -----
+
+func.func @collapse_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
+  // expected-error at +1 {{expects ranked tensor source type}}
+  %0 = tensor.collapse_shape %arg0 [[0]] : tensor<*xf32> into tensor<f32>
+  return
+}
+
+// -----
+
+func.func @expand_shape_requires_ranked_tensor(%arg0: tensor<*xf32>) {
+  // expected-error at +1 {{expects ranked tensor source type}}
+  %0 = tensor.expand_shape %arg0 [[0]] output_shape [1] : tensor<*xf32> into tensor<1xf32>
+  return
+}
+



More information about the Mlir-commits mailing list