[Mlir-commits] [mlir] c888a0c - [mlir][MemRef] Rewrite multi-buffering with proper composable abstractions

Nicolas Vasilache llvmlistbot at llvm.org
Wed Mar 1 07:25:39 PST 2023


Author: Nicolas Vasilache
Date: 2023-03-01T07:25:31-08:00
New Revision: c888a0ce8846e7ebf30914d4959125da80b3f566

URL: https://github.com/llvm/llvm-project/commit/c888a0ce8846e7ebf30914d4959125da80b3f566
DIFF: https://github.com/llvm/llvm-project/commit/c888a0ce8846e7ebf30914d4959125da80b3f566.diff

LOG: [mlir][MemRef] Rewrite multi-buffering with proper composable abstractions

Rewrite and document multi-buffering properly:
1. Use IndexingUtils / StaticValueUtils instead of duplicating functionality
2. Properly plumb RewriterBase through.
3. Add support
4. Better debug messages.

This revision is otherwise almost NFC, if it weren't for the extra DeallocOp
support that would previoulsy make multi-buffering fail.

Depends on: D145036

Differential Revision: https://reviews.llvm.org/D145055

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/Utils/Utils.h
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
    mlir/include/mlir/Dialect/SCF/IR/SCF.h
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/test/Dialect/MemRef/transform-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index bfb470231e986..d0dd5de078c81 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -76,18 +76,18 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
 Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                       OpFoldResult ofr);
 
-/// Create a cast from an index-like value (index or integer) to another
-/// index-like value. If the value type and the target type are the same, it
-/// returns the original value.
-Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
-                                      Type targetType, Value value);
-
 /// Similar to the other overload, but converts multiple OpFoldResults into
 /// Values.
 SmallVector<Value>
 getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                 ArrayRef<OpFoldResult> valueOrAttrVec);
 
+/// Create a cast from an index-like value (index or integer) to another
+/// index-like value. If the value type and the target type are the same, it
+/// returns the original value.
+Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
+                                      Type targetType, Value value);
+
 /// Converts a scalar value `operand` to type `toType`. If the value doesn't
 /// convert, a warning will be issued and the operand is returned as is (which
 /// will presumably yield a verification issue downstream).

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 80ecc10f4adfb..9f5b095eef8cf 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -19,6 +19,7 @@ namespace mlir {
 
 class AffineDialect;
 class ModuleOp;
+class RewriterBase;
 
 namespace arith {
 class WideIntEmulationConverter;
@@ -102,6 +103,11 @@ void populateMemRefWideIntEmulationConversions(
 ///   "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
 /// }
 /// ```
+FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
+                                       memref::AllocOp allocOp,
+                                       unsigned multiplier,
+                                       bool skipOverrideAnalysis = false);
+/// Call into `multiBuffer` with  locally constructed IRRewriter.
 FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
                                        unsigned multiplier,
                                        bool skipOverrideAnalysis = false);

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 9ae71bc73d02f..7f714d0a07646 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SCF_SCF_H
 #define MLIR_DIALECT_SCF_SCF_H
 
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index e27e8a75f37d7..96ec62d27df6f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -573,17 +573,17 @@ def ForallOp : SCF_Op<"forall", [
 
     /// Get lower bounds as values.
     SmallVector<Value> getLowerBound(OpBuilder &b) {
-      return getAsValues(b, getLoc(), getMixedLowerBound());
+      return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound());
     }
 
     /// Get upper bounds as values.
     SmallVector<Value> getUpperBound(OpBuilder &b) {
-      return getAsValues(b, getLoc(), getMixedUpperBound());
+      return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedUpperBound());
     }
 
     /// Get steps as values.
     SmallVector<Value> getStep(OpBuilder &b) {
-      return getAsValues(b, getLoc(), getMixedStep());
+      return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedStep());
     }
 
     int64_t getRank() { return getStaticLowerBound().size(); }

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 100699c7f7fd8..c97c0834f588d 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -85,12 +85,20 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
 bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
                                     ArrayRef<OpFoldResult> ofrs2);
 
