[Mlir-commits] [mlir] fd15e2b - [mlir][Linalg] Use rank-reduced versions of subtensor and subtensor insert when possible.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 3 12:51:37 PDT 2021


Author: MaheshRavishankar
Date: 2021-05-03T12:51:24-07:00
New Revision: fd15e2b825f26dd7eac3b4a52aab36c88e52850a

URL: https://github.com/llvm/llvm-project/commit/fd15e2b825f26dd7eac3b4a52aab36c88e52850a
DIFF: https://github.com/llvm/llvm-project/commit/fd15e2b825f26dd7eac3b4a52aab36c88e52850a.diff

LOG: [mlir][Linalg] Use rank-reduced versions of subtensor and subtensor insert when possible.

Convert subtensor and subtensor_insert operations to use their
rank-reduced versions to drop unit dimensions.

Differential Revision: https://reviews.llvm.org/D101495

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
    mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/MemRef/fold-subview-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 18be136ac6d18..d98d510a134a2 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -18,7 +18,9 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
     from/to the original memref.
   }];
   let constructor = "mlir::memref::createFoldSubViewOpsPass()";
-  let dependentDialects = ["memref::MemRefDialect", "vector::VectorDialect"];
+  let dependentDialects = [
+      "AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
+  ];
 }
 
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 5d8a664ef9646..f9320f358ab54 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -544,77 +544,87 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
     return success();
   }
 };
+} // namespace
 
-/// Pattern to fold subtensors that are just taking a slice of unit-dimension
-/// tensor. For example
-///
-/// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1]
-///     : tensor<1x?x1xf32> to tensor<1x?x1xf32>
-///
-/// can be replaced with
-///
-/// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-///     : tensor<1x?x1xf32> into tensor<?xf32>
-/// %1 = subtensor %0[%o1] [%s1] [1] : tensor<?xf32> to tensor<?xf32>
-/// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-///     : tensor<?xf32> into tensor<1x?x1xf32>
-///
-/// The additional tensor_reshapes will hopefully get canonicalized away with
-/// other reshapes that drop unit dimensions. Three condiitions to fold a
-/// dimension
-/// - The offset must be 0
-/// - The size must be 1
-/// - The dimension of the source type must be 1.
-struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
+/// Get the reassociation maps to fold the result of a subtensor (or source of a
+/// subtensor_insert) operation with given offsets, and sizes to its
+/// rank-reduced version. This is only done for the cases where the size is 1
+/// and offset is 0. Strictly speaking the offset 0 is not required in general,
+/// but non-zero offsets are not handled by SPIR-V backend at this point (and
+/// potentially cannot be handled).
+static Optional<SmallVector<ReassociationIndices>>
+getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
+  SmallVector<ReassociationIndices> reassociation;
+  ReassociationIndices curr;
+  for (auto it : llvm::enumerate(mixedSizes)) {
+    auto dim = it.index();
+    auto size = it.value();
+    curr.push_back(dim);
+    auto attr = size.dyn_cast<Attribute>();
+    if (attr && attr.cast<IntegerAttr>().getInt() == 1)
+      continue;
+    reassociation.emplace_back(ReassociationIndices{});
+    std::swap(reassociation.back(), curr);
+  }
+  if (!curr.empty())
+    reassociation.back().append(curr.begin(), curr.end());
+  return reassociation;
+}
+
+namespace {
+/// Convert `subtensor` operations to rank-reduced versions.
+struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> {
   using OpRewritePattern<SubTensorOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
                                 PatternRewriter &rewriter) const override {
-    SmallVector<OpFoldResult> mixedOffsets = subTensorOp.getMixedOffsets();
-    SmallVector<OpFoldResult> mixedSizes = subTensorOp.getMixedSizes();
-    SmallVector<OpFoldResult> mixedStrides = subTensorOp.getMixedStrides();
-    auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) {
-      auto attr = valueOrAttr.dyn_cast<Attribute>();
-      return attr && attr.cast<IntegerAttr>().getInt() == val;
-    };
-
-    if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) {
-          return !hasValue(valueOrAttr, 1);
-        }))
+    RankedTensorType resultType = subTensorOp.getType();
+    SmallVector<OpFoldResult> offsets = subTensorOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = subTensorOp.getMixedSizes();
+    SmallVector<OpFoldResult> strides = subTensorOp.getMixedStrides();
+    auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+    if (!reassociation ||
+        reassociation->size() == static_cast<size_t>(resultType.getRank()))
       return failure();
