[Mlir-commits] [mlir] 68f0bc6 - [mlir] Fix a zero stride canonicalizer crash (#74200)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 5 22:35:23 PST 2023


Author: Rik Huijzer
Date: 2023-12-06T07:35:18+01:00
New Revision: 68f0bc6f2e869dc7d3e8394b99fba7052ad8116a

URL: https://github.com/llvm/llvm-project/commit/68f0bc6f2e869dc7d3e8394b99fba7052ad8116a
DIFF: https://github.com/llvm/llvm-project/commit/68f0bc6f2e869dc7d3e8394b99fba7052ad8116a.diff

LOG: [mlir] Fix a zero stride canonicalizer crash (#74200)

This PR fixes https://github.com/llvm/llvm-project/issues/73383 and is
another shot at the refactoring proposed in
https://github.com/llvm/llvm-project/pull/72885.

---------

Co-authored-by: Kai Sasaki <lewuathe at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 502ab93ddbfa7..1dc0398494dcc 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -139,12 +139,36 @@ SmallVector<int64_t>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
                      llvm::function_ref<bool(Attribute, Attribute)> compare);
 
+/// Helper function to check whether the passed in `sizes` or `offsets` are
+/// valid. This can be used to re-check whether dimensions are still valid
+/// after constant folding the dynamic dimensions.
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
+
+/// Helper function to check whether the passed in `strides` are valid. This
+/// can be used to re-check whether dimensions are still valid after constant
+/// folding the dynamic dimensions.
+bool hasValidStrides(SmallVector<int64_t> strides);
+
 /// Returns "success" when any of the elements in `ofrs` is a constant value. In
 /// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened. If `onlyNonNegative` is set, only non-negative constant
-/// values are folded.
+/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
+/// non-negative and non-zero constant values are folded respectively.
 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
-                                   bool onlyNonNegative = false);
+                                   bool onlyNonNegative = false,
+                                   bool onlyNonZero = false);
+
+/// Returns "success" when any of the elements in `offsetsOrSizes` is a
+/// constant value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
+
+/// Returns "success" when any of the elements in `strides` is a constant
+/// value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a397506629cfe..93327a28234ea 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2582,17 +2582,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  // If one of the offsets or sizes is invalid, fail the canonicalization.
-  // These checks also occur in the verifier, but they are needed here
-  // because some dynamic dimensions may have been constant folded.
-  for (int64_t offset : staticOffsets)
-    if (offset < 0 && !ShapedType::isDynamic(offset))
-      return {};
-  for (int64_t size : staticSizes)
-    if (size < 0 && !ShapedType::isDynamic(size))
-      return {};
-
+  if (!hasValidSizesOffsets(staticOffsets))
+    return {};
+  if (!hasValidSizesOffsets(staticSizes))
+    return {};
+  if (!hasValidStrides(staticStrides))
+    return {};
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
                                     staticSizes, staticStrides);
 }

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f15695383d34a..55f813df78b85 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1447,13 +1447,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
     SmallVector<int64_t> newShape;
     operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
 
-    for (int64_t newdim : newShape) {
-      // This check also occurs in the verifier, but we need it here too
-      // since intermediate passes may have replaced some dynamic dimensions
-      // by constants.
-      if (newdim < 0 && !ShapedType::isDynamic(newdim))
-        return failure();
-    }
+    if (!hasValidSizesOffsets(newShape))
+      return failure();
 
     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
       return failure();
@@ -2549,9 +2544,9 @@ class InsertSliceOpConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedStrides)))
+    if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
+        failed(foldDynamicOffsetSizeList(mixedSizes)) &&
+        failed(foldDynamicStrideList(mixedStrides)))
       return failure();
 
     // Create the new op in canonical form.
@@ -2692,6 +2687,8 @@ struct InsertSliceOpSourceCastInserter final
         newSrcShape[i] = *constInt;
       }
     }
+    if (!hasValidSizesOffsets(newSrcShape))
+      return failure();
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index c7a3d8fc8eb28..0c8a88da789e2 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
   return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
 }
 
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
+  return llvm::none_of(sizesOrOffsets, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value < 0;
+  });
+}
+
+bool hasValidStrides(SmallVector<int64_t> strides) {
+  return llvm::none_of(strides, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value == 0;
+  });
+}
+
 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
-                                   bool onlyNonNegative) {
+                                   bool onlyNonNegative, bool onlyNonZero) {
   bool valuesChanged = false;
   for (OpFoldResult &ofr : ofrs) {
     if (ofr.is<Attribute>())
@@ -267,6 +279,8 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
       // Note: All ofrs have index type.
       if (onlyNonNegative && *getConstantIntValue(attr) < 0)
         continue;
+      if (onlyNonZero && *getConstantIntValue(attr) == 0)
+        continue;
       ofr = attr;
       valuesChanged = true;
     }
@@ -274,4 +288,15 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
   return success(valuesChanged);
 }
 
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
+  return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
+                              /*onlyNonZero=*/false);
+}
+
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
+  return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
+                              /*onlyNonZero=*/true);
+}
+
 } // namespace mlir

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a1f8673638ff8..d3406c630f6dd 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?
 
 // -----
 
+// CHECK-LABEL: func @no_fold_subview_zero_stride
+//  CHECK:        %[[SUBVIEW:.+]] = memref.subview
+//  CHECK:        return %[[SUBVIEW]]
+func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>>
+  return %1 : memref<1xf32, strided<[?], offset: 1>>
+}
+
+// -----
+
 // CHECK-LABEL: func @no_fold_of_store
 //  CHECK:   %[[cst:.+]] = memref.cast %arg
 //  CHECK:   memref.store %[[cst]]


        


More information about the Mlir-commits mailing list