[clang] [clang-tools-extra] [llvm] [mlir] [mlir] Fix a zero stride canonicalizer crash (PR #74200)
via cfe-commits
cfe-commits at lists.llvm.org
Sat Dec 2 09:10:23 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-tensor
Author: Rik Huijzer (rikhuijzer)
<details>
<summary>Changes</summary>
This PR fixes https://github.com/llvm/llvm-project/issues/73383 and is another shot at the refactoring proposed in https://github.com/llvm/llvm-project/pull/72885.
---
Full diff: https://github.com/llvm/llvm-project/pull/74200.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+27-3)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-11)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-10)
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+26-1)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+12)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 502ab93ddbfa7..a1853438ccf7f 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -139,12 +139,36 @@ SmallVector<int64_t>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
llvm::function_ref<bool(Attribute, Attribute)> compare);
+/// Helper function to check whether the passed in `sizes` or `values` are
+/// valid. This can be used to re-check whether dimensions are still valid
+/// after constant folding the dynamic dimensions.
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
+
+/// Helper function to check whether the passed in `strides` are valid. This
+/// can be used to re-check whether dimensions are still valid after constant
+/// folding the dynamic dimensions.
+bool hasValidStrides(SmallVector<int64_t> strides);
+
/// Returns "success" when any of the elements in `ofrs` is a constant value. In
/// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened. If `onlyNonNegative` is set, only non-negative constant
-/// values are folded.
+/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
+/// non-negative and non-zero constant values are folded respectively.
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
- bool onlyNonNegative = false);
+ bool onlyNonNegative = false,
+ bool onlyNonZero = false);
+
+/// Returns "success" when any of the elements in `OffsetsOrSizes` is a
+/// constant value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
+
+/// Returns "success" when any of the elements in `strides` is a constant
+/// value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dce96cca016ff..b2d52e400e52d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2581,17 +2581,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
- // If one of the offsets or sizes is invalid, fail the canonicalization.
- // These checks also occur in the verifier, but they are needed here
- // because some dynamic dimensions may have been constant folded.
- for (int64_t offset : staticOffsets)
- if (offset < 0 && !ShapedType::isDynamic(offset))
- return {};
- for (int64_t size : staticSizes)
- if (size < 0 && !ShapedType::isDynamic(size))
- return {};
-
+ if (!hasValidSizesOffsets(staticOffsets))
+ return {};
+ if (!hasValidSizesOffsets(staticSizes))
+ return {};
+ if (!hasValidStrides(staticStrides))
+ return {};
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
staticSizes, staticStrides);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8970ea1c73b40..94b7b734f88fe 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1446,13 +1446,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
SmallVector<int64_t> newShape;
operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
- for (int64_t newdim : newShape) {
- // This check also occurs in the verifier, but we need it here too
- // since intermediate passes may have replaced some dynamic dimensions
- // by constants.
- if (newdim < 0 && !ShapedType::isDynamic(newdim))
- return failure();
- }
+ if (!hasValidSizesOffsets(newShape))
+ return failure();
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
return failure();
@@ -2548,9 +2543,9 @@ class InsertSliceOpConstantArgumentFolder final
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
- failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
- failed(foldDynamicIndexList(mixedStrides)))
+ if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
+ failed(foldDynamicOffsetSizeList(mixedSizes)) &&
+ failed(foldDynamicStrideList(mixedStrides)))
return failure();
// Create the new op in canonical form.
@@ -2691,6 +2686,8 @@ struct InsertSliceOpSourceCastInserter final
newSrcShape[i] = *constInt;
}
}
+ if (!hasValidSizesOffsets(newSrcShape))
+ return failure();
RankedTensorType newSrcType =
RankedTensorType::get(newSrcShape, srcType.getElementType());
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index c7a3d8fc8eb28..0c8a88da789e2 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
}
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
+ return llvm::none_of(sizesOrOffsets, [](int64_t value) {
+ return !ShapedType::isDynamic(value) && value < 0;
+ });
+}
+
+bool hasValidStrides(SmallVector<int64_t> strides) {
+ return llvm::none_of(strides, [](int64_t value) {
+ return !ShapedType::isDynamic(value) && value == 0;
+ });
+}
+
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
- bool onlyNonNegative) {
+ bool onlyNonNegative, bool onlyNonZero) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>())
@@ -267,6 +279,8 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
// Note: All ofrs have index type.
if (onlyNonNegative && *getConstantIntValue(attr) < 0)
continue;
+ if (onlyNonZero && *getConstantIntValue(attr) == 0)
+ continue;
ofr = attr;
valuesChanged = true;
}
@@ -274,4 +288,15 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
return success(valuesChanged);
}
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
+ return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
+ /*onlyNonZero=*/false);
+}
+
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
+ return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
+ /*onlyNonZero=*/true);
+}
+
} // namespace mlir
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a1f8673638ff8..d3406c630f6dd 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?
// -----
+// CHECK-LABEL: func @no_fold_subview_zero_stride
+// CHECK: %[[SUBVIEW:.+]] = memref.subview
+// CHECK: return %[[SUBVIEW]]
+func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>>
+ return %1 : memref<1xf32, strided<[?], offset: 1>>
+}
+
+// -----
+
// CHECK-LABEL: func @no_fold_of_store
// CHECK: %[[cst:.+]] = memref.cast %arg
// CHECK: memref.store %[[cst]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/74200
More information about the cfe-commits
mailing list