[Mlir-commits] [mlir] 4cf9bf6 - [mlir][MemRef] Compute unused dimensions of a rank-reducing subviews using strides as well.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 20 11:05:55 PDT 2021
Author: MaheshRavishankar
Date: 2021-09-20T11:05:30-07:00
New Revision: 4cf9bf6c9f64cca1111134acc9f84efe8f27e8d1
URL: https://github.com/llvm/llvm-project/commit/4cf9bf6c9f64cca1111134acc9f84efe8f27e8d1
DIFF: https://github.com/llvm/llvm-project/commit/4cf9bf6c9f64cca1111134acc9f84efe8f27e8d1.diff
LOG: [mlir][MemRef] Compute unused dimensions of a rank-reducing subviews using strides as well.
For `memref.subview` operations, when there are more than one
unit-dimensions, the strides need to be used to figure out which of
the unit-dims are actually dropped.
Differential Revision: https://reviews.llvm.org/D109418
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/MemRef/fold-subview-ops.mlir
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/IR/invalid-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index e0cb3816efafe..dd8455a7f9190 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1379,6 +1379,10 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
/// Return the number of leading operands before the `offsets`, `sizes` and
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+
+ /// Return the dimensions of the source type that are dropped when
+ /// the result is rank-reduced.
+ llvm::SmallDenseSet<unsigned> getDroppedDims();
}];
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index f5227361165b3..50ebeaa44a5c3 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -66,7 +66,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
let cppNamespace = "::mlir";
let methods = [
- InterfaceMethod<
+ StaticInterfaceMethod<
/*desc=*/[{
Return the number of leading operands before the `offsets`, `sizes` and
and `strides` operands.
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index ebca204ab8486..85c05f04b07f3 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1272,12 +1272,8 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
extracted);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
- auto shape = viewMemRefType.getShape();
- auto inferredShape = inferredType.getShape();
- size_t inferredShapeRank = inferredShape.size();
- size_t resultShapeRank = shape.size();
- llvm::SmallDenseSet<unsigned> unusedDims =
- computeRankReductionMask(inferredShape, shape).getValue();
+ size_t inferredShapeRank = inferredType.getRank();
+ size_t resultShapeRank = viewMemRefType.getRank();
// Extract strides needed to compute offset.
SmallVector<Value, 4> strideValues;
@@ -1315,6 +1311,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
assert(mixedSizes.size() == mixedStrides.size() &&
"expected sizes and strides of equal length");
+ llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
i >= 0 && j >= 0; --i) {
if (unusedDims.contains(i))
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f0bf8c639e4a4..f80d373c41e0d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -690,6 +690,92 @@ static LogicalResult verify(DimOp op) {
return success();
}
+/// Return a map with key being elements in `vals` and data being number of
+/// occurences of it. Use std::map, since the `vals` here are strides and the
+/// dynamic stride value is the same as the tombstone value for
+/// `DenseMap<int64_t>`.
+static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
+ std::map<int64_t, unsigned> numOccurences;
+ for (auto val : vals)
+ numOccurences[val]++;
+ return numOccurences;
+}
+
+/// Given the type of the un-rank reduced subview result type and the
+/// rank-reduced result type, computes the dropped dimensions. This accounts for
+/// cases where there are multiple unit-dims, but only a subset of those are
+/// dropped. For MemRefTypes these can be disambiguated using the strides. If a
+/// dimension is dropped the stride must be dropped too.
+static llvm::Optional<llvm::SmallDenseSet<unsigned>>
+computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
+ ArrayAttr staticSizes) {
+ llvm::SmallDenseSet<unsigned> unusedDims;
+ if (originalType.getRank() == reducedType.getRank())
+ return unusedDims;
+
+ for (auto dim : llvm::enumerate(staticSizes))
+ if (dim.value().cast<IntegerAttr>().getInt() == 1)
+ unusedDims.insert(dim.index());
+ SmallVector<int64_t> originalStrides, candidateStrides;
+ int64_t originalOffset, candidateOffset;
+ if (failed(
+ getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
+ failed(
+ getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
+ return llvm::None;
+
+ // For memrefs, a dimension is truly dropped if its corresponding stride is
+ // also dropped. This is particularly important when more than one of the dims
+ // is 1. Track the number of occurences of the strides in the original type
+ // and the candidate type. For each unused dim that stride should not be
+ // present in the candidate type. Note that there could be multiple dimensions
+ // that have the same size. We dont need to exactly figure out which dim
+ // corresponds to which stride, we just need to verify that the number of
+ // reptitions of a stride in the original + number of unused dims with that
+ // stride == number of repititions of a stride in the candidate.
+ std::map<int64_t, unsigned> currUnaccountedStrides =
+ getNumOccurences(originalStrides);
+ std::map<int64_t, unsigned> candidateStridesNumOccurences =
+ getNumOccurences(candidateStrides);
+ llvm::SmallDenseSet<unsigned> prunedUnusedDims;
+ for (unsigned dim : unusedDims) {
+ int64_t originalStride = originalStrides[dim];
+ if (currUnaccountedStrides[originalStride] >
+ candidateStridesNumOccurences[originalStride]) {
+ // This dim can be treated as dropped.
+ currUnaccountedStrides[originalStride]--;
+ continue;
+ }
+ if (currUnaccountedStrides[originalStride] ==
+ candidateStridesNumOccurences[originalStride]) {
+ // The stride for this is not dropped. Keep as is.
+ prunedUnusedDims.insert(dim);
+ continue;
+ }
+ if (currUnaccountedStrides[originalStride] <
+ candidateStridesNumOccurences[originalStride]) {
+ // This should never happen. Cant have a stride in the reduced rank type
+ // that wasnt in the original one.
+ return llvm::None;
+ }
+ }
+
+ for (auto prunedDim : prunedUnusedDims)
+ unusedDims.erase(prunedDim);
+ if (unusedDims.size() + reducedType.getRank() != originalType.getRank())
+ return llvm::None;
+ return unusedDims;
+}
+
+llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
+ MemRefType sourceType = getSourceType();
+ MemRefType resultType = getType();
+ llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
+ computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
+ assert(unusedDims && "unable to find unused dims of subview");
+ return *unusedDims;
+}
+
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
// All forms of folding require a known index.
auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
@@ -725,6 +811,25 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return *(view.getDynamicSizes().begin() +
memrefType.getDynamicDimIndex(unsignedIndex));
+ if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
+ llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims();
+ unsigned resultIndex = 0;
+ unsigned sourceRank = subview.getSourceType().getRank();
+ unsigned sourceIndex = 0;
+ for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
+ if (unusedDims.count(i))
+ continue;
+ if (resultIndex == unsignedIndex) {
+ sourceIndex = i;
+ break;
+ }
+ resultIndex++;
+ }
+ assert(subview.isDynamicSize(sourceIndex) &&
+ "expected dynamic subview size");
+ return subview.getDynamicSize(sourceIndex);
+ }
+
if (auto sizeInterface =
dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
assert(sizeInterface.isDynamicSize(unsignedIndex) &&
@@ -1887,7 +1992,7 @@ enum SubViewVerificationResult {
/// not matching dimension must be 1.
static SubViewVerificationResult
isRankReducedType(Type originalType, Type candidateReducedType,
- std::string *errMsg = nullptr) {
+ ArrayAttr staticSizes, std::string *errMsg = nullptr) {
if (originalType == candidateReducedType)
return SubViewVerificationResult::Success;
if (!originalType.isa<MemRefType>())
@@ -1908,8 +2013,11 @@ isRankReducedType(Type originalType, Type candidateReducedType,
if (candidateReducedRank > originalRank)
return SubViewVerificationResult::RankTooLarge;
+ MemRefType original = originalType.cast<MemRefType>();
+ MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
+
auto optionalUnusedDimsMask =
- computeRankReductionMask(originalShape, candidateReducedShape);
+ computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
// Sizes cannot be matched in case empty vector is returned.
if (!optionalUnusedDimsMask.hasValue())
@@ -1920,42 +2028,8 @@ isRankReducedType(Type originalType, Type candidateReducedType,
return SubViewVerificationResult::ElemTypeMismatch;
// Strided layout logic is relevant for MemRefType only.
- MemRefType original = originalType.cast<MemRefType>();
- MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
if (original.getMemorySpace() != candidateReduced.getMemorySpace())
return SubViewVerificationResult::MemSpaceMismatch;
-
- llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
- auto inferredType =
- getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
- AffineMap candidateLayout;
- if (candidateReduced.getAffineMaps().empty())
- candidateLayout = getStridedLinearLayoutMap(candidateReduced);
- else
- candidateLayout = candidateReduced.getAffineMaps().front();
- assert(inferredType.getNumResults() == 1 &&
- candidateLayout.getNumResults() == 1);
- if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
- inferredType.getNumDims() != candidateLayout.getNumDims()) {
- if (errMsg) {
- llvm::raw_string_ostream os(*errMsg);
- os << "inferred type: " << inferredType;
- }
- return SubViewVerificationResult::AffineMapMismatch;
- }
- // Check that the
diff erence of the affine maps simplifies to 0.
- AffineExpr
diff Expr =
- inferredType.getResult(0) - candidateLayout.getResult(0);
-
diff Expr = simplifyAffineExpr(
diff Expr, inferredType.getNumDims(),
- inferredType.getNumSymbols());
- auto cst =
diff Expr.dyn_cast<AffineConstantExpr>();
- if (!(cst && cst.getValue() == 0)) {
- if (errMsg) {
- llvm::raw_string_ostream os(*errMsg);
- os << "inferred type: " << inferredType;
- }
- return SubViewVerificationResult::AffineMapMismatch;
- }
return SubViewVerificationResult::Success;
}
@@ -2012,7 +2086,8 @@ static LogicalResult verify(SubViewOp op) {
extractFromI64ArrayAttr(op.static_strides()));
std::string errMsg;
- auto result = isRankReducedType(expectedType, subViewType, &errMsg);
+ auto result =
+ isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
return produceSubViewErrorMsg(result, op, expectedType, errMsg);
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index 4e1424083e96b..17ec4a1ba7fe6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -49,18 +49,13 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
SmallVector<Value> useIndices;
// Check if this is rank-reducing case. Then for every unit-dim size add a
// zero to the indices.
- ArrayRef<int64_t> resultShape = subViewOp.getType().getShape();
unsigned resultDim = 0;
- for (auto size : llvm::enumerate(mixedSizes)) {
- auto attr = size.value().dyn_cast<Attribute>();
- // Check if this dimension has been dropped, i.e. the size is 1, but the
- // associated dimension is not 1.
- if (attr && attr.cast<IntegerAttr>().getInt() == 1 &&
- (resultDim >= resultShape.size() || resultShape[resultDim] != 1))
+ llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
+ for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
+ if (unusedDims.count(dim))
useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
- else if (resultDim < resultShape.size()) {
+ else
useIndices.push_back(indices[resultDim++]);
- }
}
if (useIndices.size() != mixedOffsets.size())
return failure();
@@ -104,6 +99,25 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.source();
}
+/// Given the permutation map of the original
+/// `vector.transfer_read`/`vector.transfer_write` operations compute the
+/// permutation map to use after the subview is folded with it.
+static AffineMap getPermutationMap(MLIRContext *context,
+ memref::SubViewOp subViewOp,
+ AffineMap currPermutationMap) {
+ llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
+ SmallVector<AffineExpr> exprs;
+ unsigned resultIdx = 0;
+ int64_t sourceRank = subViewOp.getSourceType().getRank();
+ for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
+ if (unusedDims.count(dim))
+ continue;
+ exprs.push_back(getAffineDimExpr(resultIdx++, context));
+ }
+ auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
+ return currPermutationMap.compose(resultDimToSourceDimMap);
+}
+
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
@@ -153,7 +167,9 @@ void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
- loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr());
+ getPermutationMap(rewriter.getContext(), subViewOp,
+ loadOp.permutation_map()),
+ loadOp.padding(), loadOp.in_boundsAttr());
}
template <>
@@ -170,7 +186,9 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
- sourceIndices, transferWriteOp.permutation_map(),
+ sourceIndices,
+ getPermutationMap(rewriter.getContext(), subViewOp,
+ transferWriteOp.permutation_map()),
transferWriteOp.in_boundsAttr());
}
} // namespace
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 718dd2f9a789f..747471623c248 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1418,3 +1418,28 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+
+// -----
+
+func @lower_to_loops_with_rank_reducing_subviews(
+ %arg0 : memref<?xi32>, %arg1 : memref<?x?xi32>, %arg2 : index,
+ %arg3 : index, %arg4 : index) {
+ %0 = memref.subview %arg0[%arg2] [%arg3] [1]
+ : memref<?xi32> to memref<?xi32, offset: ?, strides: [1]>
+ %1 = memref.subview %arg1[0, %arg4] [1, %arg3] [1, 1]
+ : memref<?x?xi32> to memref<?xi32, offset: ?, strides : [1]>
+ linalg.copy(%0, %1)
+ : memref<?xi32, offset: ?, strides: [1]>, memref<?xi32, offset: ?, strides: [1]>
+ return
+}
+// CHECK-LABEL: func @lower_to_loops_with_rank_reducing_subviews
+// CHECK: scf.for %[[IV:.+]] = %{{.+}} to %{{.+}} step %{{.+}} {
+// CHECK: %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]]
+// CHECK: memref.store %[[VAL]], %{{.+}}[%[[IV]]]
+// CHECK: }
+
+// CHECKPARALLEL-LABEL: func @lower_to_loops_with_rank_reducing_subviews
+// CHECKPARALLEL: scf.parallel (%[[IV:.+]]) = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) {
+// CHECKPARALLEL: %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]]
+// CHECKPARALLEL: memref.store %[[VAL]], %{{.+}}[%[[IV]]]
+// CHECKPARALLEL: }
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d73bb136025ce..ec57845a00f14 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -159,6 +159,63 @@ func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : inde
// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
// CHECK: return %[[RESULT]]
+// -----
+
+func @multiple_reducing_dims(%arg0 : memref<1x384x384xf32>,
+ %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, offset: ?, strides: [1]>
+{
+ %c1 = constant 1 : index
+ %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref<?x?xf32, offset: ?, strides: [384, 1]>
+ %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, offset: ?, strides: [384, 1]> to memref<?xf32, offset: ?, strides: [1]>
+ return %1 : memref<?xf32, offset: ?, strides: [1]>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)>
+// CHECK: func @multiple_reducing_dims
+// CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
+// CHECK-SAME: : memref<1x384x384xf32> to memref<1x?xf32, #[[MAP1]]>
+// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
+// CHECK-SAME: : memref<1x?xf32, #[[MAP1]]> to memref<?xf32, #[[MAP0]]>
+
+// -----
+
+func @multiple_reducing_dims_dynamic(%arg0 : memref<?x?x?xf32>,
+ %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, offset: ?, strides: [1]>
+{
+ %c1 = constant 1 : index
+ %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?xf32, offset: ?, strides: [1]>
+ return %1 : memref<?xf32, offset: ?, strides: [1]>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK: func @multiple_reducing_dims_dynamic
+// CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
+// CHECK-SAME: : memref<?x?x?xf32> to memref<1x?xf32, #[[MAP1]]>
+// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
+// CHECK-SAME: : memref<1x?xf32, #[[MAP1]]> to memref<?xf32, #[[MAP0]]>
+
+// -----
+
+func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>,
+ %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, offset: ?, strides: [?]>
+{
+ %c1 = constant 1 : index
+ %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1]
+ : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?xf32, offset: ?, strides: [?]>
+ return %1 : memref<?xf32, offset: ?, strides: [?]>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+// CHECK: func @multiple_reducing_dims_all_dynamic
+// CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
+// CHECK-SAME: : memref<?x?x?xf32, #[[MAP2]]> to memref<1x?xf32, #[[MAP1]]>
+// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
+// CHECK-SAME: : memref<1x?xf32, #[[MAP1]]> to memref<?xf32, #[[MAP0]]>
+
+
// -----
// CHECK-LABEL: @clone_before_dealloc
@@ -567,4 +624,3 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
%collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
return %collapsed : memref<?x?xf32>
}
-
diff --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
index 246c0b3552947..558b44350af7b 100644
--- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
@@ -160,3 +160,66 @@ func @fold_rank_reducing_subview_with_load
// CHECK-DAG: %[[I5:.+]] = affine.apply #[[MAP]](%[[ARG16]])[%[[ARG11]], %[[ARG5]]]
// CHECK-DAG: %[[I6:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG12]], %[[ARG6]]]
// CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]]
+
+// -----
+
+func @fold_vector_transfer_read_with_rank_reduced_subview(
+ %arg0 : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>,
+ %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+ %arg6 : index) -> vector<4xf32> {
+ %cst = constant 0.0 : f32
+ %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %arg3, %arg4] [1, 1, 1]
+ : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]> to
+ memref<?x?xf32, offset: ?, strides: [?, ?]>
+ %1 = vector.transfer_read %0[%arg5, %arg6], %cst {in_bounds = [true]}
+ : memref<?x?xf32, offset: ?, strides: [?, ?]>, vector<4xf32>
+ return %1 : vector<4xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK: func @fold_vector_transfer_read_with_rank_reduced_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?xf32, #[[MAP0]]>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]](%[[ARG5]])[%[[ARG1]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG6]])[%[[ARG2]]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]]
+// CHECK-SAME: permutation_map = #[[MAP2]]
+
+// -----
+
+func @fold_vector_transfer_write_with_rank_reduced_subview(
+ %arg0 : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>,
+ %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+ %arg5: index, %arg6 : index, %arg7 : index) {
+ %cst = constant 0.0 : f32
+ %0 = memref.subview %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1]
+ : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]> to
+ memref<?x?xf32, offset: ?, strides: [?, ?]>
+ vector.transfer_write %arg1, %0[%arg6, %arg7] {in_bounds = [true]}
+ : vector<4xf32>, memref<?x?xf32, offset: ?, strides: [?, ?]>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK: func @fold_vector_transfer_write_with_rank_reduced_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?xf32, #[[MAP0]]>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]](%[[ARG6]])[%[[ARG2]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG7]])[%[[ARG3]]]
+// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]]
+// CHECK-SAME: permutation_map = #[[MAP2]]
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index dcd1a6b128498..b93815533119c 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -353,3 +353,12 @@ func @collapse_shape_illegal_mixed_memref_2(%arg0 : memref<?x4x5xf32>)
: memref<?x4x5xf32> into memref<?x?xf32>
return %0 : memref<?x?xf32>
}
+
+// -----
+
+func @static_stride_to_dynamic_stride(%arg0 : memref<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> memref<?x?xf32, offset:?, strides: [?, ?]> {
+ // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}}
+ %0 = memref.subview %arg0[0, 0, 0] [1, %arg1, %arg2] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ return %0 : memref<?x?xf32, offset: ?, strides: [?, ?]>
+}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 30b1b411d2df8..265f095fe2272 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -960,17 +960,6 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
// -----
-func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
- %0 = memref.alloc() : memref<8x16x4xf32>
- // expected-error at +1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result affine map)}}
- %1 = memref.subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2]
- : memref<8x16x4xf32> to
- memref<?x?x?xf32, offset: ?, strides: [64, 4, 1]>
- return
-}
-
-// -----
-
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = memref.alloc() : memref<8x16x4xf32>
// expected-error at +1 {{expected result element type to be 'f32'}}
@@ -1014,22 +1003,13 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
// -----
func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
- // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map)}}
+ // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result sizes)}}
%0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
return
}
// -----
-// The affine map affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)> has an extra unused symbol.
-func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
- // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map) inferred type: (d0)[s0, s1] -> (d0 * s1 + s0)}}
- %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32, affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)>>
- return
-}
-
-// -----
-
func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
// expected-error at +1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}}
%0 = memref.cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>
More information about the Mlir-commits
mailing list