[Mlir-commits] [mlir] e0b99a5 - [mlir] Add SubViewOp::getOrCreateRanges and fix folding pattern

Nicolas Vasilache llvmlistbot at llvm.org
Wed May 13 07:14:42 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-13T10:11:30-04:00
New Revision: e0b99a5de4cbf00bfa46d06caf1ebf64a6456537

URL: https://github.com/llvm/llvm-project/commit/e0b99a5de4cbf00bfa46d06caf1ebf64a6456537
DIFF: https://github.com/llvm/llvm-project/commit/e0b99a5de4cbf00bfa46d06caf1ebf64a6456537.diff

LOG: [mlir] Add SubViewOp::getOrCreateRanges and fix folding pattern

The existing implementation of SubViewOp::getRanges relies on all
offsets/sizes/strides to be dynamic values and does not work in
combination with canonicalization. This revision adds a
SubViewOp::getOrCreateRanges to create the missing constants in the
canonicalized case.

This allows reactivating the fused pass with staged pattern
applications.

However another issue surfaces that the SubViewOp verifier is now too
strict to allow folding. The existing folding pattern is turned into a
canonicalization pattern which rewrites memref_cast + subview into
subview + memref_cast.

The transform-patterns-matmul-to-vector can then be reactivated.

Differential Revision: https://reviews.llvm.org/D79759

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 30b5d438dd29..e978323d0283 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2676,8 +2676,18 @@ def SubViewOp : Std_Op<"subview", [
     struct Range {
       Value offset, size, stride;
     };
-    // TODO: retire `getRanges`.
-    SmallVector<Range, 8> getRanges();
+    /// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each
+    /// Range entry contains either the dynamic value or a ConstantIndexOp
+    /// constructed with `b` at location `loc`.
+    SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
+
+    /// A subview result type can be fully inferred from the source type and the
+    /// static representation of offsets, sizes and strides. Special sentinels
+    /// encode the dynamic case.
+    static Type inferSubViewResultType(MemRefType sourceMemRefType,
+                                       ArrayRef<int64_t> staticOffsets,
+                                       ArrayRef<int64_t> staticSizes,
+                                       ArrayRef<int64_t> staticStrides);
 
     /// Return the rank of the result MemRefType.
     unsigned getRank() { return getType().getRank(); }
@@ -2750,7 +2760,6 @@ def SubViewOp : Std_Op<"subview", [
   }];
 
   let hasCanonicalizer = 1;
-  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index d541ed2a4f2d..34fe059415ff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -184,15 +184,16 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
   unsigned nWin = producer.getNumWindowLoops();
   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
 
+  OpBuilder b(consumer.getOperation());
+  auto loc = consumer.getLoc();
   // Iterate over dimensions identified by the producer map for `producerIdx`.
   // This defines a subset of the loop ranges that we need to complete later.
   for (auto en : llvm::enumerate(producerMap.getResults())) {
     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
-    loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
+    loopRanges[posInProducerLoop] =
+        subView.getOrCreateRanges(b, loc)[en.index()];
   }
 
-  OpBuilder b(consumer.getOperation());
-  auto loc = consumer.getLoc();
   // Iterate over all dimensions. For the dimensions not identified by the
   // producer map for `producerIdx`, we need to explicitly compute the view that
   // defines the loop ranges using the `producer`.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 03f8d9e3fd18..5cbaa2f426db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -153,7 +153,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
   SmallVector<Value, 8> fullSizes, partialSizes;
   fullSizes.reserve(rank);
   partialSizes.reserve(rank);
-  for (auto en : llvm::enumerate(subView.getRanges())) {
+  for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
     auto rank = en.index();
     auto rangeValue = en.value();
     // Try to extract a tight constant.
@@ -169,7 +169,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
                             dynamicBuffers, folder, alignment);
   auto fullLocalView = folded_std_view(
       folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
-      folded_std_constant_index(folder, 0), fullSizes);
+      zero, fullSizes);
   SmallVector<Value, 4> zeros(fullSizes.size(), zero);
   SmallVector<Value, 4> ones(fullSizes.size(), one);
   auto partialLocalView =

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 7ca5e79f69a0..9cd97c3b337e 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2275,10 +2275,10 @@ Wrapper operator*(Wrapper a, int64_t b) {
 /// A subview result type can be fully inferred from the source type and the
 /// static representation of offsets, sizes and strides. Special sentinels
 /// encode the dynamic case.
-static Type inferSubViewResultType(MemRefType sourceMemRefType,
-                                   ArrayRef<int64_t> staticOffsets,
-                                   ArrayRef<int64_t> staticSizes,
-                                   ArrayRef<int64_t> staticStrides) {
+Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
+                                       ArrayRef<int64_t> staticOffsets,
+                                       ArrayRef<int64_t> staticSizes,
+                                       ArrayRef<int64_t> staticStrides) {
   unsigned rank = sourceMemRefType.getRank();
   (void)rank;
   assert(staticOffsets.size() == rank &&
@@ -2474,7 +2474,7 @@ static LogicalResult verify(SubViewOp op) {
     return failure();
 
   // Verify result type against inferred type.
-  auto expectedType = inferSubViewResultType(
+  auto expectedType = SubViewOp::inferSubViewResultType(
       op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
@@ -2489,16 +2489,6 @@ raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
             << range.stride;
 }
 
-SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
-  SmallVector<Range, 8> res;
-  unsigned rank = getType().getRank();
-  res.reserve(rank);
-  for (unsigned i = 0; i < rank; ++i)
-    res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
-                           *(strides().begin() + i)});
-  return res;
-}
-
 static unsigned getNumDynamicEntriesUpToIdx(
     ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
   return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
@@ -2540,6 +2530,29 @@ unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
   return 1 + offsets().size() + sizes().size() + numDynamic;
 }
 
+/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
+/// entry contains either the dynamic value or a ConstantIndexOp constructed
+/// with `b` at location `loc`.
+SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
+                                                              Location loc) {
+  SmallVector<Range, 8> res;
+  unsigned rank = getType().getRank();
+  res.reserve(rank);
+  for (unsigned idx = 0; idx < rank; ++idx) {
+    auto offset = isDynamicOffset(idx)
+                      ? getDynamicOffset(idx)
+                      : b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
+    auto size = isDynamicSize(idx)
+                    ? getDynamicSize(idx)
+                    : b.create<ConstantIndexOp>(loc, getStaticSize(idx));
+    auto stride = isDynamicStride(idx)
+                      ? getDynamicStride(idx)
+                      : b.create<ConstantIndexOp>(loc, getStaticStride(idx));
+    res.emplace_back(Range{offset, size, stride});
+  }
+  return res;
+}
+
 LogicalResult
 SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
   if (!strides().empty())
@@ -2583,7 +2596,8 @@ void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
 }
 
 /// Pattern to rewrite a subview op with constant arguments.
