[Mlir-commits] [mlir] 311dd55 - [mlir][MemRef] Fix SubViewOp canonicalization when a subset of unit-dims are dropped.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Nov 30 12:38:32 PST 2021
Author: MaheshRavishankar
Date: 2021-11-30T20:37:06Z
New Revision: 311dd55c9eb9342b1c889f6db7728f15b05378bb
URL: https://github.com/llvm/llvm-project/commit/311dd55c9eb9342b1c889f6db7728f15b05378bb
DIFF: https://github.com/llvm/llvm-project/commit/311dd55c9eb9342b1c889f6db7728f15b05378bb.diff
LOG: [mlir][MemRef] Fix SubViewOp canonicalization when a subset of unit-dims are dropped.
The canonical type of the result of the `memref.subview` needs to make
sure that the previously dropped unit-dimensions are the ones dropped
for the canonicalized type as well. This means the generic
`inferRankReducedResultType` cannot be used. Instead the current
dropped dimensions need to be querried and the same need to be dropped.
Reviewed By: nicolasvasilache, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D114751
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
index 11d81d7050394..4c3799cb9428a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
@@ -63,6 +63,8 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
ResultTypeFunc resultTypeFunc;
auto resultType =
resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
+ if (!resultType)
+ return failure();
auto newOp =
rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
mixedOffsets, mixedSizes, mixedStrides);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ac36d1dc11bbb..b9bd01b439a9d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -511,14 +511,16 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
/// dimension is dropped the stride must be dropped too.
static llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
- ArrayAttr staticSizes) {
+ ArrayRef<OpFoldResult> sizes) {
llvm::SmallDenseSet<unsigned> unusedDims;
if (originalType.getRank() == reducedType.getRank())
return unusedDims;
- for (auto dim : llvm::enumerate(staticSizes))
- if (dim.value().cast<IntegerAttr>().getInt() == 1)
- unusedDims.insert(dim.index());
+ for (auto dim : llvm::enumerate(sizes))
+ if (auto attr = dim.value().dyn_cast<Attribute>())
+ if (attr.cast<IntegerAttr>().getInt() == 1)
+ unusedDims.insert(dim.index());
+
SmallVector<int64_t> originalStrides, candidateStrides;
int64_t originalOffset, candidateOffset;
if (failed(
@@ -574,7 +576,7 @@ llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
MemRefType sourceType = getSourceType();
MemRefType resultType = getType();
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
- computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
+ computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
assert(unusedDims && "unable to find unused dims of subview");
return *unusedDims;
}
@@ -1718,7 +1720,7 @@ enum SubViewVerificationResult {
/// not matching dimension must be 1.
static SubViewVerificationResult
isRankReducedType(Type originalType, Type candidateReducedType,
- ArrayAttr staticSizes, std::string *errMsg = nullptr) {
+ ArrayRef<OpFoldResult> sizes, std::string *errMsg = nullptr) {
if (originalType == candidateReducedType)
return SubViewVerificationResult::Success;
if (!originalType.isa<MemRefType>())
@@ -1743,7 +1745,7 @@ isRankReducedType(Type originalType, Type candidateReducedType,
MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
auto optionalUnusedDimsMask =
- computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
+ computeMemRefRankReductionMask(original, candidateReduced, sizes);
// Sizes cannot be matched in case empty vector is returned.
if (!optionalUnusedDimsMask.hasValue())
@@ -1813,7 +1815,7 @@ static LogicalResult verify(SubViewOp op) {
std::string errMsg;
auto result =
- isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
+ isRankReducedType(expectedType, subViewType, op.getMixedSizes(), &errMsg);
return produceSubViewErrorMsg(result, op, expectedType, errMsg);
}
@@ -1854,21 +1856,29 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
/// Infer the canonical type of the result of a subview operation. Returns a
/// type with rank `resultRank` that is either the rank of the rank-reduced
/// type, or the non-rank-reduced type.
-static MemRefType
-getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
- ArrayRef<OpFoldResult> mixedOffsets,
- ArrayRef<OpFoldResult> mixedSizes,
- ArrayRef<OpFoldResult> mixedStrides) {
- auto resultType =
- SubViewOp::inferRankReducedResultType(
- resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
- .cast<MemRefType>();
- if (resultType.getRank() != resultRank) {
- resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
- mixedSizes, mixedStrides)
- .cast<MemRefType>();
+static MemRefType getCanonicalSubViewResultType(
+ MemRefType currentResultType, MemRefType sourceType,
+ ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
+ ArrayRef<OpFoldResult> mixedStrides) {
+ auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
+ mixedSizes, mixedStrides)
+ .cast<MemRefType>();
+ llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
+ computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes);
+ // Return nullptr as failure mode.
+ if (!unusedDims)
+ return nullptr;
+ SmallVector<int64_t> shape;
+ for (auto sizes : llvm::enumerate(nonRankReducedType.getShape())) {
+ if (unusedDims->count(sizes.index()))
+ continue;
+ shape.push_back(sizes.value());
}
- return resultType;
+ AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
+ if (!layoutMap.isIdentity())
+ layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
+ return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
+ nonRankReducedType.getMemorySpace());
}
namespace {
@@ -1911,8 +1921,7 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
auto resultType = getCanonicalSubViewResultType(
- subViewOp.getType().getRank(),
- castOp.source().getType().cast<MemRefType>(),
+ subViewOp.getType(), castOp.source().getType().cast<MemRefType>(),
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
subViewOp.getMixedStrides());
Value newSubView = rewriter.create<SubViewOp>(
@@ -1931,9 +1940,9 @@ struct SubViewReturnTypeCanonicalizer {
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
- return getCanonicalSubViewResultType(op.getType().getRank(),
- op.getSourceType(), mixedOffsets,
- mixedSizes, mixedStrides);
+ return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
+ mixedOffsets, mixedSizes,
+ mixedStrides);
}
};
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 71f4c3e538743..a568d5f3887db 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -47,7 +47,7 @@ func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
// -----
-#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?xf32, #map0>
{
@@ -395,3 +395,25 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
%collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
return %collapsed : memref<?x?xf32>
}
+
+// -----
+
+func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
+ -> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> {
+ %c0 = arith.constant 0 : index
+ %c5 = arith.constant 5 : index
+ %c4 = arith.constant 4 : index
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ %0 = memref.subview %arg0[%arg1, %arg1, %arg1, 0] [%c1, %c4, %c1, 1] [1, 1, 1, 1]
+ : memref<2x5x7x1xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
+ %1 = memref.cast %0
+ : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> to
+ memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
+ return %1 : memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
+}
+
+// CHECK-LABEL: func @reduced_memref
+// CHECK: %[[RESULT:.+]] = memref.subview
+// CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}>
+// CHECK: return %[[RESULT]]
More information about the Mlir-commits
mailing list