[Mlir-commits] [mlir] [mlir][tensor] Fix crash when canonicalizing invalid IR (PR #72888)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 20 08:20:15 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit fixes a crash of the canonicalizer when there are slice ops with offset/size SSA values that have a negative constant value. Such ops are invalid if they are reachable and their offsets/sizes should not be folded to static integer values. (But such ops may appear in non-reachable block.)

This commit partially fixes #<!-- -->71150. The canonicalizer no longer crashes, but invalid IR is still being produced.

---
Full diff: https://github.com/llvm/llvm-project/pull/72888.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+4-2) 
- (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+2-2) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-3) 
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+10-3) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..c2fbaea726abcbb 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -141,8 +141,10 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
 
 /// Returns "success" when any of the elements in `ofrs` is a constant value. In
 /// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened.
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs);
+/// folding happened. If `onlyNonNegative` is set, only non-negative constant
+/// values are folded.
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+                                   bool onlyNonNegative = false);
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index a114e9af126f112..931309b0c596296 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -67,8 +67,8 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets)) &&
-        failed(foldDynamicIndexList(mixedSizes)) &&
+    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
         failed(foldDynamicIndexList(mixedStrides)))
       return failure();
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..5bfcb35127b5267 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2361,8 +2361,8 @@ class InsertSliceOpConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets)) &&
-        failed(foldDynamicIndexList(mixedSizes)) &&
+    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
         failed(foldDynamicIndexList(mixedStrides)))
       return failure();
 
@@ -2497,8 +2497,12 @@ struct InsertSliceOpSourceCastInserter final
                                      srcType.getShape().end());
     for (int64_t i = 0; i < srcType.getRank(); ++i) {
       if (std::optional<int64_t> constInt =
-              getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
+              getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
+        // Bail on invalid IR.
+        if (*constInt < 0)
+          return failure();
         newSrcShape[i] = *constInt;
+      }
     }
 
     RankedTensorType newSrcType =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..1cc3b054762a2c1 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,13 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
   return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
 }
 
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs) {
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+                                   bool onlyNonNegative) {
   bool valuesChanged = false;
   for (OpFoldResult &ofr : ofrs) {
     if (ofr.is<Attribute>())
       continue;
-    Attribute attr;
-    if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
+    APInt intVal;
+    if (matchPattern(ofr.get<Value>(), m_ConstantInt(&intVal))) {
+      if (intVal.isNegative() && onlyNonNegative)
+        continue;
+      Attribute attr;
+      bool isConstant = matchPattern(ofr.get<Value>(), m_Constant(&attr));
+      (void)isConstant;
+      assert(isConstant && "expected constant value");
       ofr = attr;
       valuesChanged = true;
     }
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..41bfd6fe7b6eedc 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1925,3 +1925,19 @@ func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init :
 //       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
 //  CHECK-SAME:       into %[[INIT]]
 //       CHECK:   return %[[UNPACK]]
+
+// -----
+
+// The IR in this test case in invalid. This test tests that the canonicalizer
+// does not crash.
+
+// CHECK-LABEL: func @invalid_slice_ops(
+//       CHECK:   %[[c:.*]] = arith.constant -5 : index
+//       CHECK:   tensor.extract_slice {{.*}}%[[c]]
+//       CHECK:   tensor.insert_slice {{.*}}%[[c]]
+func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
+  %c = arith.constant -5 : index
+  %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
+  %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
+  return %1 : tensor<?xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/72888


More information about the Mlir-commits mailing list