+    auto rankReducedType =
+        SubTensorOp::inferRankReducedResultType(reassociation->size(),
+                                                subTensorOp.getSourceType(),
+                                                offsets, sizes, strides)
+            .cast<RankedTensorType>();
+
+    Location loc = subTensorOp.getLoc();
+    Value newSubTensor = rewriter.create<SubTensorOp>(
+        loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides);
+    rewriter.replaceOpWithNewOp<TensorReshapeOp>(subTensorOp, resultType,
+                                                 newSubTensor, *reassociation);
+    return success();
+  }
+};
 
-    // Find the expanded unit dimensions.
-    SmallVector<ReassociationIndices> reassociation;
-    SmallVector<OpFoldResult> newOffsets, newSizes;
-    ArrayRef<int64_t> sourceShape = subTensorOp.getSourceType().getShape();
-    ReassociationIndices curr;
-    for (int64_t dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
-      curr.push_back(dim);
-      if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) &&
-          hasValue(mixedSizes[dim], 1)) {
-        continue;
-      }
-      newOffsets.push_back(mixedOffsets[dim]);
-      newSizes.push_back(mixedSizes[dim]);
-      reassociation.emplace_back(ReassociationIndices{});
-      std::swap(reassociation.back(), curr);
-    }
-    if (newOffsets.size() == mixedOffsets.size())
+/// Convert `subtensor_insert` operations to rank-reduced versions.
+struct UseRankReducedSubTensorInsertOp
+    : public OpRewritePattern<SubTensorInsertOp> {
+  using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SubTensorInsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    RankedTensorType sourceType = insertOp.getSourceType();
+    SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
+    SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
+    auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+    if (!reassociation ||
+        reassociation->size() == static_cast<size_t>(sourceType.getRank()))
       return failure();
-    reassociation.back().append(curr.begin(), curr.end());
-    SmallVector<OpFoldResult> newStrides(newOffsets.size(),
-                                         rewriter.getI64IntegerAttr(1));
-    Location loc = subTensorOp->getLoc();
-    auto srcReshape = rewriter.create<TensorReshapeOp>(
-        loc, subTensorOp.source(), reassociation);
-    auto newSubTensorOp = rewriter.create<SubTensorOp>(
-        loc, srcReshape, newOffsets, newSizes, newStrides);
-    rewriter.replaceOpWithNewOp<TensorReshapeOp>(
-        subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation);
+    Location loc = insertOp.getLoc();
+    auto reshapedSource = rewriter.create<TensorReshapeOp>(
+        loc, insertOp.source(), *reassociation);
+    rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
+        insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
+        insertOp.getMixedSizes(), insertOp.getMixedStrides());
     return success();
   }
 };
-
 } // namespace
 
 /// Patterns that are used to canonicalize the use of unit-extent dims for
