[Mlir-commits] [mlir] [mlir] Fix crash when folding tensor.dim(tensor.collapse()) on out-of-bound dim (PR #119941)
Mehdi Amini
llvmlistbot at llvm.org
Fri Dec 13 16:50:53 PST 2024
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/119941
#119866
>From 5bc86db9896fcf524ad13a2d185f18002313168d Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 13 Dec 2024 16:39:04 -0800
Subject: [PATCH] [mlir] Fix crash when folding tensor.dim(tensor.collapse())
on out-of-bound dim
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 ++-
mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9bb628781342ca..21f78cf96c70e9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2012,7 +2012,8 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
// Only constant dimension values are supported.
std::optional<int64_t> dim = dimOp.getConstantIndex();
- if (!dim.has_value())
+ if (!dim.has_value() ||
+ dim.value() >= collapseShapeOp.getResultType().getRank())
return failure();
// Skip static dims. These are folded to constant ops.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 613ec066337294..e8fc4ce834e18f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2344,6 +2344,20 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
// -----
+// Can't fold when dim is out of bound.
+// CHECK-LABEL: func @out_of_bound_dim_of_collapse_shape(
+// CHECK: %[[DIM:.*]] = tensor.dim
+// CHECK: return %[[DIM]]
+func.func @out_of_bound_dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
+ %c5 = arith.constant 5 : index
+ %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
+ : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
+ %1 = tensor.dim %0, %c5 : tensor<?x?xf32>
+ return %1 : index
+}
+
+// -----
+
// CHECK-LABEL: func @collapse_expand_fold_to_cast(
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
// CHECK: return %[[t]]
More information about the Mlir-commits
mailing list