-/// Helper function to convert a vector of `OpFoldResult`s into a vector of
-/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold
-/// result if it casts to  a `Value` or create an index-type constant if it
-/// casts to `IntegerAttr`. No other attribute types are supported.
-SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
-                               ArrayRef<OpFoldResult> valueOrAttrVec);
+// To convert an OpFoldResult to a Value of index type, see:
+//   mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+// TODO: find a better common landing place.
+//
+// Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
+//                                       OpFoldResult ofr);
+
+// To convert an OpFoldResult to a Value of index type, see:
+//   mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+// TODO: find a better common landing place.
+//
+// SmallVector<Value>
+// getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
+//                                 ArrayRef<OpFoldResult> valueOrAttrVec);
 
 /// Return a vector of OpFoldResults with the same size a staticValues, but
 /// all elements for which ShapedType::isDynamic is true, will be replaced by

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 50f89cfeec145..dda0f491c6aaa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -674,7 +674,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
         return !isConstantIntValue(ofr, 0);
       }));
   SmallVector<Value> materializedNonZeroNumThreads =
-      getAsValues(b, loc, nonZeroNumThreads);
+      getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
 
   // 2. Create the ForallOp with an empty region.
   scf::ForallOp forallOp = b.create<scf::ForallOp>(

diff  --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index e4b3b4523ac7a..ae721fe641a84 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -15,9 +15,13 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "llvm/Support/Debug.h"
 
 using namespace mlir;
 
+#define DEBUG_TYPE "memref-transforms"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+
 //===----------------------------------------------------------------------===//
 // MemRefMultiBufferOp
 //===----------------------------------------------------------------------===//
@@ -27,25 +31,36 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
     transform::TransformState &state) {
   SmallVector<Operation *> results;
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+  IRRewriter rewriter(getContext());
   for (auto *op : payloadOps) {
     bool canApplyMultiBuffer = true;
     auto target = cast<memref::AllocOp>(op);
+    LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
     // Skip allocations not used in a loop.
     for (Operation *user : target->getUsers()) {
+      if (isa<memref::DeallocOp>(user))
+        continue;
       auto loop = user->getParentOfType<LoopLikeOpInterface>();
       if (!loop) {
+        LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
+                   DBGS() << "----due to user: " << *user;);
         canApplyMultiBuffer = false;
         break;
       }
     }
-    if (!canApplyMultiBuffer)
+    if (!canApplyMultiBuffer) {
+      LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
       continue;
+    }
 
     auto newBuffer =
-        memref::multiBuffer(target, getFactor(), getSkipAnalysis());
-    if (failed(newBuffer))
+        memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
+
+    if (failed(newBuffer)) {
+      LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
       return emitSilenceableFailure(target->getLoc())
              << "op failed to multibuffer";
+    }
 
     results.push_back(*newBuffer);
   }

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index c93bda2e0be25..660d1022e0182 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -11,10 +11,16 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 
 using namespace mlir;