@@ -623,8 +633,10 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
     RewritePatternSet &patterns) {
   auto *context = patterns.getContext();
   patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
-               FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors<GenericOp>,
-               ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
+               ReplaceUnitExtentTensors<GenericOp>,
+               ReplaceUnitExtentTensors<IndexedGenericOp>,
+               UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
+      context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldReshapeOpWithUnitExtent>(context);
 }

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index cb27354f36dfe..e795a86f69d74 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   MLIRMemRefPassIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAffine
   MLIRMemRef
   MLIRPass
   MLIRStandard

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index ae76966ba25d6..4e1424083e96b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -41,27 +42,53 @@ static LogicalResult
 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
                      memref::SubViewOp subViewOp, ValueRange indices,
                      SmallVectorImpl<Value> &sourceIndices) {
-  // TODO: Aborting when the offsets are static. There might be a way to fold
-  // the subview op with load even if the offsets have been canonicalized
-  // away.
-  SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
-  if (opRanges.size() != indices.size()) {
-    // For the rank-reduced cases, we can only handle the folding when the
-    // offset is zero, size is 1 and stride is 1.
-    return failure();
+  SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
+  SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
+  SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
+
+  SmallVector<Value> useIndices;
+  // Check if this is rank-reducing case. Then for every unit-dim size add a
+  // zero to the indices.
+  ArrayRef<int64_t> resultShape = subViewOp.getType().getShape();
+  unsigned resultDim = 0;
+  for (auto size : llvm::enumerate(mixedSizes)) {
+    auto attr = size.value().dyn_cast<Attribute>();
+    // Check if this dimension has been dropped, i.e. the size is 1, but the
+    // associated dimension is not 1.
+    if (attr && attr.cast<IntegerAttr>().getInt() == 1 &&
+        (resultDim >= resultShape.size() || resultShape[resultDim] != 1))
+      useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
+    else if (resultDim < resultShape.size()) {
+      useIndices.push_back(indices[resultDim++]);
+    }
   }
-  auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
-  auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
-
-  // New indices for the load are the current indices * subview_stride +
-  // subview_offset.
-  sourceIndices.resize(indices.size());
-  for (auto index : llvm::enumerate(indices)) {
-    auto offset = *(opOffsets.begin() + index.index());
-    auto stride = *(opStrides.begin() + index.index());
-    auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
-    sourceIndices[index.index()] =
-        rewriter.create<AddIOp>(loc, offset, mul).getResult();
+  if (useIndices.size() != mixedOffsets.size())
+    return failure();
+  sourceIndices.resize(useIndices.size());
+  for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
+    SmallVector<Value> dynamicOperands;
+    AffineExpr expr = rewriter.getAffineDimExpr(0);
+    unsigned numSymbols = 0;
+    dynamicOperands.push_back(useIndices[index]);
+
+    // Multiply the stride;
+    if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
+      expr = expr * attr.cast<IntegerAttr>().getInt();
+    } else {
+      dynamicOperands.push_back(mixedStrides[index].get<Value>());
+      expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
+    }
+
+    // Add the offset.
+    if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
+      expr = expr + attr.cast<IntegerAttr>().getInt();
+    } else {
+      dynamicOperands.push_back(mixedOffsets[index].get<Value>());
+      expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
+    }
+    Location loc = subViewOp.getLoc();
+    sourceIndices[index] = rewriter.create<AffineApplyOp>(
+        loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
   }
   return success();
 }

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index e9dd74faad64b..2c6ab57782dd2 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -476,67 +476,32 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
 // -----
 
 func @fold_subtensor(
-    %arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : index, %arg2 : index,
-    %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index)
-    -> tensor<1x?x?x1x?x1x1xf32> {
-  %0 = subtensor %arg0[0, %arg1, %arg2, 0, %arg3, 0, 0]
-                      [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
+    %arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : tensor<1x?x?x?x?x1x1xf32>,
+    %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index,
+    %arg6 : index, %arg7 : index) -> (tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32>) {
+  %0 = subtensor %arg0[0, %arg2, %arg3, 0, %arg4, 0, 0]
+                      [1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
       tensor<1x?x?x1x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
-  return %0 : tensor<1x?x?x1x?x1x1xf32>
+  %1 = subtensor %arg1[%arg2, 0, %arg3, 0, 0, %arg4, 0]
+                      [1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
+      tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
+  return %0, %1 : tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32>
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
 //      CHECK: func @fold_subtensor
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x?x1x?x1x1xf32>
-// CHECK-SAME:   %[[ARG1:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG2:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG3:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG4:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG5:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG6:[a-z0-9]+]]: index
-//      CHECK:   %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:   %[[ARG1:.+]]: tensor<1x?x?x?x?x1x1xf32>
+//      CHECK:   %[[SUBTENSOR1:.+]] = subtensor %[[ARG0]]
+// CHECK-SAME:       to tensor<?x?x?xf32>
+//      CHECK:   %[[RESULT1:.+]] = linalg.tensor_reshape %[[SUBTENSOR1]]
 // CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-//      CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]]
-// CHECK-SAME:       [%[[ARG1]], %[[ARG2]], %[[ARG3]]]
-// CHECK-SAME:       [%[[ARG4]], %[[ARG5]], %[[ARG6]]]
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
+//      CHECK:   %[[SUBTENSOR2:.+]] = subtensor %[[ARG1]]
+// CHECK-SAME:       to tensor<?x?x?xf32>
+//      CHECK:   %[[RESULT2:.+]] = linalg.tensor_reshape %[[SUBTENSOR2]]
 // CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-//      CHECK:   return %[[RESULT_RESHAPE]]
-
-// -----
-
-func @no_fold_subtensor(
-    %arg0 : tensor<1x?x?x?x?x1x1xf32>, %arg1 : index, %arg2 : index,
-    %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index)
-    -> tensor<1x?x?x1x?x1x1xf32> {
-  %0 = subtensor %arg0[%arg1, 0, %arg2, 0, 0, %arg3, 0]
-                      [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
-      tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
-  return %0 : tensor<1x?x?x1x?x1x1xf32>
-}
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1)>
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)>
-//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3)>
-//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4)>
-//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6)>
-//      CHECK: func @no_fold_subtensor
-// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x?x?x?x1x1xf32>
-// CHECK-SAME:   %[[ARG1:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG2:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG3:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG4:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG5:[a-z0-9]+]]: index
-// CHECK-SAME:   %[[ARG6:[a-z0-9]+]]: index
-//      CHECK:   %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
-// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
-//      CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]]
-// CHECK-SAME:       [%[[ARG1]], 0, %[[ARG2]], 0, 0, %[[ARG3]]]
-// CHECK-SAME:       [1, %[[ARG4]], %[[ARG5]], 1, %[[ARG6]], 1]
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
-// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
-//      CHECK:   return %[[RESULT_RESHAPE]]
+//      CHECK:   return %[[RESULT1]], %[[RESULT2]]
 
 // -----
 

diff  --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
index 2cddeb93dc301..246c0b3552947 100644
--- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
@@ -1,99 +1,162 @@
-// RUN: mlir-opt -fold-memref-subview-ops -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -fold-memref-subview-ops -split-input-file %s -o - | FileCheck %s
 
-// CHECK-LABEL: @fold_static_stride_subview_with_load
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
 func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
-  // CHECK-NOT: memref.subview
-  // CHECK: [[C2:%.*]] = constant 2 : index
-  // CHECK: [[C3:%.*]] = constant 3 : index
-  // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
-  // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
-  // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
-  // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
-  // CHECK: memref.load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
   %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
   return %1 : f32
 }
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * 3 + s0)>
+//      CHECK: func @fold_static_stride_subview_with_load
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP0]](%[[ARG3]])[%[[ARG1]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP1]](%[[ARG4]])[%[[ARG2]]]
+//      CHECK:   memref.load %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
 
