[Mlir-commits] [mlir] 68386a7 - [mlir][tensor] Fix crash when canonicalizing invalid IR (#72888)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 21 00:20:23 PST 2023
Author: Matthias Springer
Date: 2023-11-21T09:20:18+01:00
New Revision: 68386a74ba48cf084a56bf54d1efb6b822e28939
URL: https://github.com/llvm/llvm-project/commit/68386a74ba48cf084a56bf54d1efb6b822e28939
DIFF: https://github.com/llvm/llvm-project/commit/68386a74ba48cf084a56bf54d1efb6b822e28939.diff
LOG: [mlir][tensor] Fix crash when canonicalizing invalid IR (#72888)
This commit fixes a crash of the canonicalizer when there are slice ops
with offset/size SSA values that have a negative constant value. Such
ops are invalid if they are reachable and their offsets/sizes should not
be folded to static integer values. (But such ops may appear in
non-reachable block.)
This commit fixes #71150.
Added:
Modified:
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..c2fbaea726abcbb 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -141,8 +141,10 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
/// Returns "success" when any of the elements in `ofrs` is a constant value. In
/// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened.
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs);
+/// folding happened. If `onlyNonNegative` is set, only non-negative constant
+/// values are folded.
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+ bool onlyNonNegative = false);
/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index a114e9af126f112..931309b0c596296 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -67,8 +67,8 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedOffsets)) &&
- failed(foldDynamicIndexList(mixedSizes)) &&
+ if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+ failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
failed(foldDynamicIndexList(mixedStrides)))
return failure();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..5bfcb35127b5267 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2361,8 +2361,8 @@ class InsertSliceOpConstantArgumentFolder final
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedOffsets)) &&
- failed(foldDynamicIndexList(mixedSizes)) &&
+ if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+ failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
failed(foldDynamicIndexList(mixedStrides)))
return failure();
@@ -2497,8 +2497,12 @@ struct InsertSliceOpSourceCastInserter final
srcType.getShape().end());
for (int64_t i = 0; i < srcType.getRank(); ++i) {
if (std::optional<int64_t> constInt =
- getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
+ getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
+ // Bail on invalid IR.
+ if (*constInt < 0)
+ return failure();
newSrcShape[i] = *constInt;
+ }
}
RankedTensorType newSrcType =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..c7a3d8fc8eb2841 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,13 +256,17 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
}
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs) {
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+ bool onlyNonNegative) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>())
continue;
Attribute attr;
if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
+ // Note: All ofrs have index type.
+ if (onlyNonNegative && *getConstantIntValue(attr) < 0)
+ continue;
ofr = attr;
valuesChanged = true;
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..41bfd6fe7b6eedc 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1925,3 +1925,19 @@ func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]
+
+// -----
+
+// The IR in this test case in invalid. This test tests that the canonicalizer
+// does not crash.
+
+// CHECK-LABEL: func @invalid_slice_ops(
+// CHECK: %[[c:.*]] = arith.constant -5 : index
+// CHECK: tensor.extract_slice {{.*}}%[[c]]
+// CHECK: tensor.insert_slice {{.*}}%[[c]]
+func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
+ %c = arith.constant -5 : index
+ %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
+ %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
+ return %1 : tensor<?xf32>
+}
More information about the Mlir-commits
mailing list