[Mlir-commits] [mlir] 0572ad6 - [mlir][shape] Fix crash when folding tensor.extract(shape_of(memref)) (#186270)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 13 06:17:13 PDT 2026


Author: Mehdi Amini
Date: 2026-03-13T14:17:07+01:00
New Revision: 0572ad60f354f86deddf4bd364fd0145d7a146ea

URL: https://github.com/llvm/llvm-project/commit/0572ad60f354f86deddf4bd364fd0145d7a146ea
DIFF: https://github.com/llvm/llvm-project/commit/0572ad60f354f86deddf4bd364fd0145d7a146ea.diff

LOG: [mlir][shape] Fix crash when folding tensor.extract(shape_of(memref)) (#186270)

The `ExtractFromShapeOfExtentTensor` canonicalization pattern was
unconditionally rewriting:

  tensor.extract(shape.shape_of(%arg), %idx) -> tensor.dim(%arg, %idx)

even when `%arg` is a memref. This produced an invalid `tensor.dim`
(whose source operand must be a tensor), which then caused an assertion
failure in `DimOp::getSource()` when subsequent canonicalization
patterns tried to match the op:

Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type\!"'
  failed.  [To = TypedValue<TensorType>, From = Value]

Fix: add an `IsTensorType` constraint to
`ExtractFromShapeOfExtentTensor` in `ShapeCanonicalization.td` so the
pattern only fires when `%arg` is a tensor type. The memref case is
intentionally left unfolded (the correct lowering to `memref.dim` would
require adding a MemRef dependency to the Shape dialect, which is not
desirable).

Tests cover both the positive case (tensor arg folds to tensor.dim) and
the negative case (memref arg is left unmodified).

Fixes #185248

Assisted-by: Claude Code

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index cb294ae2978fc..829f3e0adfbed 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -14,6 +14,9 @@ def HasStaticShape : Constraint<CPred< [{
   ::llvm::dyn_cast<ShapedType>($0.getType()).hasStaticShape()
 }]>>;
 
+def IsTensorType : Constraint<CPred<"isa<TensorType>($0.getType())">,
+                              "is a tensor type">;
+
 // Helper that takes the first element of a range.
 def TakeFront : NativeCodeCall<"$0.front()">;
 
@@ -45,8 +48,11 @@ def TensorCastConstShape : Pat <
   (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
   [(HasStaticShape $res)]>;
 
-// tensor.extract from shape_of -> tensor.dim. We can take the first index
-// because shape_of always returns a 1D tensor.
+// tensor.extract from shape_of(tensor) -> tensor.dim. We can take the first
+// index because shape_of always returns a 1D tensor.  Only applies when $arg
+// is a tensor; the memref case is not handled here (memref.dim would be
+// needed, but adding a MemRef dependency to this file is not desirable).
 def ExtractFromShapeOfExtentTensor : Pat<
   (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
-  (Tensor_DimOp $arg, (TakeFront $indices))>;
+  (Tensor_DimOp $arg, (TakeFront $indices)),
+  [(IsTensorType $arg)]>;

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index d1b5e7bb035bf..89aa167360168 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1654,3 +1654,34 @@ func.func @broadcast_no_crash_on_poison() {
   %3 = tensor.rank %2 : tensor<3xindex>
   return
 }
+
+// -----
+
+// tensor.extract(shape.shape_of(tensor)) folds to tensor.dim.
+// CHECK-LABEL: func @extract_from_shape_of_tensor(
+func.func @extract_from_shape_of_tensor(%arg0: tensor<?xf32>) -> index {
+  // CHECK:      %[[DIM:.*]] = tensor.dim %arg0
+  // CHECK-NOT:  shape.shape_of
+  // CHECK-NOT:  tensor.extract
+  // CHECK:      return %[[DIM]]
+  %c0 = arith.constant 0 : index
+  %shape = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
+  %dim = tensor.extract %shape[%c0] : tensor<1xindex>
+  return %dim : index
+}
+
+// -----
+
+// tensor.extract(shape.shape_of(memref)) must NOT be folded to tensor.dim
+// because tensor.dim requires a tensor source.  Previously this pattern
+// incorrectly created tensor.dim with a memref operand, causing a crash.
+// CHECK-LABEL: func @extract_from_shape_of_memref_no_fold(
+func.func @extract_from_shape_of_memref_no_fold(%arg0: memref<?xf32>) -> index {
+  // CHECK:      shape.shape_of
+  // CHECK:      tensor.extract
+  // CHECK-NOT:  tensor.dim
+  %c0 = arith.constant 0 : index
+  %shape = shape.shape_of %arg0 : memref<?xf32> -> tensor<1xindex>
+  %dim = tensor.extract %shape[%c0] : tensor<1xindex>
+  return %dim : index
+}


        


More information about the Mlir-commits mailing list