-// CHECK-LABEL: @fold_dynamic_stride_subview_with_load
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index
 func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> f32 {
-  // CHECK-NOT: memref.subview
-  // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[ARG5]] : index
-  // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
-  // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
-  // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
-  // CHECK: memref.load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
     memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
   %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
   return %1 : f32
 }
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+//      CHECK: func @fold_dynamic_stride_subview_with_load
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG3]])[%[[ARG5]], %[[ARG1]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG4]])[%[[ARG6]], %[[ARG2]]]
+//      CHECK:   memref.load %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
 
-// CHECK-LABEL: @fold_static_stride_subview_with_store
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: f32
 func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
-  // CHECK-NOT: memref.subview
-  // CHECK: [[C2:%.*]] = constant 2 : index
-  // CHECK: [[C3:%.*]] = constant 3 : index
-  // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
-  // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
-  // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
-  // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
-  // CHECK: memref.store [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
     memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
   memref.store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
   return
 }
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * 3 + s0)>
+//      CHECK: func @fold_static_stride_subview_with_store
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP0]](%[[ARG3]])[%[[ARG1]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP1]](%[[ARG4]])[%[[ARG2]]]
+//      CHECK:   memref.store %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
 
-// CHECK-LABEL: @fold_dynamic_stride_subview_with_store
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: f32
 func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : f32) {
-  // CHECK-NOT: memref.subview
-  // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[ARG5]] : index
-  // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
-  // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
-  // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
-  // CHECK: memref.store [[ARG7]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
     memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
   memref.store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
   return
 }
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+//      CHECK: func @fold_dynamic_stride_subview_with_store
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG3]])[%[[ARG5]], %[[ARG1]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG4]])[%[[ARG6]], %[[ARG2]]]
+//      CHECK:   memref.store %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
 
