[Mlir-commits] [mlir] [mlir][tensor] Fix crash when canonicalizing invalid IR (PR #72888)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 20 08:20:15 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit fixes a crash of the canonicalizer when there are slice ops with offset/size SSA values that have a negative constant value. Such ops are invalid if they are reachable and their offsets/sizes should not be folded to static integer values. (But such ops may appear in non-reachable block.)
This commit partially fixes #<!-- -->71150. The canonicalizer no longer crashes, but invalid IR is still being produced.
---
Full diff: https://github.com/llvm/llvm-project/pull/72888.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+4-2)
- (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+2-2)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-3)
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+10-3)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+16)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..c2fbaea726abcbb 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -141,8 +141,10 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
/// Returns "success" when any of the elements in `ofrs` is a constant value. In
/// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened.
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs);
+/// folding happened. If `onlyNonNegative` is set, only non-negative constant
+/// values are folded.
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+ bool onlyNonNegative = false);
/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index a114e9af126f112..931309b0c596296 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -67,8 +67,8 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedOffsets)) &&
- failed(foldDynamicIndexList(mixedSizes)) &&
+ if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+ failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
failed(foldDynamicIndexList(mixedStrides)))
return failure();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..5bfcb35127b5267 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2361,8 +2361,8 @@ class InsertSliceOpConstantArgumentFolder final
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedOffsets)) &&
- failed(foldDynamicIndexList(mixedSizes)) &&
+ if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+ failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
failed(foldDynamicIndexList(mixedStrides)))
return failure();
@@ -2497,8 +2497,12 @@ struct InsertSliceOpSourceCastInserter final
srcType.getShape().end());
for (int64_t i = 0; i < srcType.getRank(); ++i) {
if (std::optional<int64_t> constInt =
- getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
+ getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
+ // Bail on invalid IR.
+ if (*constInt < 0)
+ return failure();
newSrcShape[i] = *constInt;
+ }
}
RankedTensorType newSrcType =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..1cc3b054762a2c1 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,13 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
}
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs) {
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+ bool onlyNonNegative) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>())
continue;
- Attribute attr;
- if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
+ APInt intVal;
+ if (matchPattern(ofr.get<Value>(), m_ConstantInt(&intVal))) {
+ if (intVal.isNegative() && onlyNonNegative)
+ continue;
+ Attribute attr;
+ bool isConstant = matchPattern(ofr.get<Value>(), m_Constant(&attr));
+ (void)isConstant;
+ assert(isConstant && "expected constant value");
ofr = attr;
valuesChanged = true;
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..41bfd6fe7b6eedc 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1925,3 +1925,19 @@ func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]
+
+// -----
+
+// The IR in this test case in invalid. This test tests that the canonicalizer
+// does not crash.
+
+// CHECK-LABEL: func @invalid_slice_ops(
+// CHECK: %[[c:.*]] = arith.constant -5 : index
+// CHECK: tensor.extract_slice {{.*}}%[[c]]
+// CHECK: tensor.insert_slice {{.*}}%[[c]]
+func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
+ %c = arith.constant -5 : index
+ %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
+ %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
+ return %1 : tensor<?xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/72888
More information about the Mlir-commits
mailing list