[Mlir-commits] [mlir] e3373c6 - [mlir][memref] Fix crash in SubViewReturnTypeCanonicalizer
Matthias Springer
llvmlistbot at llvm.org
Fri Aug 25 07:02:05 PDT 2023
Author: Matthias Springer
Date: 2023-08-25T16:01:49+02:00
New Revision: e3373c6c83d3855adb78f1952a3bf0398baf359e
URL: https://github.com/llvm/llvm-project/commit/e3373c6c83d3855adb78f1952a3bf0398baf359e
DIFF: https://github.com/llvm/llvm-project/commit/e3373c6c83d3855adb78f1952a3bf0398baf359e.diff
LOG: [mlir][memref] Fix crash in SubViewReturnTypeCanonicalizer
`SubViewReturnTypeCanonicalizer` is used by `OpWithOffsetSizesAndStridesConstantArgumentFolder`, which folds constant SSA value (dynamic) sizes into static sizes. The previous implementation crashed when a dynamic size was folded into a static `1` dimension, which was then mistaken as a rank reduction.
Differential Revision: https://reviews.llvm.org/D158721
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b2909962027de2..e1b8dd62450a77 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -31,23 +31,17 @@ namespace {
namespace saturated_arith {
struct Wrapper {
static Wrapper stride(int64_t v) {
- return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
- : Wrapper{false, v};
+ return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
static Wrapper offset(int64_t v) {
- return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
- : Wrapper{false, v};
+ return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
static Wrapper size(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
- int64_t asOffset() {
- return saturated ? ShapedType::kDynamic : v;
- }
+ int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
- int64_t asStride() {
- return saturated ? ShapedType::kDynamic : v;
- }
+ int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
bool operator==(Wrapper other) {
return (saturated && other.saturated) ||
(!saturated && !other.saturated && v == other.v);
@@ -731,8 +725,7 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
- if (ShapedType::isDynamic(ss) &&
- !ShapedType::isDynamic(st))
+ if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
return false;
}
@@ -765,8 +758,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
// same. They are also compatible if either one is dynamic (see
// description of MemRefCastOp for details).
auto checkCompatible = [](int64_t a, int64_t b) {
- return (ShapedType::isDynamic(a) ||
- ShapedType::isDynamic(b) || a == b);
+ return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
};
if (!checkCompatible(aOffset, bOffset))
return false;
@@ -1889,8 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
if (!ShapedType::isDynamic(resultOffset) &&
- !ShapedType::isDynamic(expectedOffset) &&
- resultOffset != expectedOffset)
+ !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
<< expectedOffset << " instead of " << resultOffset;
@@ -2944,18 +2935,6 @@ static MemRefType getCanonicalSubViewResultType(
nonRankReducedType.getMemorySpace());
}
-/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
-/// to deduce the result type. Additionally, reduce the rank of the inferred
-/// result type if `currentResultType` is lower rank than `sourceType`.
-static MemRefType getCanonicalSubViewResultType(
- MemRefType currentResultType, MemRefType sourceType,
- ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
- ArrayRef<OpFoldResult> mixedStrides) {
- return getCanonicalSubViewResultType(currentResultType, sourceType,
- sourceType, mixedOffsets, mixedSizes,
- mixedStrides);
-}
-
Value mlir::memref::createCanonicalRankReducingSubViewOp(
OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
auto memrefType = llvm::cast<MemRefType>(memref.getType());
@@ -3108,9 +3087,32 @@ struct SubViewReturnTypeCanonicalizer {
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
- return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
- mixedOffsets, mixedSizes,
- mixedStrides);
+ // Infer a memref type without taking into account any rank reductions.
+ MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
+ op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));
+
+ // Directly return the non-rank reduced type if there are no dropped dims.
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+ if (droppedDims.empty())
+ return nonReducedType;
+
+ // Take the strides and offset from the non-rank reduced type.
+ auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
+
+ // Drop dims from shape and strides.
+ SmallVector<int64_t> targetShape;
+ SmallVector<int64_t> targetStrides;
+ for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
+ if (droppedDims.test(i))
+ continue;
+ targetStrides.push_back(nonReducedStrides[i]);
+ targetShape.push_back(nonReducedType.getDimSize(i));
+ }
+
+ return MemRefType::get(targetShape, nonReducedType.getElementType(),
+ StridedLayoutAttr::get(nonReducedType.getContext(),
+ offset, targetStrides),
+ nonReducedType.getMemorySpace());
}
};
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b65426cad30b6d..df66705e83e0e2 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -931,7 +931,7 @@ func.func @fold_multiple_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32
// -----
-// CHECK-lABEL: func @ub_negative_alloc_size
+// CHECK-LABEL: func private @ub_negative_alloc_size
func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
%idx1 = index.constant 1
%c-2 = arith.constant -2 : index
@@ -940,3 +940,18 @@ func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
%alloc = memref.alloc(%c15, %c-2, %idx1) : memref<?x?x?xi1>
return %alloc : memref<?x?x?xi1>
}
+
+// -----
+
+// CHECK-LABEL: func @subview_rank_reduction(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index
+func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
+ -> memref<?x?xf32, strided<[384, 1], offset: ?>> {
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
+ // CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>>
+ %0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1]
+ : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
+ // CHECK: return %[[cast]]
+ return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
+}
More information about the Mlir-commits
mailing list