-// CHECK-LABEL: @fold_static_stride_subview_with_transfer_read
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
-func @fold_static_stride_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> vector<4xf32> {
-  // CHECK-NOT: memref.subview
-  // CHECK-DAG: [[F1:%.*]] = constant 1.000000e+00 : f32
-  // CHECK-DAG: [[C2:%.*]] = constant 2 : index
-  // CHECK-DAG: [[C3:%.*]] = constant 3 : index
-  // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
-  // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
-  // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
-  // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
-  // CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}, [[F1]] {in_bounds = [true]}
+func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> {
   %f1 = constant 1.0 : f32
-  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
-  %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32>
+  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+  %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, offset:?, strides: [?, ?]>, vector<4xf32>
   return %1 : vector<4xf32>
 }
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+//      CHECK: func @fold_subview_with_transfer_read
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG3]])[%[[ARG5]], %[[ARG1]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG4]])[%[[ARG6]], %[[ARG2]]]
+//      CHECK:   vector.transfer_read %[[ARG0]][%[[I1]], %[[I2]]]
 
-// CHECK-LABEL: @fold_static_stride_subview_with_transfer_write
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: vector<4xf32>
-func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : vector<4xf32>) {
-  // CHECK-NOT: memref.subview
-  // CHECK: [[C2:%.*]] = constant 2 : index
-  // CHECK: [[C3:%.*]] = constant 3 : index
-  // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
-  // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
-  // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
-  // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
-  // CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} {in_bounds = [true]}
-  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
-    memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
-  vector.transfer_write %arg5, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]>
+// -----
+
+func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5: index, %arg6 : index, %arg7 : vector<4xf32>) {
+  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
+    memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+  vector.transfer_write %arg7, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, offset:?, strides: [?, ?]>
   return
 }
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+//      CHECK: func @fold_static_stride_subview_with_transfer_write
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG3]])[%[[ARG5]], %[[ARG1]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG4]])[%[[ARG6]], %[[ARG2]]]
+//      CHECK:   vector.transfer_write %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
+
+func @fold_rank_reducing_subview_with_load
+    (%arg0 : memref<?x?x?x?x?x?xf32>, %arg1 : index, %arg2 : index,
+     %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+     %arg7 : index, %arg8 : index, %arg9 : index, %arg10: index,
+     %arg11 : index, %arg12 : index, %arg13 : index, %arg14: index,
+     %arg15 : index, %arg16 : index) -> f32 {
+  %0 = memref.subview %arg0[%arg1, %arg2, %arg3, %arg4, %arg5, %arg6][4, 1, 1, 4, 1, 1][%arg7, %arg8, %arg9, %arg10, %arg11, %arg12] : memref<?x?x?x?x?x?xf32> to memref<4x1x4x1xf32, offset:?, strides: [?, ?, ?, ?]>
+  %1 = memref.load %0[%arg13, %arg14, %arg15, %arg16] : memref<4x1x4x1xf32, offset:?, strides: [?, ?, ?, ?]>
+  return %1 : f32
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+//      CHECK: func @fold_rank_reducing_subview_with_load
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?x?x?x?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG7:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG8:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG9:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG10:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG11:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG12:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG13:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG14:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG15:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG16:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG13]])[%[[ARG7]], %[[ARG1]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG14]])[%[[ARG8]], %[[ARG2]]]
+//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG9]], %[[ARG3]]]
+//  CHECK-DAG:   %[[I4:.+]] = affine.apply #[[MAP]](%[[ARG15]])[%[[ARG10]], %[[ARG4]]]
+//  CHECK-DAG:   %[[I5:.+]] = affine.apply #[[MAP]](%[[ARG16]])[%[[ARG11]], %[[ARG5]]]
+//  CHECK-DAG:   %[[I6:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG12]], %[[ARG6]]]
+//      CHECK:   memref.load %[[ARG0]][%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]]


        


More information about the Mlir-commits mailing list