[Mlir-commits] [mlir] [mlir][shape] Fix crash when folding tensor.extract(shape_of(memref)) (PR #186270)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 12 15:54:22 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-shape

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

The `ExtractFromShapeOfExtentTensor` canonicalization pattern was unconditionally rewriting:

  tensor.extract(shape.shape_of(%arg), %idx) -> tensor.dim(%arg, %idx)

even when `%arg` is a memref.  This produced an invalid `tensor.dim` (whose source operand must be a tensor), which then caused an assertion failure in `DimOp::getSource()` when subsequent canonicalization patterns tried to match the op:

  Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type\!"'
  failed.  [To = TypedValue<TensorType>, From = Value]

Fix: add an `IsTensorType` constraint to `ExtractFromShapeOfExtentTensor` in `ShapeCanonicalization.td` so the pattern only fires when `%arg` is a tensor type.  The memref case is intentionally left unfolded (the correct lowering to `memref.dim` would require adding a MemRef dependency to the Shape dialect, which is not desirable).

Tests cover both the positive case (tensor arg folds to tensor.dim) and the negative case (memref arg is left unmodified).

Fixes #<!-- -->185248

Assisted-by: Claude Code

---
Full diff: https://github.com/llvm/llvm-project/pull/186270.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td (+9-3) 
- (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+31) 


``````````diff
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index cb294ae2978fc..829f3e0adfbed 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -14,6 +14,9 @@ def HasStaticShape : Constraint<CPred< [{
   ::llvm::dyn_cast<ShapedType>($0.getType()).hasStaticShape()
 }]>>;
 
+def IsTensorType : Constraint<CPred<"isa<TensorType>($0.getType())">,
+                              "is a tensor type">;
+
 // Helper that takes the first element of a range.
 def TakeFront : NativeCodeCall<"$0.front()">;
 
@@ -45,8 +48,11 @@ def TensorCastConstShape : Pat <
   (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
   [(HasStaticShape $res)]>;
 
-// tensor.extract from shape_of -> tensor.dim. We can take the first index
-// because shape_of always returns a 1D tensor.
+// tensor.extract from shape_of(tensor) -> tensor.dim. We can take the first
+// index because shape_of always returns a 1D tensor.  Only applies when $arg
+// is a tensor; the memref case is not handled here (memref.dim would be
+// needed, but adding a MemRef dependency to this file is not desirable).
 def ExtractFromShapeOfExtentTensor : Pat<
   (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
-  (Tensor_DimOp $arg, (TakeFront $indices))>;
+  (Tensor_DimOp $arg, (TakeFront $indices)),
+  [(IsTensorType $arg)]>;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index d1b5e7bb035bf..89aa167360168 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1654,3 +1654,34 @@ func.func @broadcast_no_crash_on_poison() {
   %3 = tensor.rank %2 : tensor<3xindex>
   return
 }
+
+// -----
+
+// tensor.extract(shape.shape_of(tensor)) folds to tensor.dim.
+// CHECK-LABEL: func @extract_from_shape_of_tensor(
+func.func @extract_from_shape_of_tensor(%arg0: tensor<?xf32>) -> index {
+  // CHECK:      %[[DIM:.*]] = tensor.dim %arg0
+  // CHECK-NOT:  shape.shape_of
+  // CHECK-NOT:  tensor.extract
+  // CHECK:      return %[[DIM]]
+  %c0 = arith.constant 0 : index
+  %shape = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
+  %dim = tensor.extract %shape[%c0] : tensor<1xindex>
+  return %dim : index
+}
+
+// -----
+
+// tensor.extract(shape.shape_of(memref)) must NOT be folded to tensor.dim
+// because tensor.dim requires a tensor source.  Previously this pattern
+// incorrectly created tensor.dim with a memref operand, causing a crash.
+// CHECK-LABEL: func @extract_from_shape_of_memref_no_fold(
+func.func @extract_from_shape_of_memref_no_fold(%arg0: memref<?xf32>) -> index {
+  // CHECK:      shape.shape_of
+  // CHECK:      tensor.extract
+  // CHECK-NOT:  tensor.dim
+  %c0 = arith.constant 0 : index
+  %shape = shape.shape_of %arg0 : memref<?xf32> -> tensor<1xindex>
+  %dim = tensor.extract %shape[%c0] : tensor<1xindex>
+  return %dim : index
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/186270


More information about the Mlir-commits mailing list