@@ -35,46 +41,52 @@ static bool overrideBuffer(Operation *op, Value buffer) {
 /// propagate the type change. Changing the memref type may require propagating
 /// it through subview ops so we cannot just do a replaceAllUse but need to
 /// propagate the type change and erase old subview ops.
-static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
-                                        OpBuilder &builder) {
-  SmallVector<Operation *> opToDelete;
+static void replaceUsesAndPropagateType(RewriterBase &rewriter,
+                                        Operation *oldOp, Value val) {
+  SmallVector<Operation *> opsToDelete;
   SmallVector<OpOperand *> operandsToReplace;
+
+  // Save the operand to replace / delete later (avoid iterator invalidation).
+  // TODO: can we use an early_inc iterator?
   for (OpOperand &use : oldOp->getUses()) {
+    // Non-subview ops will be replaced by `val`.
     auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
     if (!subviewUse) {
-      // Save the operand to and replace outside the loop to not invalidate the
-      // iterator.
       operandsToReplace.push_back(&use);
       continue;
     }
-    builder.setInsertionPoint(subviewUse);
+
+    // `subview(old_op)` is replaced by a new `subview(val)`.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(subviewUse);
     Type newType = memref::SubViewOp::inferRankReducedResultType(
         subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
         subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
         subviewUse.getStaticStrides());
-    Value newSubview = builder.create<memref::SubViewOp>(
+    Value newSubview = rewriter.create<memref::SubViewOp>(
         subviewUse->getLoc(), newType.cast<MemRefType>(), val,
         subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
         subviewUse.getMixedStrides());
-    replaceUsesAndPropagateType(subviewUse, newSubview, builder);
-    opToDelete.push_back(use.getOwner());
+
+    // Ouch recursion ... is this really necessary?
+    replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
+
+    opsToDelete.push_back(use.getOwner());
   }
-  for (OpOperand *operand : operandsToReplace)
+
+  // Perform late replacement.
+  // TODO: can we use an early_inc iterator?
+  for (OpOperand *operand : operandsToReplace) {
+    Operation *op = operand->getOwner();
+    rewriter.startRootUpdate(op);
     operand->set(val);
-  // Clean up old subview ops.
-  for (Operation *op : opToDelete)
-    op->erase();
-}
+    rewriter.finalizeRootUpdate(op);
+  }
 
-/// Helper to convert get a value from an OpFoldResult or create it at the
-/// builder insert point.
-static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
-                              Location loc) {
-  Value value = res.dyn_cast<Value>();
-  if (value)
-    return value;
-  return builder.create<arith::ConstantIndexOp>(
-      loc, res.dyn_cast<Attribute>().cast<IntegerAttr>().getInt());
+  // Perform late op erasure.
+  // TODO: can we use an early_inc iterator?
+  for (Operation *op : opsToDelete)
+    rewriter.eraseOp(op);
 }
 
 // Transformation to do multi-buffering/array expansion to remove dependencies
@@ -83,28 +95,37 @@ static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
 // This is not a pattern as it requires propagating the new memref type to its
 // uses and requires updating subview ops.
 FailureOr<memref::AllocOp>
-mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiplier,
+mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
+                          unsigned multiBufferingFactor,
                           bool skipOverrideAnalysis) {
-  LLVM_DEBUG(DBGS() << "Try multibuffer: " << allocOp << "\n");
+  LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
   DominanceInfo dom(allocOp->getParentOp());
   LoopLikeOpInterface candidateLoop;
   for (Operation *user : allocOp->getUsers()) {
     auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
     if (!parentLoop) {
-      LLVM_DEBUG(DBGS() << "Skip user: no parent loop\n");
+      if (isa<memref::DeallocOp>(user)) {
+        // Allow dealloc outside of any loop.
+        // TODO: The whole precondition function here is very brittle and will
+        // need to rethought an isolated into a cleaner analysis.
+        continue;
+      }
+      LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
+      LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
       return failure();
     }
     if (!skipOverrideAnalysis) {
       /// Make sure there is no loop-carried dependency on the allocation.
       if (!overrideBuffer(user, allocOp.getResult())) {
-        LLVM_DEBUG(DBGS() << "Skip user: found loop-carried dependence\n");
+        LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
         continue;
       }
       // If this user doesn't dominate all the other users keep looking.
       if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
             return !dom.dominates(user, otherUser);
           })) {
-        LLVM_DEBUG(DBGS() << "Skip user: does not dominate all other users\n");
+        LLVM_DEBUG(
+            DBGS() << "--Skip user: does not dominate all other users\n");
         continue;
       }
     } else {
@@ -114,17 +135,19 @@ mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiplier,
           })) {
         LLVM_DEBUG(
             DBGS()
-            << "Skip user: not all other users are in the parent loop\n");
+            << "--Skip user: not all other users are in the parent loop\n");
         continue;
       }
     }
     candidateLoop = parentLoop;
     break;
   }
+
   if (!candidateLoop) {
     LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
     return failure();
   }
+
   std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
   std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
   std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
@@ -138,51 +161,89 @@ mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiplier,
     return failure();
   }
 
