[Mlir-commits] [mlir] 6c85a49 - [mlir][memref] Use current source type in getCanonicalSubViewResultType.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 13 06:50:54 PST 2021
Author: gysit
Date: 2021-12-13T14:50:41Z
New Revision: 6c85a49e22021a8709d3d4f2a976b299c7cbcaa6
URL: https://github.com/llvm/llvm-project/commit/6c85a49e22021a8709d3d4f2a976b299c7cbcaa6
DIFF: https://github.com/llvm/llvm-project/commit/6c85a49e22021a8709d3d4f2a976b299c7cbcaa6.diff
LOG: [mlir][memref] Use current source type in getCanonicalSubViewResultType.
Use the current instead of the new source type to compute the rank-reduction map in getCanonicalSubViewResultType. Otherwise, the computation of the rank-reduction map fails when folding a cast into a subview since the strides of the new source type cannot be related to the strides of the current result type.
Depends On D115428
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D115446
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78b4f67f82569..4badc0b31ddb6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1833,18 +1833,23 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
return res;
}
-/// Infer the canonical type of the result of a subview operation. Returns a
-/// type with rank `resultRank` that is either the rank of the rank-reduced
-/// type, or the non-rank-reduced type.
+/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
+/// deduce the result type for the given `sourceType`. Additionally, reduce the
+/// rank of the inferred result type if `currentResultType` is lower rank than
+/// `currentSourceType`. Use this signature if `sourceType` is updated together
+/// with the result type. In this case, it is important to compute the dropped
+/// dimensions using `currentSourceType` whose strides align with
+/// `currentResultType`.
static MemRefType getCanonicalSubViewResultType(
- MemRefType currentResultType, MemRefType sourceType,
- ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
- ArrayRef<OpFoldResult> mixedStrides) {
+ MemRefType currentResultType, MemRefType currentSourceType,
+ MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
mixedSizes, mixedStrides)
.cast<MemRefType>();
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
- computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes);
+ computeMemRefRankReductionMask(currentSourceType, currentResultType,
+ mixedSizes);
// Return nullptr as failure mode.
if (!unusedDims)
return nullptr;
@@ -1861,6 +1866,18 @@ static MemRefType getCanonicalSubViewResultType(
nonRankReducedType.getMemorySpace());
}
+/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
+/// deduce the result type. Additionally, reduce the rank of the inferred result
+/// type if `currentResultType` is lower rank than `sourceType`.
+static MemRefType getCanonicalSubViewResultType(
+ MemRefType currentResultType, MemRefType sourceType,
+ ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
+ ArrayRef<OpFoldResult> mixedStrides) {
+ return getCanonicalSubViewResultType(currentResultType, sourceType,
+ sourceType, mixedOffsets, mixedSizes,
+ mixedStrides);
+}
+
namespace {
/// Pattern to rewrite a subview op with MemRefCast arguments.
/// This essentially pushes memref.cast past its consuming subview when
@@ -1897,13 +1914,18 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
if (!CastOp::canFoldIntoConsumerOp(castOp))
return failure();
- /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
- /// the cast source operand type and the SubViewOp static information. This
- /// is the resulting type if the MemRefCastOp were folded.
+ // Compute the SubViewOp result type after folding the MemRefCastOp. Use the
+ // MemRefCastOp source operand type to infer the result type and the current
+ // SubViewOp source operand type to compute the dropped dimensions if the
+ // operation is rank-reducing.
auto resultType = getCanonicalSubViewResultType(
- subViewOp.getType(), castOp.source().getType().cast<MemRefType>(),
+ subViewOp.getType(), subViewOp.getSourceType(),
+ castOp.source().getType().cast<MemRefType>(),
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
subViewOp.getMixedStrides());
+ if (!resultType)
+ return failure();
+
Value newSubView = rewriter.create<SubViewOp>(
subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a568d5f3887db..251658fac7653 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1,11 +1,11 @@
// RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s
-// CHECK-LABEL: func @subview_of_memcast
+// CHECK-LABEL: func @subview_of_size_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
-// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}>
-func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
+func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) ->
memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
%0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
%1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
@@ -16,6 +16,27 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
// -----
+// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0, d1)[s0] -> (d0 * 7 + s0 + d1)>
+// CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>
+#map1 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+#map2 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+
+// CHECK: func @subview_of_strides_memcast
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<1x1x?xf32, #{{.*}}>
+// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 0, 0] [1, 1, 4]
+// CHECK-SAME: to memref<1x4xf32, #[[MAP0]]>
+// CHECK: %[[M:.+]] = memref.cast %[[S]]
+// CHECK-SAME: to memref<1x4xf32, #[[MAP1]]>
+// CHECK: return %[[M]]
+func @subview_of_strides_memcast(%arg : memref<1x1x?xf32, #map0>) -> memref<1x4xf32, #map2> {
+ %0 = memref.cast %arg : memref<1x1x?xf32, #map0> to memref<1x1x?xf32, #map1>
+ %1 = memref.subview %0[0, 0, 0] [1, 1, 4] [1, 1, 1] : memref<1x1x?xf32, #map1> to memref<1x4xf32, #map2>
+ return %1 : memref<1x4xf32, #map2>
+}
+
+// -----
+
// CHECK-LABEL: func @subview_of_static_full_size
// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
// CHECK-NOT: memref.subview
More information about the Mlir-commits
mailing list