[Mlir-commits] [mlir] 0ed2d4c - [mlir][linalg] Allow promotion to use callbacks for

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 26 21:53:51 PDT 2020


Author: MaheshRavishankar
Date: 2020-05-26T21:33:57-07:00
New Revision: 0ed2d4c7cba8fb15e51d0f6f4e9011027c17085c

URL: https://github.com/llvm/llvm-project/commit/0ed2d4c7cba8fb15e51d0f6f4e9011027c17085c
DIFF: https://github.com/llvm/llvm-project/commit/0ed2d4c7cba8fb15e51d0f6f4e9011027c17085c.diff

LOG: [mlir][linalg] Allow promotion to use callbacks for
alloc/dealloc/copies.

Add options to LinalgPromotion to use callbacks for implementating the
allocation, deallocation of buffers used for the promoted subviews,
and to copy data into and from the original subviews to the allocated
buffers.
Also some misc. cleanup of the code.

Differential Revision: https://reviews.llvm.org/D80365

Added: 
    mlir/test/Dialect/Linalg/promotion_options.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6d34a0943e5e..2da631956572 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -57,18 +57,27 @@ Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
 /// (i.e. `[1,1,2]` is an invalid permutation).
 LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
 
-/// Promotes the `subViews` into a new buffer allocated at the insertion point
-/// `b`. Promotion occurs in 3 steps:
-///   1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
-///   2. Take a full view on the buffer and `linalg.fill` it with zeros (use
-///      float zero for now).
-///   3. Take a partial slice of the full view in step 2. and copy into it.
-/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
-/// true.
-///
-/// Returns a list of PromotionInfo which hold the promoted buffer and the
-/// full and partial views indexing into the buffer.
-// TODO: revisit dynamicBuffers option.
+/// Callback function type used to perform the allocation for the promoted
+/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
+/// smallest constant value for the size of the buffer needed for each
+/// dimension. If that is not possible, contains the dynamic size of the
+/// subview. The call back should return the buffer to use.
+using AllocBufferCallbackFn = std::function<Optional<Value>(
+    OpBuilder &b, SubViewOp subView, ArrayRef<Value> boundingSubViewSize,
+    OperationFolder *folder)>;
+
+/// Callback function type used to deallocate the buffers used to hold the
+/// promoted subview.
+using DeallocBufferCallbackFn =
+    std::function<LogicalResult(OpBuilder &b, Value buffer)>;
+
+/// Callback function type used to insert copy from original subview to subview
+/// of the promoted region for the read operands/subview of promoted region to
+/// original subview for the results. The copy has to happen from `src` to
+/// `dst`.
+using CopyCallbackFn =
+    std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>;
+
 struct LinalgPromotionOptions {
   /// Indices of subViews to promote. If `None`, try to promote all operands.
   Optional<DenseSet<unsigned>> operandsToPromote = None;
@@ -111,10 +120,44 @@ struct LinalgPromotionOptions {
     alignment = align;
     return *this;
   }
+  /// Callback function to do the allocation of the promoted buffer. If None,
+  /// then the default allocation scheme of allocating a memref<?xi8> buffer
+  /// followed by a view operation is used.
+  Optional<AllocBufferCallbackFn> allocationFn = None;
+  Optional<DeallocBufferCallbackFn> deallocationFn = None;
+  LinalgPromotionOptions &
+  setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn,
+                               DeallocBufferCallbackFn const &deallocFn) {
+    allocationFn = allocFn;
+    deallocationFn = deallocFn;
+    return *this;
+  }
+
+  /// Callback function to do the copy of data to and from the promoted
+  /// subview. If None then a linalg.copy is used.
+  Optional<CopyCallbackFn> copyInFn = None;
+  Optional<CopyCallbackFn> copyOutFn = None;
+  LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const &copyIn,
+                                          CopyCallbackFn const &copyOut) {
+    copyInFn = copyIn;
+    copyOutFn = copyOut;
+    return *this;
+  }
 };
