[Mlir-commits] [mlir] 4cf9bf6 - [mlir][MemRef] Compute unused dimensions of a rank-reducing subviews using strides as well.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 20 11:05:55 PDT 2021


Author: MaheshRavishankar
Date: 2021-09-20T11:05:30-07:00
New Revision: 4cf9bf6c9f64cca1111134acc9f84efe8f27e8d1

URL: https://github.com/llvm/llvm-project/commit/4cf9bf6c9f64cca1111134acc9f84efe8f27e8d1
DIFF: https://github.com/llvm/llvm-project/commit/4cf9bf6c9f64cca1111134acc9f84efe8f27e8d1.diff

LOG: [mlir][MemRef] Compute unused dimensions of a rank-reducing subviews using strides as well.

For `memref.subview` operations, when there are more than one
unit-dimensions, the strides need to be used to figure out which of
the unit-dims are actually dropped.

Differential Revision: https://reviews.llvm.org/D109418

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Interfaces/ViewLikeInterface.td
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/MemRef/fold-subview-ops.mlir
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index e0cb3816efafe..dd8455a7f9190 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1379,6 +1379,10 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
     /// Return the number of leading operands before the `offsets`, `sizes` and
     /// and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+
+    /// Return the dimensions of the source type that are dropped when
+    /// the result is rank-reduced.
+    llvm::SmallDenseSet<unsigned> getDroppedDims();
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index f5227361165b3..50ebeaa44a5c3 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -66,7 +66,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
   let cppNamespace = "::mlir";
 
   let methods = [
-    InterfaceMethod<
+    StaticInterfaceMethod<
       /*desc=*/[{
         Return the number of leading operands before the `offsets`, `sizes` and
         and `strides` operands.

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index ebca204ab8486..85c05f04b07f3 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1272,12 +1272,8 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
         extracted);
     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
-    auto shape = viewMemRefType.getShape();
-    auto inferredShape = inferredType.getShape();
-    size_t inferredShapeRank = inferredShape.size();
-    size_t resultShapeRank = shape.size();
-    llvm::SmallDenseSet<unsigned> unusedDims =
-        computeRankReductionMask(inferredShape, shape).getValue();
+    size_t inferredShapeRank = inferredType.getRank();
+    size_t resultShapeRank = viewMemRefType.getRank();
 
     // Extract strides needed to compute offset.
     SmallVector<Value, 4> strideValues;
@@ -1315,6 +1311,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
     assert(mixedSizes.size() == mixedStrides.size() &&
            "expected sizes and strides of equal length");
+    llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
          i >= 0 && j >= 0; --i) {
       if (unusedDims.contains(i))

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f0bf8c639e4a4..f80d373c41e0d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -690,6 +690,92 @@ static LogicalResult verify(DimOp op) {
   return success();
 }
 
+/// Return a map with key being elements in `vals` and data being number of
+/// occurences of it. Use std::map, since the `vals` here are strides and the
+/// dynamic stride value is the same as the tombstone value for
+/// `DenseMap<int64_t>`.
+static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
+  std::map<int64_t, unsigned> numOccurences;
+  for (auto val : vals)
+    numOccurences[val]++;
+  return numOccurences;
+}
+
+/// Given the type of the un-rank reduced subview result type and the
+/// rank-reduced result type, computes the dropped dimensions. This accounts for
+/// cases where there are multiple unit-dims, but only a subset of those are
+/// dropped. For MemRefTypes these can be disambiguated using the strides. If a
+/// dimension is dropped the stride must be dropped too.
+static llvm::Optional<llvm::SmallDenseSet<unsigned>>
+computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
+                               ArrayAttr staticSizes) {
+  llvm::SmallDenseSet<unsigned> unusedDims;
+  if (originalType.getRank() == reducedType.getRank())
+    return unusedDims;
+
+  for (auto dim : llvm::enumerate(staticSizes))
+    if (dim.value().cast<IntegerAttr>().getInt() == 1)
+      unusedDims.insert(dim.index());
+  SmallVector<int64_t> originalStrides, candidateStrides;
+  int64_t originalOffset, candidateOffset;
+  if (failed(
+          getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
+      failed(
+          getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
+    return llvm::None;
+
+  // For memrefs, a dimension is truly dropped if its corresponding stride is
+  // also dropped. This is particularly important when more than one of the dims
+  // is 1. Track the number of occurences of the strides in the original type
+  // and the candidate type. For each unused dim that stride should not be
+  // present in the candidate type. Note that there could be multiple dimensions
+  // that have the same size. We dont need to exactly figure out which dim
+  // corresponds to which stride, we just need to verify that the number of
+  // reptitions of a stride in the original + number of unused dims with that
+  // stride == number of repititions of a stride in the candidate.
+  std::map<int64_t, unsigned> currUnaccountedStrides =
+      getNumOccurences(originalStrides);
+  std::map<int64_t, unsigned> candidateStridesNumOccurences =
+      getNumOccurences(candidateStrides);
+  llvm::SmallDenseSet<unsigned> prunedUnusedDims;
+  for (unsigned dim : unusedDims) {
+    int64_t originalStride = originalStrides[dim];
+    if (currUnaccountedStrides[originalStride] >
+        candidateStridesNumOccurences[originalStride]) {
+      // This dim can be treated as dropped.
+      currUnaccountedStrides[originalStride]--;
+      continue;
+    }
+    if (currUnaccountedStrides[originalStride] ==
+        candidateStridesNumOccurences[originalStride]) {
+      // The stride for this is not dropped. Keep as is.
+      prunedUnusedDims.insert(dim);
+      continue;
+    }
+    if (currUnaccountedStrides[originalStride] <
+        candidateStridesNumOccurences[originalStride]) {
+      // This should never happen. Cant have a stride in the reduced rank type
+      // that wasnt in the original one.
+      return llvm::None;
+    }
+  }
+
+  for (auto prunedDim : prunedUnusedDims)
+    unusedDims.erase(prunedDim);
+  if (unusedDims.size() + reducedType.getRank() != originalType.getRank())
+    return llvm::None;
+  return unusedDims;
+}
+
+llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
+  MemRefType sourceType = getSourceType();
+  MemRefType resultType = getType();
+  llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
+      computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
+  assert(unusedDims && "unable to find unused dims of subview");
+  return *unusedDims;
+}
+
 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
   // All forms of folding require a known index.
   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
@@ -725,6 +811,25 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
     return *(view.getDynamicSizes().begin() +
              memrefType.getDynamicDimIndex(unsignedIndex));
 
+  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
+    llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims();
+    unsigned resultIndex = 0;
+    unsigned sourceRank = subview.getSourceType().getRank();
+    unsigned sourceIndex = 0;
+    for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
+      if (unusedDims.count(i))
+        continue;
+      if (resultIndex == unsignedIndex) {
+        sourceIndex = i;
+        break;
+      }
+      resultIndex++;
+    }
+    assert(subview.isDynamicSize(sourceIndex) &&
+           "expected dynamic subview size");
+    return subview.getDynamicSize(sourceIndex);
+  }
+
   if (auto sizeInterface =
           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
@@ -1887,7 +1992,7 @@ enum SubViewVerificationResult {
 /// not matching dimension must be 1.
 static SubViewVerificationResult
 isRankReducedType(Type originalType, Type candidateReducedType,
-                  std::string *errMsg = nullptr) {
+                  ArrayAttr staticSizes, std::string *errMsg = nullptr) {
   if (originalType == candidateReducedType)
     return SubViewVerificationResult::Success;
   if (!originalType.isa<MemRefType>())
@@ -1908,8 +2013,11 @@ isRankReducedType(Type originalType, Type candidateReducedType,
   if (candidateReducedRank > originalRank)
     return SubViewVerificationResult::RankTooLarge;
 
+  MemRefType original = originalType.cast<MemRefType>();
+  MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
+
   auto optionalUnusedDimsMask =
-      computeRankReductionMask(originalShape, candidateReducedShape);
+      computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
 
   // Sizes cannot be matched in case empty vector is returned.
   if (!optionalUnusedDimsMask.hasValue())
@@ -1920,42 +2028,8 @@ isRankReducedType(Type originalType, Type candidateReducedType,
     return SubViewVerificationResult::ElemTypeMismatch;
 
   // Strided layout logic is relevant for MemRefType only.
-  MemRefType original = originalType.cast<MemRefType>();
-  MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
     return SubViewVerificationResult::MemSpaceMismatch;
-
-  llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
-  auto inferredType =
-      getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
-  AffineMap candidateLayout;
-  if (candidateReduced.getAffineMaps().empty())
-    candidateLayout = getStridedLinearLayoutMap(candidateReduced);
-  else
-    candidateLayout = candidateReduced.getAffineMaps().front();
-  assert(inferredType.getNumResults() == 1 &&
-         candidateLayout.getNumResults() == 1);
-  if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
-      inferredType.getNumDims() != candidateLayout.getNumDims()) {
-    if (errMsg) {
-      llvm::raw_string_ostream os(*errMsg);
-      os << "inferred type: " << inferredType;
-    }
-    return SubViewVerificationResult::AffineMapMismatch;
-  }
-  // Check that the 
diff erence of the affine maps simplifies to 0.
-  AffineExpr 
diff Expr =
-      inferredType.getResult(0) - candidateLayout.getResult(0);
-  
diff Expr = simplifyAffineExpr(
diff Expr, inferredType.getNumDims(),
-                                inferredType.getNumSymbols());
-  auto cst = 
diff Expr.dyn_cast<AffineConstantExpr>();
-  if (!(cst && cst.getValue() == 0)) {
-    if (errMsg) {
-      llvm::raw_string_ostream os(*errMsg);
-      os << "inferred type: " << inferredType;
-    }
-    return SubViewVerificationResult::AffineMapMismatch;
-  }
   return SubViewVerificationResult::Success;
 }
 
@@ -2012,7 +2086,8 @@ static LogicalResult verify(SubViewOp op) {
       extractFromI64ArrayAttr(op.static_strides()));
 
   std::string errMsg;
-  auto result = isRankReducedType(expectedType, subViewType, &errMsg);
+  auto result =
+      isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
 }
 

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index 4e1424083e96b..17ec4a1ba7fe6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -49,18 +49,13 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
   SmallVector<Value> useIndices;
   // Check if this is rank-reducing case. Then for every unit-dim size add a
   // zero to the indices.
-  ArrayRef<int64_t> resultShape = subViewOp.getType().getShape();
   unsigned resultDim = 0;
-  for (auto size : llvm::enumerate(mixedSizes)) {
-    auto attr = size.value().dyn_cast<Attribute>();
-    // Check if this dimension has been dropped, i.e. the size is 1, but the
-    // associated dimension is not 1.
-    if (attr && attr.cast<IntegerAttr>().getInt() == 1 &&
-        (resultDim >= resultShape.size() || resultShape[resultDim] != 1))
+  llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
+  for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
+    if (unusedDims.count(dim))
       useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
-    else if (resultDim < resultShape.size()) {
+    else
       useIndices.push_back(indices[resultDim++]);
-    }
   }
   if (useIndices.size() != mixedOffsets.size())
     return failure();
@@ -104,6 +99,25 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
   return op.source();
 }
 
+/// Given the permutation map of the original
+/// `vector.transfer_read`/`vector.transfer_write` operations compute the
+/// permutation map to use after the subview is folded with it.
+static AffineMap getPermutationMap(MLIRContext *context,
+                                   memref::SubViewOp subViewOp,
+                                   AffineMap currPermutationMap) {
+  llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
+  SmallVector<AffineExpr> exprs;
+  unsigned resultIdx = 0;
+  int64_t sourceRank = subViewOp.getSourceType().getRank();
+  for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
+    if (unusedDims.count(dim))
+      continue;
+    exprs.push_back(getAffineDimExpr(resultIdx++, context));
+  }
+  auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
+  return currPermutationMap.compose(resultDimToSourceDimMap);
+}
+
 //===----------------------------------------------------------------------===//
 // Patterns
 //===----------------------------------------------------------------------===//
