[Mlir-commits] [mlir] 6c85a49 - [mlir][memref] Use current source type in getCanonicalSubViewResultType.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 13 06:50:54 PST 2021


Author: gysit
Date: 2021-12-13T14:50:41Z
New Revision: 6c85a49e22021a8709d3d4f2a976b299c7cbcaa6

URL: https://github.com/llvm/llvm-project/commit/6c85a49e22021a8709d3d4f2a976b299c7cbcaa6
DIFF: https://github.com/llvm/llvm-project/commit/6c85a49e22021a8709d3d4f2a976b299c7cbcaa6.diff

LOG: [mlir][memref] Use current source type in getCanonicalSubViewResultType.

Use the current instead of the new source type to compute the rank-reduction map in getCanonicalSubViewResultType. Otherwise, the computation of the rank-reduction map fails when folding a cast into a subview since the strides of the new source type cannot be related to the strides of the current result type.

Depends On D115428

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D115446

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78b4f67f82569..4badc0b31ddb6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1833,18 +1833,23 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
   return res;
 }
 
-/// Infer the canonical type of the result of a subview operation. Returns a
-/// type with rank `resultRank` that is either the rank of the rank-reduced
-/// type, or the non-rank-reduced type.
+/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
+/// deduce the result type for the given `sourceType`. Additionally, reduce the
+/// rank of the inferred result type if `currentResultType` is lower rank than
+/// `currentSourceType`. Use this signature if `sourceType` is updated together
+/// with the result type. In this case, it is important to compute the dropped
+/// dimensions using `currentSourceType` whose strides align with
+/// `currentResultType`.
 static MemRefType getCanonicalSubViewResultType(
-    MemRefType currentResultType, MemRefType sourceType,
-    ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
-    ArrayRef<OpFoldResult> mixedStrides) {
+    MemRefType currentResultType, MemRefType currentSourceType,
+    MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
+    ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
   auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
                                                        mixedSizes, mixedStrides)
                                 .cast<MemRefType>();
   llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
-      computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes);
+      computeMemRefRankReductionMask(currentSourceType, currentResultType,
+                                     mixedSizes);
   // Return nullptr as failure mode.
   if (!unusedDims)
     return nullptr;
@@ -1861,6 +1866,18 @@ static MemRefType getCanonicalSubViewResultType(
                          nonRankReducedType.getMemorySpace());
 }
 
+/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
+/// deduce the result type. Additionally, reduce the rank of the inferred result
+/// type if `currentResultType` is lower rank than `sourceType`.
+static MemRefType getCanonicalSubViewResultType(
+    MemRefType currentResultType, MemRefType sourceType,
+    ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
+    ArrayRef<OpFoldResult> mixedStrides) {
+  return getCanonicalSubViewResultType(currentResultType, sourceType,
+                                       sourceType, mixedOffsets, mixedSizes,
+                                       mixedStrides);
+}
+
 namespace {
 /// Pattern to rewrite a subview op with MemRefCast arguments.
 /// This essentially pushes memref.cast past its consuming subview when
@@ -1897,13 +1914,18 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
     if (!CastOp::canFoldIntoConsumerOp(castOp))
       return failure();
 
-    /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
-    /// the cast source operand type and the SubViewOp static information. This
-    /// is the resulting type if the MemRefCastOp were folded.
+    // Compute the SubViewOp result type after folding the MemRefCastOp. Use the
+    // MemRefCastOp source operand type to infer the result type and the current
+    // SubViewOp source operand type to compute the dropped dimensions if the
+    // operation is rank-reducing.
     auto resultType = getCanonicalSubViewResultType(
-        subViewOp.getType(), castOp.source().getType().cast<MemRefType>(),
+        subViewOp.getType(), subViewOp.getSourceType(),
+        castOp.source().getType().cast<MemRefType>(),
         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
         subViewOp.getMixedStrides());
+    if (!resultType)
+      return failure();
+
     Value newSubView = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a568d5f3887db..251658fac7653 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1,11 +1,11 @@
 // RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s
 
-// CHECK-LABEL: func @subview_of_memcast
+// CHECK-LABEL: func @subview_of_size_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
-//       CHECK:   %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
 //       CHECK:   %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
 //       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
-func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
+func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) ->
   memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
   %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
   %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
@@ -16,6 +16,27 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
 
 // -----
 
+//   CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0, d1)[s0] -> (d0 * 7 + s0 + d1)>
+//   CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>
+#map1 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+#map2 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+
+//       CHECK: func @subview_of_strides_memcast
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<1x1x?xf32, #{{.*}}>
+//       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 0, 0] [1, 1, 4]
+//  CHECK-SAME:                    to memref<1x4xf32, #[[MAP0]]>
+//       CHECK:   %[[M:.+]] = memref.cast %[[S]]
+//  CHECK-SAME:                    to memref<1x4xf32, #[[MAP1]]>
+//       CHECK:   return %[[M]]
+func @subview_of_strides_memcast(%arg : memref<1x1x?xf32, #map0>) -> memref<1x4xf32, #map2> {
+  %0 = memref.cast %arg : memref<1x1x?xf32, #map0> to memref<1x1x?xf32, #map1>
+  %1 = memref.subview %0[0, 0, 0] [1, 1, 4] [1, 1, 1] : memref<1x1x?xf32, #map1> to memref<1x4xf32, #map2>
+  return %1 : memref<1x4xf32, #map2>
+}
+
+// -----
+
 // CHECK-LABEL: func @subview_of_static_full_size
 // CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
 // CHECK-NOT: memref.subview


        


More information about the Mlir-commits mailing list