[clang] [clang-tools-extra] [llvm] [mlir] [mlir] Fix a zero stride canonicalizer crash (PR #74200)

via cfe-commits cfe-commits at lists.llvm.org
Sat Dec 2 09:10:23 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-tensor

Author: Rik Huijzer (rikhuijzer)

<details>
<summary>Changes</summary>

This PR fixes https://github.com/llvm/llvm-project/issues/73383 and is another shot at the refactoring proposed in https://github.com/llvm/llvm-project/pull/72885.

---
Full diff: https://github.com/llvm/llvm-project/pull/74200.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+27-3) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-11) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-10) 
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+26-1) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 502ab93ddbfa7..a1853438ccf7f 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -139,12 +139,36 @@ SmallVector<int64_t>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
                      llvm::function_ref<bool(Attribute, Attribute)> compare);
 
+/// Helper function to check whether the passed in `sizes` or `values` are
+/// valid. This can be used to re-check whether dimensions are still valid
+/// after constant folding the dynamic dimensions.
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
+
+/// Helper function to check whether the passed in `strides` are valid. This
+/// can be used to re-check whether dimensions are still valid after constant
+/// folding the dynamic dimensions.
+bool hasValidStrides(SmallVector<int64_t> strides);
+
 /// Returns "success" when any of the elements in `ofrs` is a constant value. In
 /// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened. If `onlyNonNegative` is set, only non-negative constant
-/// values are folded.
+/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
+/// non-negative and non-zero constant values are folded respectively.
 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
-                                   bool onlyNonNegative = false);
+                                   bool onlyNonNegative = false,
+                                   bool onlyNonZero = false);
+
+/// Returns "success" when any of the elements in `OffsetsOrSizes` is a
+/// constant value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
+
+/// Returns "success" when any of the elements in `strides` is a constant
+/// value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dce96cca016ff..b2d52e400e52d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2581,17 +2581,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  // If one of the offsets or sizes is invalid, fail the canonicalization.
-  // These checks also occur in the verifier, but they are needed here
-  // because some dynamic dimensions may have been constant folded.
-  for (int64_t offset : staticOffsets)
-    if (offset < 0 && !ShapedType::isDynamic(offset))
-      return {};
-  for (int64_t size : staticSizes)
-    if (size < 0 && !ShapedType::isDynamic(size))
-      return {};
-
+  if (!hasValidSizesOffsets(staticOffsets))
+    return {};
+  if (!hasValidSizesOffsets(staticSizes))
+    return {};
+  if (!hasValidStrides(staticStrides))
+    return {};
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
                                     staticSizes, staticStrides);
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8970ea1c73b40..94b7b734f88fe 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1446,13 +1446,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
     SmallVector<int64_t> newShape;
     operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
 
-    for (int64_t newdim : newShape) {
-      // This check also occurs in the verifier, but we need it here too
-      // since intermediate passes may have replaced some dynamic dimensions
-      // by constants.
-      if (newdim < 0 && !ShapedType::isDynamic(newdim))
-        return failure();
-    }
+    if (!hasValidSizesOffsets(newShape))
+      return failure();
 
     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
       return failure();
@@ -2548,9 +2543,9 @@ class InsertSliceOpConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedStrides)))
+    if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
+        failed(foldDynamicOffsetSizeList(mixedSizes)) &&
+        failed(foldDynamicStrideList(mixedStrides)))
       return failure();
 
     // Create the new op in canonical form.
@@ -2691,6 +2686,8 @@ struct InsertSliceOpSourceCastInserter final
         newSrcShape[i] = *constInt;
       }
     }
+    if (!hasValidSizesOffsets(newSrcShape))
+      return failure();
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index c7a3d8fc8eb28..0c8a88da789e2 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
   return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
 }
 
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
+  return llvm::none_of(sizesOrOffsets, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value < 0;
+  });
+}
+
+bool hasValidStrides(SmallVector<int64_t> strides) {
+  return llvm::none_of(strides, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value == 0;
+  });
+}
+
 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
-                                   bool onlyNonNegative) {
+                                   bool onlyNonNegative, bool onlyNonZero) {
   bool valuesChanged = false;
   for (OpFoldResult &ofr : ofrs) {
     if (ofr.is<Attribute>())
@@ -267,6 +279,8 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
       // Note: All ofrs have index type.
       if (onlyNonNegative && *getConstantIntValue(attr) < 0)
         continue;
+      if (onlyNonZero && *getConstantIntValue(attr) == 0)
+        continue;
       ofr = attr;
       valuesChanged = true;
     }
@@ -274,4 +288,15 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
   return success(valuesChanged);
 }
 
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
+  return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
+                              /*onlyNonZero=*/false);
+}
+
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
+  return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
+                              /*onlyNonZero=*/true);
+}
+
 } // namespace mlir
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a1f8673638ff8..d3406c630f6dd 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?
 
 // -----
 
+// CHECK-LABEL: func @no_fold_subview_zero_stride
+//  CHECK:        %[[SUBVIEW:.+]] = memref.subview
+//  CHECK:        return %[[SUBVIEW]]
+func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>>
+  return %1 : memref<1xf32, strided<[?], offset: 1>>
+}
+
+// -----
+
 // CHECK-LABEL: func @no_fold_of_store
 //  CHECK:   %[[cst:.+]] = memref.cast %arg
 //  CHECK:   memref.store %[[cst]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74200


More information about the cfe-commits mailing list