@@ -153,7 +167,9 @@ void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
       loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
-      loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr());
+      getPermutationMap(rewriter.getContext(), subViewOp,
+                        loadOp.permutation_map()),
+      loadOp.padding(), loadOp.in_boundsAttr());
 }
 
 template <>
@@ -170,7 +186,9 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
-      sourceIndices, transferWriteOp.permutation_map(),
+      sourceIndices,
+      getPermutationMap(rewriter.getContext(), subViewOp,
+                        transferWriteOp.permutation_map()),
       transferWriteOp.in_boundsAttr());
 }
 } // namespace

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 718dd2f9a789f..747471623c248 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1418,3 +1418,28 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
 //       CHECKPARALLEL:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
 //       CHECKPARALLEL:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKPARALLEL:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+
+// -----
+
+func @lower_to_loops_with_rank_reducing_subviews(
+    %arg0 : memref<?xi32>, %arg1 : memref<?x?xi32>, %arg2 : index,
+    %arg3 : index, %arg4 : index) {
+  %0 = memref.subview %arg0[%arg2] [%arg3] [1]
+      : memref<?xi32> to memref<?xi32, offset: ?, strides: [1]>
+  %1 = memref.subview %arg1[0, %arg4] [1, %arg3] [1, 1]
+      : memref<?x?xi32> to memref<?xi32, offset: ?, strides : [1]>
+  linalg.copy(%0, %1)
+      : memref<?xi32, offset: ?, strides: [1]>, memref<?xi32, offset: ?, strides: [1]>
+  return
+}
+// CHECK-LABEL: func @lower_to_loops_with_rank_reducing_subviews
+//       CHECK:   scf.for %[[IV:.+]] = %{{.+}} to %{{.+}} step %{{.+}} {
+//       CHECK:     %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]]
+//       CHECK:     memref.store %[[VAL]], %{{.+}}[%[[IV]]]
+//       CHECK:   }
+
+// CHECKPARALLEL-LABEL: func @lower_to_loops_with_rank_reducing_subviews
+//       CHECKPARALLEL:   scf.parallel (%[[IV:.+]]) = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) {
+//       CHECKPARALLEL:     %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]]
+//       CHECKPARALLEL:     memref.store %[[VAL]], %{{.+}}[%[[IV]]]
+//       CHECKPARALLEL:   }

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d73bb136025ce..ec57845a00f14 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -159,6 +159,63 @@ func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : inde
 //       CHECK:   %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
 //       CHECK:   return %[[RESULT]]
 