-LinalgOp promoteSubViews(OpBuilder &b, LinalgOp op,
-                         LinalgPromotionOptions options,
-                         OperationFolder *folder = nullptr);
+
+/// Promotes the `subViews` into a new buffer allocated at the insertion point
+/// `b`. Promotion occurs in 3 steps:
+///   1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
+///   2. Take a full view on the buffer.
+///   3. Take a partial slice of the full view in step 2. and copy into it.
+/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
+/// true.
+///
+/// Returns the modified linalg op (the modification happens in place) as well
+/// as all the copy ops created.
+Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
+                                   LinalgPromotionOptions options,
+                                   OperationFolder *folder = nullptr);
 
 /// Emit a suitable vector form for a Linalg op with fully static shape.
 void vectorizeLinalgOp(OpBuilder &builder, Operation *op);

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index c8a5d83438f5..235dedd60401 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -117,28 +117,6 @@ SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
                                        AffineMap map, ArrayRef<Value> values,
                                        OperationFolder *folder = nullptr);
 
-struct PromotionInfo {
-  Value buffer;
-  Value fullLocalView;
-  Value partialLocalView;
-};
-
-/// Promotes the `subViews` into a new buffer allocated at the insertion point
-/// `b`. For now, promotion occurs in 3 steps:
-///   1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
-///   2. Take a full view on the buffer and `linalg.fill` it with zeros (use
-///      float zero for now).
-///   3. Take a partial slice of the full view in step 2. and copy into it.
-/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
-/// true.
-///
-/// Returns a list of PromotionInfo which hold the promoted buffer and the
-/// full and partial views indexing into the buffer.
-SmallVector<PromotionInfo, 8>
-promoteSubViews(OpBuilder &b, Location loc, ArrayRef<Value> subViews,
-                bool dynamicBuffers = false, int64_t alignment = 0,
-                OperationFolder *folder = nullptr);
-
 /// Returns all the operands of `linalgOp` that are not views.
 /// Asserts that these operands are value types to allow transformations like
 /// tiling to just use the values when cloning `linalgOp`.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 44de2a1021c2..de8514f0fa41 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -25,8 +25,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/FoldUtils.h"
-
-#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/Support/CommandLine.h"
 
 using namespace mlir;
@@ -35,7 +34,7 @@ using namespace mlir::edsc::intrinsics;
 using namespace mlir::linalg;
 using namespace mlir::scf;
 
-using llvm::SetVector;
+using llvm::MapVector;
 
 using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
 using folded_linalg_range = FoldedValueBuilder<linalg::RangeOp>;
@@ -45,6 +44,87 @@ using folded_std_view = FoldedValueBuilder<ViewOp>;
 
 #define DEBUG_TYPE "linalg-promotion"
 
