[Mlir-commits] [mlir] 0572ad6 - [mlir][shape] Fix crash when folding tensor.extract(shape_of(memref)) (#186270)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 13 06:17:13 PDT 2026
Author: Mehdi Amini
Date: 2026-03-13T14:17:07+01:00
New Revision: 0572ad60f354f86deddf4bd364fd0145d7a146ea
URL: https://github.com/llvm/llvm-project/commit/0572ad60f354f86deddf4bd364fd0145d7a146ea
DIFF: https://github.com/llvm/llvm-project/commit/0572ad60f354f86deddf4bd364fd0145d7a146ea.diff
LOG: [mlir][shape] Fix crash when folding tensor.extract(shape_of(memref)) (#186270)
The `ExtractFromShapeOfExtentTensor` canonicalization pattern was
unconditionally rewriting:
tensor.extract(shape.shape_of(%arg), %idx) -> tensor.dim(%arg, %idx)
even when `%arg` is a memref. This produced an invalid `tensor.dim`
(whose source operand must be a tensor), which then caused an assertion
failure in `DimOp::getSource()` when subsequent canonicalization
patterns tried to match the op:
Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type\!"'
failed. [To = TypedValue<TensorType>, From = Value]
Fix: add an `IsTensorType` constraint to
`ExtractFromShapeOfExtentTensor` in `ShapeCanonicalization.td` so the
pattern only fires when `%arg` is a tensor type. The memref case is
intentionally left unfolded (the correct lowering to `memref.dim` would
require adding a MemRef dependency to the Shape dialect, which is not
desirable).
Tests cover both the positive case (tensor arg folds to tensor.dim) and
the negative case (memref arg is left unmodified).
Fixes #185248
Assisted-by: Claude Code
Added:
Modified:
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index cb294ae2978fc..829f3e0adfbed 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -14,6 +14,9 @@ def HasStaticShape : Constraint<CPred< [{
::llvm::dyn_cast<ShapedType>($0.getType()).hasStaticShape()
}]>>;
+def IsTensorType : Constraint<CPred<"isa<TensorType>($0.getType())">,
+ "is a tensor type">;
+
// Helper that takes the first element of a range.
def TakeFront : NativeCodeCall<"$0.front()">;
@@ -45,8 +48,11 @@ def TensorCastConstShape : Pat <
(Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
[(HasStaticShape $res)]>;
-// tensor.extract from shape_of -> tensor.dim. We can take the first index
-// because shape_of always returns a 1D tensor.
+// tensor.extract from shape_of(tensor) -> tensor.dim. We can take the first
+// index because shape_of always returns a 1D tensor. Only applies when $arg
+// is a tensor; the memref case is not handled here (memref.dim would be
+// needed, but adding a MemRef dependency to this file is not desirable).
def ExtractFromShapeOfExtentTensor : Pat<
(Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
- (Tensor_DimOp $arg, (TakeFront $indices))>;
+ (Tensor_DimOp $arg, (TakeFront $indices)),
+ [(IsTensorType $arg)]>;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index d1b5e7bb035bf..89aa167360168 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1654,3 +1654,34 @@ func.func @broadcast_no_crash_on_poison() {
%3 = tensor.rank %2 : tensor<3xindex>
return
}
+
+// -----
+
+// tensor.extract(shape.shape_of(tensor)) folds to tensor.dim.
+// CHECK-LABEL: func @extract_from_shape_of_tensor(
+func.func @extract_from_shape_of_tensor(%arg0: tensor<?xf32>) -> index {
+ // CHECK: %[[DIM:.*]] = tensor.dim %arg0
+ // CHECK-NOT: shape.shape_of
+ // CHECK-NOT: tensor.extract
+ // CHECK: return %[[DIM]]
+ %c0 = arith.constant 0 : index
+ %shape = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
+ %dim = tensor.extract %shape[%c0] : tensor<1xindex>
+ return %dim : index
+}
+
+// -----
+
+// tensor.extract(shape.shape_of(memref)) must NOT be folded to tensor.dim
+// because tensor.dim requires a tensor source. Previously this pattern
+// incorrectly created tensor.dim with a memref operand, causing a crash.
+// CHECK-LABEL: func @extract_from_shape_of_memref_no_fold(
+func.func @extract_from_shape_of_memref_no_fold(%arg0: memref<?xf32>) -> index {
+ // CHECK: shape.shape_of
+ // CHECK: tensor.extract
+ // CHECK-NOT: tensor.dim
+ %c0 = arith.constant 0 : index
+ %shape = shape.shape_of %arg0 : memref<?xf32> -> tensor<1xindex>
+ %dim = tensor.extract %shape[%c0] : tensor<1xindex>
+ return %dim : index
+}
More information about the Mlir-commits
mailing list