+// -----
+
+func @multiple_reducing_dims(%arg0 : memref<1x384x384xf32>,
+    %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, offset: ?, strides: [1]>
+{
+  %c1 = constant 1 : index
+  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref<?x?xf32, offset: ?, strides: [384, 1]>
+  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, offset: ?, strides: [384, 1]> to memref<?xf32, offset: ?, strides: [1]>
+  return %1 : memref<?xf32, offset: ?, strides: [1]>
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)>
+//       CHECK: func @multiple_reducing_dims
+//       CHECK:   %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:       : memref<1x384x384xf32> to memref<1x?xf32, #[[MAP1]]>
+//       CHECK:   %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
+//  CHECK-SAME:       : memref<1x?xf32, #[[MAP1]]> to memref<?xf32, #[[MAP0]]>
+
+// -----
+
+func @multiple_reducing_dims_dynamic(%arg0 : memref<?x?x?xf32>,
+    %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, offset: ?, strides: [1]>
+{
+  %c1 = constant 1 : index
+  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?xf32, offset: ?, strides: [1]>
+  return %1 : memref<?xf32, offset: ?, strides: [1]>
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//       CHECK: func @multiple_reducing_dims_dynamic
+//       CHECK:   %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:       : memref<?x?x?xf32> to memref<1x?xf32, #[[MAP1]]>
+//       CHECK:   %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
+//  CHECK-SAME:       : memref<1x?xf32, #[[MAP1]]> to memref<?xf32, #[[MAP0]]>
+
+// -----
+
+func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>,
+    %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, offset: ?, strides: [?]>
+{
+  %c1 = constant 1 : index
+  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1]
+      : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?xf32, offset: ?, strides: [?]>
+  return %1 : memref<?xf32, offset: ?, strides: [?]>
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+//       CHECK: func @multiple_reducing_dims_all_dynamic
+//       CHECK:   %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:       : memref<?x?x?xf32, #[[MAP2]]> to memref<1x?xf32, #[[MAP1]]>
+//       CHECK:   %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
+//  CHECK-SAME:       : memref<1x?xf32, #[[MAP1]]> to memref<?xf32, #[[MAP0]]>
+
+
 // -----
 
 // CHECK-LABEL: @clone_before_dealloc
@@ -567,4 +624,3 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
   %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
   return %collapsed : memref<?x?xf32>
 }
-

diff  --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
index 246c0b3552947..558b44350af7b 100644
--- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
@@ -160,3 +160,66 @@ func @fold_rank_reducing_subview_with_load
 //  CHECK-DAG:   %[[I5:.+]] = affine.apply #[[MAP]](%[[ARG16]])[%[[ARG11]], %[[ARG5]]]
 //  CHECK-DAG:   %[[I6:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG12]], %[[ARG6]]]
 //      CHECK:   memref.load %[[ARG0]][%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]]
+
+// -----
+
+func @fold_vector_transfer_read_with_rank_reduced_subview(
+    %arg0 : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>,
+    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+    %arg6 : index) -> vector<4xf32> {
+  %cst = constant 0.0 : f32
+  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %arg3, %arg4] [1, 1, 1]
+      : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]> to
+        memref<?x?xf32, offset: ?, strides: [?, ?]>
+  %1 = vector.transfer_read %0[%arg5, %arg6], %cst {in_bounds = [true]}
+      : memref<?x?xf32, offset: ?, strides: [?, ?]>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+//       CHECK: func @fold_vector_transfer_read_with_rank_reduced_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?xf32, #[[MAP0]]>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//   CHECK-DAG:    %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]](%[[ARG5]])[%[[ARG1]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG6]])[%[[ARG2]]]
+//       CHECK:    vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]]
+//  CHECK-SAME:        permutation_map = #[[MAP2]]
+
+// -----
+
+func @fold_vector_transfer_write_with_rank_reduced_subview(
+    %arg0 : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>,
+    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+    %arg5: index, %arg6 : index, %arg7 : index) {
+  %cst = constant 0.0 : f32
+  %0 = memref.subview %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1]
+      : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]> to
+        memref<?x?xf32, offset: ?, strides: [?, ?]>
+  vector.transfer_write %arg1, %0[%arg6, %arg7] {in_bounds = [true]}
+      : vector<4xf32>, memref<?x?xf32, offset: ?, strides: [?, ?]>
+  return
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+//       CHECK: func @fold_vector_transfer_write_with_rank_reduced_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?xf32, #[[MAP0]]>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+//   CHECK-DAG:    %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]](%[[ARG6]])[%[[ARG2]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG7]])[%[[ARG3]]]
+//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]]
+//  CHECK-SAME:        permutation_map = #[[MAP2]]

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index dcd1a6b128498..b93815533119c 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -353,3 +353,12 @@ func @collapse_shape_illegal_mixed_memref_2(%arg0 : memref<?x4x5xf32>)
       : memref<?x4x5xf32> into memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }
+
+// -----
+
+func @static_stride_to_dynamic_stride(%arg0 : memref<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> memref<?x?xf32, offset:?, strides: [?, ?]> {
+  // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}}
+  %0 = memref.subview %arg0[0, 0, 0] [1, %arg1, %arg2] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+  return %0 : memref<?x?xf32, offset: ?, strides: [?, ?]>
+}

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 30b1b411d2df8..265f095fe2272 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -960,17 +960,6 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 // -----
 
-func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
-  %0 = memref.alloc() : memref<8x16x4xf32>
-  // expected-error at +1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result affine map)}}
-  %1 = memref.subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2]
-    : memref<8x16x4xf32> to
-      memref<?x?x?xf32, offset: ?, strides: [64, 4, 1]>
-  return
-}
-
-// -----
-
 func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
   %0 = memref.alloc() : memref<8x16x4xf32>
   // expected-error at +1 {{expected result element type to be 'f32'}}
@@ -1014,22 +1003,13 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
 // -----
 
 func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
-  // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map)}}
+  // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result sizes)}}
   %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
   return
 }
 
 // -----
 
-// The affine map affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)> has an extra unused symbol.
-func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
-  // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map) inferred type: (d0)[s0, s1] -> (d0 * s1 + s0)}}
-  %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32, affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)>>
-  return
-}
-
-// -----
-
 func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
   // expected-error at +1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}}
   %0 = memref.cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>


        


More information about the Mlir-commits mailing list