+/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
+/// is a constant then return a new value set to the smallest such constant.
+/// Otherwise return size.
+static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
+                                                 Value size) {
+  Optional<int64_t> boundingConst = {};
+  if (auto affineMinOp = size.getDefiningOp<AffineMinOp>()) {
+    for (auto e : affineMinOp.getAffineMap().getResults())
+      if (auto cst = e.dyn_cast<AffineConstantExpr>())
+        boundingConst = boundingConst
+                            ? std::min(boundingConst.getValue(), cst.getValue())
+                            : cst.getValue();
+  } else if (auto constIndexOp = size.getDefiningOp<ConstantOp>()) {
+    if (constIndexOp.getType().isa<IndexType>())
+      boundingConst = constIndexOp.value().cast<IntegerAttr>().getInt();
+  }
+  return boundingConst && *boundingConst >= 0
+             ? b.create<ConstantIndexOp>(loc, *boundingConst)
+             : size;
+}
+
+/// Alloc a new buffer of `size`. If `dynamicBuffers` is true allocate exactly
+/// the size needed, otherwise try to allocate a static bounding box.
+static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
+                         OperationFolder *folder,
+                         Optional<unsigned> alignment = None) {
+  auto *ctx = size.getContext();
+  auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
+  IntegerAttr alignment_attr;
+  if (alignment.hasValue())
+    alignment_attr =
+        IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue());
+  if (!dynamicBuffers)
+    if (auto cst = size.getDefiningOp<ConstantIndexOp>())
+      return std_alloc(
+          MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)),
+          ValueRange{}, alignment_attr);
+  Value mul =
+      folded_std_muli(folder, folded_std_constant_index(folder, width), size);
+  return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
+                   alignment_attr);
+}
+
+/// Default allocation callback function. This allocates a promoted buffer when
+/// no call back to do so is provided. The default is to allocate a
+/// memref<..xi8> and return a view to get a memref type of shape
+/// boundingSubViewSize.
+static Optional<Value>
+allocBufferCallBack(OpBuilder &builder, SubViewOp subView,
+                    ArrayRef<Value> boundingSubViewSize, bool dynamicBuffers,
+                    Optional<unsigned> alignment, OperationFolder *folder) {
+  ShapedType viewType = subView.getType();
+  int64_t rank = viewType.getRank();
+  (void)rank;
+  assert(rank > 0 && boundingSubViewSize.size() == static_cast<size_t>(rank));
+  auto zero = folded_std_constant_index(folder, 0);
+  auto one = folded_std_constant_index(folder, 1);
+
+  Value allocSize = one;
+  for (auto size : llvm::enumerate(boundingSubViewSize))
+    allocSize = folded_std_muli(folder, allocSize, size.value());
+  Value buffer = allocBuffer(viewType.getElementType(), allocSize,
+                             dynamicBuffers, folder, alignment);
+  SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(),
+                                   ShapedType::kDynamicSize);
+  Value view = folded_std_view(
+      folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
+      zero, boundingSubViewSize);
+  return view;
+}
+
+/// Default implementation of deallocation of the buffer use for promotion. It
+/// expects to get the same value that the default allocation method returned,
+/// i.e. result of a ViewOp.
+static LogicalResult deallocCallBack(OpBuilder &b, Value fullLocalView) {
+  auto viewOp = fullLocalView.getDefiningOp<ViewOp>();
+  assert(viewOp && "expected full local view to be a ViewOp");
+  std_dealloc(viewOp.source());
+  return success();
+}
+
 namespace {
 
 /// Helper struct that captures the information required to apply the
@@ -55,81 +135,65 @@ struct LinalgOpInstancePromotionOptions {
   LinalgOpInstancePromotionOptions(LinalgOp op,
                                    const LinalgPromotionOptions &options);
   /// SubViews to promote.
-  SetVector<Value> subViews;
+  MapVector<unsigned, Value> subViews;
   /// True if the full view should be used for the promoted buffer.
   DenseMap<Value, bool> useFullTileBuffers;
+
+  /// Callback functions for allocation and deallocation of promoted buffers, as
+  /// well as to copy the data into and out of these buffers.
+  AllocBufferCallbackFn allocationFn;
+  DeallocBufferCallbackFn deallocationFn;
+  CopyCallbackFn copyInFn;
+  CopyCallbackFn copyOutFn;
+
   /// Allow the use of dynamicaly-sized buffers.
   bool dynamicBuffers;
   /// Alignment of promoted buffer.
   Optional<unsigned> alignment;
 };
+
+struct PromotionInfo {
+  Value fullLocalView;
+  Value partialLocalView;
+};
 } // namespace
 
 LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
     LinalgOp linalgOp, const LinalgPromotionOptions &options)
