[Mlir-commits] [mlir] [mlir][shape] Fix crash when folding tensor.extract(shape_of(memref)) (PR #186270)

Mehdi Amini llvmlistbot at llvm.org
Thu Mar 12 15:53:48 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/186270

The `ExtractFromShapeOfExtentTensor` canonicalization pattern was unconditionally rewriting:

  tensor.extract(shape.shape_of(%arg), %idx) -> tensor.dim(%arg, %idx)

even when `%arg` is a memref.  This produced an invalid `tensor.dim` (whose source operand must be a tensor), which then caused an assertion failure in `DimOp::getSource()` when subsequent canonicalization patterns tried to match the op:

  Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type\!"'
  failed.  [To = TypedValue<TensorType>, From = Value]

Fix: add an `IsTensorType` constraint to `ExtractFromShapeOfExtentTensor` in `ShapeCanonicalization.td` so the pattern only fires when `%arg` is a tensor type.  The memref case is intentionally left unfolded (the correct lowering to `memref.dim` would require adding a MemRef dependency to the Shape dialect, which is not desirable).

Tests cover both the positive case (tensor arg folds to tensor.dim) and the negative case (memref arg is left unmodified).

Fixes #185248

Assisted-by: Claude Code

>From 367d3d2b16121b3158f6282e428e0a7e39dd0b7c Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 12 Mar 2026 11:25:08 -0700
Subject: [PATCH] [mlir][shape] Fix crash when folding
 tensor.extract(shape_of(memref))

The `ExtractFromShapeOfExtentTensor` canonicalization pattern was
unconditionally rewriting:

  tensor.extract(shape.shape_of(%arg), %idx) -> tensor.dim(%arg, %idx)

even when `%arg` is a memref.  This produced an invalid `tensor.dim`
(whose source operand must be a tensor), which then caused an assertion
failure in `DimOp::getSource()` when subsequent canonicalization
patterns tried to match the op:

  Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type\!"'
  failed.  [To = TypedValue<TensorType>, From = Value]

Fix: add an `IsTensorType` constraint to `ExtractFromShapeOfExtentTensor`
in `ShapeCanonicalization.td` so the pattern only fires when `%arg` is a
tensor type.  The memref case is intentionally left unfolded (the correct
lowering to `memref.dim` would require adding a MemRef dependency to the
Shape dialect, which is not desirable).

Tests cover both the positive case (tensor arg folds to tensor.dim) and
the negative case (memref arg is left unmodified).

Fixes #185248

Assisted-by: Claude Code
---
 .../Dialect/Shape/IR/ShapeCanonicalization.td | 12 +++++--
 mlir/test/Dialect/Shape/canonicalize.mlir     | 31 +++++++++++++++++++
 2 files changed, 40 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index cb294ae2978fc..829f3e0adfbed 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -14,6 +14,9 @@ def HasStaticShape : Constraint<CPred< [{
   ::llvm::dyn_cast<ShapedType>($0.getType()).hasStaticShape()
 }]>>;
 
+def IsTensorType : Constraint<CPred<"isa<TensorType>($0.getType())">,
+                              "is a tensor type">;
+
 // Helper that takes the first element of a range.
 def TakeFront : NativeCodeCall<"$0.front()">;
 
@@ -45,8 +48,11 @@ def TensorCastConstShape : Pat <
   (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
   [(HasStaticShape $res)]>;
 
-// tensor.extract from shape_of -> tensor.dim. We can take the first index
-// because shape_of always returns a 1D tensor.
+// tensor.extract from shape_of(tensor) -> tensor.dim. We can take the first
+// index because shape_of always returns a 1D tensor.  Only applies when $arg
+// is a tensor; the memref case is not handled here (memref.dim would be
+// needed, but adding a MemRef dependency to this file is not desirable).
 def ExtractFromShapeOfExtentTensor : Pat<
   (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
-  (Tensor_DimOp $arg, (TakeFront $indices))>;
+  (Tensor_DimOp $arg, (TakeFront $indices)),
+  [(IsTensorType $arg)]>;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index d1b5e7bb035bf..89aa167360168 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1654,3 +1654,34 @@ func.func @broadcast_no_crash_on_poison() {
   %3 = tensor.rank %2 : tensor<3xindex>
   return
 }
+
+// -----
+
+// tensor.extract(shape.shape_of(tensor)) folds to tensor.dim.
+// CHECK-LABEL: func @extract_from_shape_of_tensor(
+func.func @extract_from_shape_of_tensor(%arg0: tensor<?xf32>) -> index {
+  // CHECK:      %[[DIM:.*]] = tensor.dim %arg0
+  // CHECK-NOT:  shape.shape_of
+  // CHECK-NOT:  tensor.extract
+  // CHECK:      return %[[DIM]]
+  %c0 = arith.constant 0 : index
+  %shape = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
+  %dim = tensor.extract %shape[%c0] : tensor<1xindex>
+  return %dim : index
+}
+
+// -----
+
+// tensor.extract(shape.shape_of(memref)) must NOT be folded to tensor.dim
+// because tensor.dim requires a tensor source.  Previously this pattern
+// incorrectly created tensor.dim with a memref operand, causing a crash.
+// CHECK-LABEL: func @extract_from_shape_of_memref_no_fold(
+func.func @extract_from_shape_of_memref_no_fold(%arg0: memref<?xf32>) -> index {
+  // CHECK:      shape.shape_of
+  // CHECK:      tensor.extract
+  // CHECK-NOT:  tensor.dim
+  %c0 = arith.constant 0 : index
+  %shape = shape.shape_of %arg0 : memref<?xf32> -> tensor<1xindex>
+  %dim = tensor.extract %shape[%c0] : tensor<1xindex>
+  return %dim : index
+}



More information about the Mlir-commits mailing list