-  OpBuilder builder(candidateLoop);
-  SmallVector<int64_t, 4> newShape(1, multiplier);
-  ArrayRef<int64_t> oldShape = allocOp.getType().getShape();
-  newShape.append(oldShape.begin(), oldShape.end());
-  auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(),
-                                   MemRefLayoutAttrInterface(),
-                                   allocOp.getType().getMemorySpace());
-  builder.setInsertionPoint(allocOp);
+  LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");
+
+  // 1. Construct the multi-buffered memref type.
+  ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
+  SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
+  llvm::append_range(multiBufferedShape, originalShape);
+  LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
+  MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
+                                .setShape(multiBufferedShape)
+                                .setLayout(MemRefLayoutAttrInterface());
+  LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
+
+  // 2. Create the multi-buffered alloc.
   Location loc = allocOp->getLoc();
-  auto newAlloc = builder.create<memref::AllocOp>(loc, newMemref, ValueRange{},
-                                                  allocOp->getAttrs());
-  builder.setInsertionPoint(&candidateLoop.getLoopBody().front(),
-                            candidateLoop.getLoopBody().front().begin());
-
-  SmallVector<Value> operands = {*inductionVar};
-  AffineExpr induc = getAffineDimExpr(0, allocOp.getContext());
-  unsigned dimCount = 1;
-  auto getAffineExpr = [&](OpFoldResult e) -> AffineExpr {
-    if (std::optional<int64_t> constValue = getConstantIntValue(e)) {
-      return getAffineConstantExpr(*constValue, allocOp.getContext());
-    }
-    auto value = getOrCreateValue(e, builder, candidateLoop->getLoc());
-    operands.push_back(value);
-    return getAffineDimExpr(dimCount++, allocOp.getContext());
-  };
-  auto init = getAffineExpr(*lowerBound);
-  auto step = getAffineExpr(*singleStep);
-
-  AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier;
-  auto map = AffineMap::get(dimCount, 0, expr);
-  Value bufferIndex = builder.create<AffineApplyOp>(loc, map, operands);
-  SmallVector<OpFoldResult> offsets, sizes, strides;
-  offsets.push_back(bufferIndex);
-  offsets.append(oldShape.size(), builder.getIndexAttr(0));
-  strides.assign(oldShape.size() + 1, builder.getIndexAttr(1));
-  sizes.push_back(builder.getIndexAttr(1));
-  for (int64_t size : oldShape)
-    sizes.push_back(builder.getIndexAttr(size));
-  auto dstMemref =
-      memref::SubViewOp::inferRankReducedResultType(
-          allocOp.getType().getShape(), newMemref, offsets, sizes, strides)
-          .cast<MemRefType>();
-  Value subview = builder.create<memref::SubViewOp>(loc, dstMemref, newAlloc,
-                                                    offsets, sizes, strides);
-  replaceUsesAndPropagateType(allocOp, subview, builder);
-  allocOp.erase();
-  return newAlloc;
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(allocOp);
+  auto mbAlloc = rewriter.create<memref::AllocOp>(
+      loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
+  LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
+
+  // 3. Within the loop, build the modular leading index (i.e. each loop
+  // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
+  rewriter.setInsertionPointToStart(&candidateLoop.getLoopBody().front());
+  Value ivVal = *inductionVar;
+  Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
+  Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
+  AffineExpr iv, lb, step;
+  bindDims(rewriter.getContext(), iv, lb, step);
+  Value bufferIndex = makeComposedAffineApply(
+      rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
+      {ivVal, lbVal, stepVal});
+  LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
+
+  // 4. Build the subview accessing the particular slice, taking modular
+  // rotation into account.
+  int64_t mbMemRefTypeRank = mbMemRefType.getRank();
+  IntegerAttr zero = rewriter.getIndexAttr(0);
+  IntegerAttr one = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
+  SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
+  SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
+  // Offset is [bufferIndex, 0 ... 0 ].
+  offsets.front() = bufferIndex;
+  // Sizes is [1, original_size_0 ... original_size_n ].
+  for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
+    sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
+  // Strides is [1, 1 ... 1 ].
+  auto dstMemref = memref::SubViewOp::inferRankReducedResultType(
+                       originalShape, mbMemRefType, offsets, sizes, strides)
+                       .cast<MemRefType>();
+  Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
+                                                     offsets, sizes, strides);
+  LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
+
+  // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
+  // handle dealloc uses separately..
+  for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
+    auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
+    if (!deallocOp)
+      continue;
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(deallocOp);
+    auto newDeallocOp =
+        rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
+    (void)newDeallocOp;
+    LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
+    rewriter.eraseOp(deallocOp);
+  }
+
+  // 6. RAUW with the particular slice, taking modular rotation into account.
+  replaceUsesAndPropagateType(rewriter, allocOp, subview);
+
+  // 7. Finally, erase the old allocOp.
+  rewriter.eraseOp(allocOp);
+
+  return mbAlloc;
+}
+
+FailureOr<memref::AllocOp>
+mlir::memref::multiBuffer(memref::AllocOp allocOp,
+                          unsigned multiBufferingFactor,
+                          bool skipOverrideAnalysis) {
+  IRRewriter rewriter(allocOp->getContext());
+  return multiBuffer(rewriter, allocOp, multiBufferingFactor,
+                     skipOverrideAnalysis);
 }

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 79c14f9b3288a..2abe18087107e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -922,9 +922,10 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
         continue;
       // If the dest type of the cast does not preserve static information in
       // the source type.