-    : subViews(), useFullTileBuffers(), dynamicBuffers(options.dynamicBuffers),
+    : subViews(), dynamicBuffers(options.dynamicBuffers),
       alignment(options.alignment) {
   unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers();
   auto vUseFullTileBuffers =
       options.useFullTileBuffers.getValueOr(llvm::SmallBitVector());
   vUseFullTileBuffers.resize(nBuffers, options.useFullTileBuffersDefault);
 
-  if (options.operandsToPromote.hasValue()) {
-    for (auto it : llvm::enumerate(options.operandsToPromote.getValue())) {
-      auto *op = linalgOp.getBuffer(it.value()).getDefiningOp();
-      if (auto sv = dyn_cast_or_null<SubViewOp>(op)) {
-        subViews.insert(sv);
-        useFullTileBuffers[sv] = vUseFullTileBuffers[it.index()];
-      }
-    }
-  } else {
-    for (unsigned idx = 0; idx < nBuffers; ++idx) {
-      auto *op = linalgOp.getBuffer(idx).getDefiningOp();
-      if (auto sv = dyn_cast_or_null<SubViewOp>(op)) {
-        subViews.insert(sv);
-        useFullTileBuffers[sv] = vUseFullTileBuffers[idx];
-      }
+  for (unsigned idx = 0; idx != nBuffers; ++idx) {
+    if (options.operandsToPromote && !options.operandsToPromote->count(idx))
+      continue;
+    auto *op = linalgOp.getBuffer(idx).getDefiningOp();
+    if (auto sv = dyn_cast_or_null<SubViewOp>(op)) {
+      subViews[idx] = sv;
+      useFullTileBuffers[sv] = vUseFullTileBuffers[idx];
     }
   }
-}
-
-/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
-/// is a constant then return a new value set to the smallest such constant.
-/// Otherwise return size.
-static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
-                                                 Value size) {
-  auto affineMinOp = size.getDefiningOp<AffineMinOp>();
-  if (!affineMinOp)
-    return size;
-  int64_t minConst = std::numeric_limits<int64_t>::max();
-  for (auto e : affineMinOp.getAffineMap().getResults())
-    if (auto cst = e.dyn_cast<AffineConstantExpr>())
-      minConst = std::min(minConst, cst.getValue());
-  return (minConst == std::numeric_limits<int64_t>::max())
-             ? size
-             : b.create<ConstantIndexOp>(loc, minConst);
-}
 
-/// Alloc a new buffer of `size`. If `dynamicBuffers` is true allocate exactly
-/// the size needed, otherwise try to allocate a static bounding box.
-static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
-                         OperationFolder *folder,
-                         Optional<unsigned> alignment = None) {
-  auto *ctx = size.getContext();
-  auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
-  IntegerAttr alignment_attr;
-  if (alignment.hasValue())
-    alignment_attr =
-        IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue());
-  if (!dynamicBuffers)
-    if (auto cst = size.getDefiningOp<ConstantIndexOp>())
-      return std_alloc(
-          MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)),
-          ValueRange{}, alignment_attr);
-  Value mul =
-      folded_std_muli(folder, folded_std_constant_index(folder, width), size);
-  return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
-                   alignment_attr);
+  allocationFn =
+      (options.allocationFn ? *(options.allocationFn)
+                            : [&](OpBuilder &builder, SubViewOp subViewOp,
+                                  ArrayRef<Value> boundingSubViewSize,
+                                  OperationFolder *folder) -> Optional<Value> {
+        return allocBufferCallBack(builder, subViewOp, boundingSubViewSize,
+                                   dynamicBuffers, alignment, folder);
+      });
+  deallocationFn =
+      (options.deallocationFn ? *(options.deallocationFn) : deallocCallBack);
+  auto defaultCopyCallBack = [&](OpBuilder &builder, Value src,
+                                 Value dst) -> LogicalResult {
+    linalg_copy(src, dst);
+    return success();
+  };
+  copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack);
+  copyOutFn = (options.copyOutFn ? *(options.copyOutFn) : defaultCopyCallBack);
 }
 
 // Performs promotion of a `subView` into a local buffer of the size of the
