[Mlir-commits] [mlir] 68386a7 - [mlir][tensor] Fix crash when canonicalizing invalid IR (#72888)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 21 00:20:23 PST 2023


Author: Matthias Springer
Date: 2023-11-21T09:20:18+01:00
New Revision: 68386a74ba48cf084a56bf54d1efb6b822e28939

URL: https://github.com/llvm/llvm-project/commit/68386a74ba48cf084a56bf54d1efb6b822e28939
DIFF: https://github.com/llvm/llvm-project/commit/68386a74ba48cf084a56bf54d1efb6b822e28939.diff

LOG: [mlir][tensor] Fix crash when canonicalizing invalid IR (#72888)

This commit fixes a crash of the canonicalizer when there are slice ops
with offset/size SSA values that have a negative constant value. Such
ops are invalid if they are reachable and their offsets/sizes should not
be folded to static integer values. (But such ops may appear in
non-reachable block.)

This commit fixes #71150.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..c2fbaea726abcbb 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -141,8 +141,10 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
 
 /// Returns "success" when any of the elements in `ofrs` is a constant value. In
 /// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened.
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs);
+/// folding happened. If `onlyNonNegative` is set, only non-negative constant
+/// values are folded.
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+                                   bool onlyNonNegative = false);
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index a114e9af126f112..931309b0c596296 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -67,8 +67,8 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets)) &&
-        failed(foldDynamicIndexList(mixedSizes)) &&
+    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
         failed(foldDynamicIndexList(mixedStrides)))
       return failure();
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..5bfcb35127b5267 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2361,8 +2361,8 @@ class InsertSliceOpConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets)) &&
-        failed(foldDynamicIndexList(mixedSizes)) &&
+    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
         failed(foldDynamicIndexList(mixedStrides)))
       return failure();
 
@@ -2497,8 +2497,12 @@ struct InsertSliceOpSourceCastInserter final
                                      srcType.getShape().end());
     for (int64_t i = 0; i < srcType.getRank(); ++i) {
       if (std::optional<int64_t> constInt =
-              getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
+              getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
+        // Bail on invalid IR.
+        if (*constInt < 0)
+          return failure();
         newSrcShape[i] = *constInt;
+      }
     }
 
     RankedTensorType newSrcType =

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..c7a3d8fc8eb2841 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,13 +256,17 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
   return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
 }
 
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs) {
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+                                   bool onlyNonNegative) {
   bool valuesChanged = false;
   for (OpFoldResult &ofr : ofrs) {
     if (ofr.is<Attribute>())
       continue;
     Attribute attr;
     if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
+      // Note: All ofrs have index type.
+      if (onlyNonNegative && *getConstantIntValue(attr) < 0)
+        continue;
       ofr = attr;
       valuesChanged = true;
     }

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..41bfd6fe7b6eedc 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1925,3 +1925,19 @@ func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init :
 //       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
 //  CHECK-SAME:       into %[[INIT]]
 //       CHECK:   return %[[UNPACK]]
+
+// -----
+
+// The IR in this test case in invalid. This test tests that the canonicalizer
+// does not crash.
+
+// CHECK-LABEL: func @invalid_slice_ops(
+//       CHECK:   %[[c:.*]] = arith.constant -5 : index
+//       CHECK:   tensor.extract_slice {{.*}}%[[c]]
+//       CHECK:   tensor.insert_slice {{.*}}%[[c]]
+func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
+  %c = arith.constant -5 : index
+  %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
+  %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
+  return %1 : tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list