-      if (!tensor::preservesStaticInformation(incomingCast.getDest().getType(),
-                                         incomingCast.getSource().getType()))
-          continue;
+      if (!tensor::preservesStaticInformation(
+              incomingCast.getDest().getType(),
+              incomingCast.getSource().getType()))
+        continue;
       if (!std::get<1>(it).hasOneUse())
         continue;
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index df65eee6782cc..84fbae3ea9672 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
@@ -146,7 +147,8 @@ struct ReifyExpandOrCollapseShapeOp
     auto resultShape = getReshapeOutputShapeFromInputShape(
         b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
         reshapeOp.getReassociationMaps());
-    reifiedReturnShapes.push_back(getAsValues(b, loc, resultShape));
+    reifiedReturnShapes.push_back(
+        getValueOrCreateConstantIndexOp(b, loc, resultShape));
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 45ea541660fbd..5a4b7ea4f9881 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -146,18 +146,6 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
   return true;
 }
 
-/// Helper function to convert a vector of `OpFoldResult`s into a vector of
-/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result
-/// if it casts to  a `Value` or create an index-type constant if it casts to
-/// `IntegerAttr`. No other attribute types are supported.
-SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
-                               ArrayRef<OpFoldResult> valueOrAttrVec) {
-  return llvm::to_vector<4>(
-      llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
-        return getValueOrCreateConstantIndexOp(b, loc, value);
-      }));
-}
-
 /// Return a vector of OpFoldResults with the same size a staticValues, but all
 /// elements for which ShapedType::isDynamic is true, will be replaced by
 /// dynamicValues.

diff  --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index 450e642913cf1..3edcc53230fa9 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -219,3 +219,40 @@ transform.sequence failures(propagate) {
   // Verify that the returned handle is usable.
   transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
 }
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
+
+// CHECK-LABEL: func @multi_buffer_dealloc
+func.func @multi_buffer_dealloc(%in: memref<16xf32>) {
+  // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32>
+  // expected-remark @below {{transformed}}
+  %tmp = memref.alloc() : memref<4xf32>
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c16 = arith.constant 16 : index
+
+  // CHECK: scf.for %[[IV:.*]] = %[[C0]]
+  scf.for %i0 = %c0 to %c16 step %c4 {
+  // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]])
+  // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+    "some_write_read"(%tmp) : (memref<4xf32>) ->()
+  }
+
+  // CHECK-NOT: memref.dealloc {{.*}} : memref<4xf32>
+  // CHECK: memref.dealloc %[[A]] : memref<2x4xf32>
+  memref.dealloc %tmp : memref<4xf32>
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
+  %1 = transform.memref.multibuffer %0 {factor = 2 : i64, skip_analysis} : (!transform.op<"memref.alloc">) -> !pdl.operation
+  // Verify that the returned handle is usable.
+  transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
+}


        


More information about the Mlir-commits mailing list