@@ -149,45 +213,41 @@ static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
 // To account for general boundary effects, padding must be performed on the
 // boundary tiles. For now this is done with an unconditional `fill` op followed
 // by a partial `copy` op.
-static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
-                                               SubViewOp subView,
-                                               bool dynamicBuffers,
-                                               Optional<unsigned> alignment,
-                                               OperationFolder *folder) {
-  auto zero = folded_std_constant_index(folder, 0);
-  auto one = folded_std_constant_index(folder, 1);
-
+static Optional<PromotionInfo>
+promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView,
+                          LinalgOpInstancePromotionOptions const &options,
+                          OperationFolder *folder) {
   auto viewType = subView.getType();
   auto rank = viewType.getRank();
-  Value allocSize = one;
-  SmallVector<Value, 8> fullSizes, partialSizes;
+  SmallVector<Value, 4> fullSizes, partialSizes;
   fullSizes.reserve(rank);
   partialSizes.reserve(rank);
   for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
-    auto rank = en.index();
     auto rangeValue = en.value();
     // Try to extract a tight constant.
     LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
     Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
-    allocSize = folded_std_muli(folder, allocSize, size);
     fullSizes.push_back(size);
-    partialSizes.push_back(folded_std_dim(folder, subView, rank));
+    partialSizes.push_back(folded_std_dim(folder, subView, en.index()));
   }
   SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
-  auto buffer = allocBuffer(viewType.getElementType(), allocSize,
-                            dynamicBuffers, folder, alignment);
-  auto fullLocalView = folded_std_view(
-      folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
-      zero, fullSizes);
+  // If a callback is not specified, then use the default implementation for
+  // allocating the promoted buffer.
+  Optional<Value> fullLocalView =
+      options.allocationFn(b, subView, fullSizes, folder);
+  if (!fullLocalView)
+    return {};
+  auto zero = folded_std_constant_index(folder, 0);
+  auto one = folded_std_constant_index(folder, 1);
   SmallVector<Value, 4> zeros(fullSizes.size(), zero);
   SmallVector<Value, 4> ones(fullSizes.size(), one);
   auto partialLocalView =
-      folded_std_subview(folder, fullLocalView, zeros, partialSizes, ones);
-  return PromotionInfo{buffer, fullLocalView, partialLocalView};
+      folded_std_subview(folder, *fullLocalView, zeros, partialSizes, ones);
+  return PromotionInfo{*fullLocalView, partialLocalView};
 }
 