-class SubViewOpFolder final : public OpRewritePattern<SubViewOp> {
+class SubViewOpConstantArgumentFolder final
+    : public OpRewritePattern<SubViewOp> {
 public:
   using OpRewritePattern<SubViewOp>::OpRewritePattern;
 
@@ -2718,27 +2732,63 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
   return true;
 }
 
-OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
-  auto folds = [](Operation *op) {
-    bool folded = false;
-    for (OpOperand &operand : op->getOpOperands()) {
-      auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
-      if (castOp && canFoldIntoConsumerOp(castOp)) {
-        operand.set(castOp.getOperand());
-        folded = true;
-      }
-    }
-    return folded ? success() : failure();
-  };
+/// Pattern to rewrite a subview op with MemRefCast arguments.
+/// This essentially pushes memref_cast past its consuming subview when
+/// `canFoldIntoConsumerOp` is true.
+///
+/// Example:
+/// ```
+///   %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
+///   %1 = subview %0[0, 0][3, 4][1, 1] :
+///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
+/// ```
+/// is rewritten into:
+/// ```
+///   %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
+///   %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
+///     memref<3x4xf32, offset:?, strides:[?, 1]>
+/// ```
+class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
+public:
+  using OpRewritePattern<SubViewOp>::OpRewritePattern;
 
-  if (succeeded(folds(*this)))
-    return getResult();
-  return {};
-}
+  LogicalResult matchAndRewrite(SubViewOp subViewOp,
+                                PatternRewriter &rewriter) const override {
+    // Any constant operand, just return to let SubViewOpConstantFolder kick in.
+    if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
+          return matchPattern(operand, m_ConstantIndex());
+        }))
+      return failure();
+
+    auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
+    if (!castOp)
+      return failure();
+
+    if (!canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
+    /// the cast source operand type and the SubViewOp static information. This
+    /// is the resulting type if the MemRefCastOp were folded.
+    Type resultType = SubViewOp::inferSubViewResultType(
+        castOp.source().getType().cast<MemRefType>(),
+        extractFromI64ArrayAttr(subViewOp.static_offsets()),
+        extractFromI64ArrayAttr(subViewOp.static_sizes()),
+        extractFromI64ArrayAttr(subViewOp.static_strides()));
+    Value newSubView = rewriter.create<SubViewOp>(
+        subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
+        subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
+        subViewOp.static_sizes(), subViewOp.static_strides());
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
+                                              newSubView);
+    return success();
+  }
+};
 
 void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
-  results.insert<SubViewOpFolder>(context);
+  results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index 29ea43aa540b..73c72ba1c6ef 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -1,7 +1,5 @@
-// TODO: this needs a fix to land before being reactivated.
-// RUN: ls
-// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
-// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
 
 func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                   %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index f97cf21c14fb..76bd6b48f543 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -941,3 +941,19 @@ func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<
   return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
 }
 
+// -----
+
+// CHECK-DAG: #[[map0:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
+// CHECK-DAG: #[[map1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+
+// CHECK-LABEL: func @memref_cast_folding_subview_static(
+func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: index)
+  -> memref<3x4xf32, offset:?, strides:[?, 1]>
+{
+  %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
+  %1 = subview %0[0, 0][3, 4][1, 1] : memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
+
+  // CHECK:  subview{{.*}}: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
+  // CHECK:  memref_cast{{.*}}: memref<3x4xf32, #[[map0]]> to memref<3x4xf32, #[[map1]]>
+  return %1: memref<3x4xf32, offset:?, strides:[?, 1]>
+}


        


More information about the Mlir-commits mailing list