-static SmallVector<PromotionInfo, 8>
+static Optional<MapVector<unsigned, PromotionInfo>>
 promoteSubViews(OpBuilder &b, Location loc,
                 LinalgOpInstancePromotionOptions options,
                 OperationFolder *folder) {
@@ -195,24 +255,18 @@ promoteSubViews(OpBuilder &b, Location loc,
     return {};
 
   ScopedContext scope(b, loc);
-  SmallVector<PromotionInfo, 8> res;
-  res.reserve(options.subViews.size());
-  DenseMap<Value, PromotionInfo> promotionInfoMap;
-  for (auto v : options.subViews) {
-    SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
-    auto promotionInfo = promoteSubviewAsNewBuffer(
-        b, loc, subView, options.dynamicBuffers, options.alignment, folder);
-    promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
-    res.push_back(promotionInfo);
-  }
+  MapVector<unsigned, PromotionInfo> promotionInfoMap;
 
   for (auto v : options.subViews) {
-    SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
-    auto info = promotionInfoMap.find(v);
-    if (info == promotionInfoMap.end())
-      continue;
+    SubViewOp subView = cast<SubViewOp>(v.second.getDefiningOp());
+    Optional<PromotionInfo> promotionInfo =
+        promoteSubviewAsNewBuffer(b, loc, subView, options, folder);
+    if (!promotionInfo)
+      return {};
+    promotionInfoMap[v.first] = *promotionInfo;
+
     // Only fill the buffer if the full local view is used
-    if (!options.useFullTileBuffers[v])
+    if (!options.useFullTileBuffers[v.second])
       continue;
     Value fillVal;
     if (auto t = subView.getType().getElementType().dyn_cast<FloatType>())
@@ -220,75 +274,80 @@ promoteSubViews(OpBuilder &b, Location loc,
     else if (auto t =
                  subView.getType().getElementType().dyn_cast<IntegerType>())
       fillVal = folded_std_constant_int(folder, 0, t);
-    // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
-    // view that is 
diff erent from the partial local view and we are on the
-    // boundary.
-    linalg_fill(info->second.fullLocalView, fillVal);
+    linalg_fill(promotionInfo->fullLocalView, fillVal);
   }
 
+  // Copy data into the promoted buffers. Use callback if provided.
   for (auto v : options.subViews) {
-    auto info = promotionInfoMap.find(v);
+    auto info = promotionInfoMap.find(v.first);
     if (info == promotionInfoMap.end())
       continue;
-    linalg_copy(cast<SubViewOp>(v.getDefiningOp()),
-                info->second.partialLocalView);
+    if (failed(options.copyInFn(b, cast<SubViewOp>(v.second.getDefiningOp()),
+                                info->second.partialLocalView)))
+      return {};
   }
-  return res;
+  return promotionInfoMap;
 }
 
-static void promoteSubViews(OpBuilder &b, LinalgOp op,
-                            LinalgOpInstancePromotionOptions options,
-                            OperationFolder *folder) {
+static Optional<LinalgOp>
+promoteSubViews(OpBuilder &b, LinalgOp op,
+                LinalgOpInstancePromotionOptions options,
+                OperationFolder *folder) {
   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
 
   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
     // TODO(ntv): add a level of indirection to linalg.generic.
     if (convOp.padding())
-      llvm_unreachable("Unexpected conv with padding");
+      return {};
   }
 
   // 1. Promote the specified views and use them in the new op.
   auto loc = op.getLoc();
-  auto promotedBufferAndViews = promoteSubViews(b, loc, options, folder);
+  auto promotedBuffersAndViews = promoteSubViews(b, loc, options, folder);
+  if (!promotedBuffersAndViews ||
+      promotedBuffersAndViews->size() != options.subViews.size())
+    return {};
+
+  // 2. Append all other operands as they appear, this enforces that such
+  // operands are not views. This is to support cases such as FillOp taking
+  // extra scalars etc.  Keep a reference to output buffers;
   SmallVector<Value, 8> opViews;
   opViews.reserve(op.getNumInputsAndOutputs());
   SmallVector<std::pair<Value, Value>, 8> writebackViews;
-  writebackViews.reserve(promotedBufferAndViews.size());
-  unsigned promotedIdx = 0;
-  for (auto view : op.getInputsAndOutputBuffers()) {
-    if (options.subViews.count(view) != 0) {
-      if (options.useFullTileBuffers[view])
-        opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
+  writebackViews.reserve(promotedBuffersAndViews->size());
+  for (auto view : llvm::enumerate(op.getInputsAndOutputBuffers())) {
+    if (options.subViews.count(view.index()) != 0) {
+      if (options.useFullTileBuffers[view.value()])
+        opViews.push_back(
+            (*promotedBuffersAndViews)[view.index()].fullLocalView);
       else
-        opViews.push_back(promotedBufferAndViews[promotedIdx].partialLocalView);
-      writebackViews.emplace_back(std::make_pair(
-          view, promotedBufferAndViews[promotedIdx].partialLocalView));
-      promotedIdx++;
+        opViews.push_back(
+            (*promotedBuffersAndViews)[view.index()].partialLocalView);
+      if (view.index() >= op.getNumInputs())
+        writebackViews.emplace_back(std::make_pair(
+            view.value(),
+            (*promotedBuffersAndViews)[view.index()].partialLocalView));
     } else {
-      opViews.push_back(view);
+      opViews.push_back(view.value());
     }
   }
-
-  // 2. Append all other operands as they appear, this enforces that such
-  // operands are not views. This is to support cases such as FillOp taking
-  // extra scalars etc.
-  // Keep a reference to output buffers;
-  DenseSet<Value> originalOutputs(op.getOutputBuffers().begin(),
-                                  op.getOutputBuffers().end());
   op.getOperation()->setOperands(0, opViews.size(), opViews);
 
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPointAfter(op);
   ScopedContext scope(b, loc);
   // 3. Emit write-back for the promoted output views: copy the partial view.
-  for (auto viewAndPartialLocalView : writebackViews)
-    if (originalOutputs.count(viewAndPartialLocalView.first))
-      linalg_copy(viewAndPartialLocalView.second,
-                  viewAndPartialLocalView.first);
+  for (auto viewAndPartialLocalView : writebackViews) {
+    if (failed(options.copyOutFn(b, viewAndPartialLocalView.second,
+                                 viewAndPartialLocalView.first)))
+      return {};
+  }
 
   // 4. Dealloc all local buffers.
-  for (const auto &pi : promotedBufferAndViews)
-    std_dealloc(pi.buffer);
+  for (const auto &pi : *promotedBuffersAndViews) {
+    options.deallocationFn(b, pi.second.fullLocalView);
+  }
+  return op;
 }
 
 LogicalResult
@@ -312,13 +371,13 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
   return failure();
 }
 
-LinalgOp mlir::linalg::promoteSubViews(OpBuilder &b, LinalgOp linalgOp,
-                                       LinalgPromotionOptions options,
-                                       OperationFolder *folder) {
+Optional<LinalgOp> mlir::linalg::promoteSubViews(OpBuilder &b,
+                                                 LinalgOp linalgOp,
+                                                 LinalgPromotionOptions options,
+                                                 OperationFolder *folder) {
   LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options);
-  ::promoteSubViews(
+  return ::promoteSubViews(
       b, linalgOp, LinalgOpInstancePromotionOptions(linalgOp, options), folder);
-  return linalgOp;
 }
 
 namespace {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2ce949aa034c..527d162298bf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -179,12 +179,19 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
     return failure();
   if (failed(promoteSubviewsPrecondition(op, options)))
     return failure();
-  rewriter.updateRootInPlace(op, [&]() {
-    auto promotedOp = promoteSubViews(rewriter, op, options);
-    (void)promotedOp;
-    assert(promotedOp && "Unexpected pattern failure");
-    marker.replaceLinalgMarker(rewriter, op);
-  });
+
+  // TODO: We cannot use root update here. This pattern is creating other ops,
+  // so if the promotion fails, those need to be cleaned up, which doesnt seem
+  // to be happening here. So to fail properly, we should be cloning the op and
+  // deleting the previous op. This needs more investigation.
+  rewriter.startRootUpdate(op);
+  Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
+  if (!promotedOp) {
+    rewriter.cancelRootUpdate(op);
+    return op->emitError("subview promotion failed");
+  }
+  rewriter.finalizeRootUpdate(op);
+  marker.replaceLinalgMarker(rewriter, op);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
new file mode 100644
index 000000000000..e6c8e2158fc3
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-promotion-options -split-input-file | FileCheck %s
+
+func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
+{
+   linalg.matmul(%a, %b, %c) {__internal_linalg_transform__ = "START"}
+     : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+   return
+}
+
+//      CHECK: func @gemm
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-DAG: %[[C42:.+]] = constant 4.200000e+01 : f32
+//      CHECK: scf.for
+//      CHECK:   scf.for
+//      CHECK:     scf.for
+//      CHECK:       %[[T7:.+]] = subview %[[ARG0]]
+//      CHECK:       %[[T12:.+]] = subview %[[ARG1]]
+//      CHECK:       %[[T17:.+]] = subview %[[ARG2]]
+//      CHECK:       %[[T18:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32, 3>
+//      CHECK:       %[[T19:.+]] = subview %[[T18]]
+//      CHECK:       %[[T20:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32, 3>
+//      CHECK:       %[[T21:.+]] = subview %[[T20]]
+//      CHECK:       linalg.fill(%[[T19]], %[[C42]])
+//      CHECK:       linalg.copy(%[[T7]], %[[T19]])
+//      CHECK:       linalg.fill(%[[T21]], %[[C42]])
+//      CHECK:       linalg.copy(%[[T17]], %[[T21]])
+//      CHECK:       linalg.matmul(%[[T19]], %[[T12]], %[[T21]])
+//  CHECK-NOT:       linalg.fill
+//      CHECK:       linalg.copy(%[[T21]], %[[T17]])
+//      CHECK:       dealloc %[[T18]]
+//      CHECK:       dealloc %[[T20]]

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 7547e2953ef2..c38494fe2778 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -45,6 +45,9 @@ struct TestLinalgTransforms
           "Test a fused pass that applies patterns from matmul to vectors via "
           "2-d tiling"),
       llvm::cl::init(false)};
+  Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
+                                    llvm::cl::desc("Test promotion options"),
+                                    llvm::cl::init(false)};
 };
 } // end anonymous namespace
 
@@ -197,10 +200,77 @@ static void fillL1TilingAndMatmulToVectorPatterns(
               LinalgVectorizationPattern<CopyOp>>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// Test promotion callbacks
+//===----------------------------------------------------------------------===//
+
+// Allocation call back
+static Optional<Value> allocCallBackFn(OpBuilder &b, SubViewOp subView,
+                                       ArrayRef<Value> boundingSubViewSize,
+                                       OperationFolder *folder) {
+  SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
+  return b
+      .create<AllocOp>(subView.getLoc(),
+                       MemRefType::get(shape,
+                                       subView.getType().getElementType(),
+                                       /*affineMapComposition =*/{}, 3),
+                       boundingSubViewSize)
+      .getResult();
+}
+
+// Deallocation callback
+static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
+  b.create<DeallocOp>(buffer.getLoc(), buffer);
+  return success();
+}
+
+// Copy in call back
+static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
+                                    bool isOutput) {
+  auto floatType = src.getType().cast<MemRefType>().getElementType();
+  if (!floatType.isa<FloatType>())
+    return failure();
+  if (!isOutput)
+    b.create<FillOp>(
+        src.getLoc(), dst,
+        b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)));
+  b.create<CopyOp>(src.getLoc(), src, dst);
+  return success();
+}
+
+void fillPromotionCallBackPatterns(MLIRContext *context,
+                                   OwningRewritePatternList &patterns) {
+  patterns.insert<LinalgTilingPattern<MatmulOp>>(
+      context, LinalgTilingOptions().setTileSizes({16, 16, 16}),
+      LinalgMarker({"START"}, "PROMOTE"));
+  patterns.insert<LinalgPromotionPattern<MatmulOp>>(
+      context,
+      LinalgPromotionOptions()
+          .setOperandsToPromote({0, 2})
+          .setUseFullTileBuffers({false, false})
+          .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
+          .setCopyInOutFns(
+              [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
+                copyCallBackFn(b, src, dst, false);
+                return success();
+              },
+              [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
+                copyCallBackFn(b, src, dst, true);
+                return success();
+              }),
+      LinalgMarker({"PROMOTE"}));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnFunction() {
   if (testPatterns) {
     applyPatterns(getFunction());
+    return;
+  }
+  if (testPromotionOptions) {
+    OwningRewritePatternList patterns;
+    fillPromotionCallBackPatterns(&getContext(), patterns);
+    applyPatternsAndFoldGreedily(getFunction(), patterns);
   } else {
     SmallVector<OwningRewritePatternList, 4> stage1Patterns;
     if (testMatmulToVectorPatterns1dTiling) {


        


More information about the Mlir-commits mailing list