[Mlir-commits] [mlir] 417e1c7 - [mlir][scf][bufferize][NFC] Split ForOp bufferization into smaller functions

Matthias Springer llvmlistbot at llvm.org
Thu May 5 00:56:25 PDT 2022


Author: Matthias Springer
Date: 2022-05-05T16:55:44+09:00
New Revision: 417e1c7d520c5fb1868794341e8926f8037ef2a0

URL: https://github.com/llvm/llvm-project/commit/417e1c7d520c5fb1868794341e8926f8037ef2a0
DIFF: https://github.com/llvm/llvm-project/commit/417e1c7d520c5fb1868794341e8926f8037ef2a0.diff

LOG: [mlir][scf][bufferize][NFC] Split ForOp bufferization into smaller functions

This is in preparation of WhileOp bufferization, which reuses these functions.

Differential Revision: https://reviews.llvm.org/D124933

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index cc4c7288654bc..7dade29cc32d0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -259,6 +259,150 @@ struct IfOpInterface
   }
 };
 
+/// Helper function for loop bufferization. Return the indices of all values
+/// that have a tensor type.
+static DenseSet<int64_t> getTensorIndices(ValueRange values) {
+  DenseSet<int64_t> result;
+  for (const auto &it : llvm::enumerate(values))
+    if (it.value().getType().isa<TensorType>())
+      result.insert(it.index());
+  return result;
+}
+
+/// Helper function for loop bufferization. Return the indices of all
+/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
+DenseSet<int64_t> getEquivalentBuffers(ValueRange bbArgs,
+                                       ValueRange yieldedValues,
+                                       const AnalysisState &state) {
+  DenseSet<int64_t> result;
+  int64_t counter = 0;
+  for (const auto &it : llvm::zip(bbArgs, yieldedValues)) {
+    if (!std::get<0>(it).getType().isa<TensorType>())
+      continue;
+    if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it)))
+      result.insert(counter);
+    counter++;
+  }
+  return result;
+}
+
+/// Helper function for loop bufferization. Cast the given buffer to the given
+/// memref type.
+static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
+  assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
+  assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
+  // If the buffer already has the correct type, no cast is needed.
+  if (buffer.getType() == type)
+    return buffer;
+  // TODO: In case `type` has a layout map that is not the fully dynamic
+  // one, we may not be able to cast the buffer. In that case, the loop
+  // iter_arg's layout map must be changed (see uses of `castBuffer`).
+  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
+         "scf.while op bufferization: cast incompatible");
+  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
+}
+
+/// Helper function for loop bufferization. Return the bufferized values of the
+/// given OpOperands. If an operand is not a tensor, return the original value.
+static SmallVector<Value> getBuffers(RewriterBase &rewriter,
+                                     MutableArrayRef<OpOperand> operands,
+                                     BufferizationState &state) {
+  SmallVector<Value> result;
+  for (OpOperand &opOperand : operands) {
+    if (opOperand.get().getType().isa<TensorType>()) {
+      FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
+      if (failed(resultBuffer))
+        return {};
+      result.push_back(*resultBuffer);
+    } else {
+      result.push_back(opOperand.get());
+    }
+  }
+  return result;
+}
+
+/// Helper function for loop bufferization. Compute the buffer that should be
+/// yielded from a loop block (loop body or loop condition). If the given tensor
+/// is equivalent to the corresponding block argument (as indicated by
+/// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer
+/// copy must be yielded.
+///
+/// According to the `BufferizableOpInterface` implementation of scf loops, a
+/// a bufferized OpResult may alias only with the corresponding bufferized
+/// init_arg and with no other buffers. I.e., the i-th OpResult may alias with
+/// the i-th init_arg; but not with any other OpOperand. If a corresponding
+/// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by
+/// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we
+/// cannot be sure and must yield a new buffer copy. (New buffer copies do not
+/// alias with any buffer.)
+static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
+                              BaseMemRefType type, bool isEquivalent,
+                              BufferizationState &state) {
+  assert(tensor.getType().isa<TensorType>() && "expected tensor");
+  ensureToMemrefOpIsValid(tensor, type);
+  Value yieldedVal =
+      bufferization::lookupBuffer(rewriter, tensor, state.getOptions());
+
+  if (isEquivalent)
+    // Yielded value is equivalent to the corresponding iter_arg bbArg.
+    // Yield the value directly. Most IR should be like that. Everything
+    // else must be resolved with copies and is potentially inefficient.
+    // By default, such problematic IR would already have been rejected
+    // during `verifyAnalysis`, unless `allow-return-allocs`.
+    return castBuffer(rewriter, yieldedVal, type);
+
+  // It is not certain that the yielded value and the iter_arg bbArg
+  // have the same buffer. Allocate a new buffer and copy. The yielded
+  // buffer will get deallocated by `deallocateBuffers`.
+
+  // TODO: There are cases in which it is not neccessary to return a new
+  // buffer allocation. E.g., when equivalent values are yielded in a
+  // 
diff erent order. This could be resolved with copies.
+  Optional<Value> yieldedAlloc = state.createAlloc(
+      rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false);
+  // TODO: We should rollback, but for now just assume that this always
+  // succeeds.
+  assert(yieldedAlloc.hasValue() && "could not create alloc");
+  LogicalResult copyStatus = bufferization::createMemCpy(
+      rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions());
+  (void)copyStatus;
+  assert(succeeded(copyStatus) && "could not create memcpy");
+
+  // The iter_arg memref type may have a layout map. Cast the new buffer
+  // to the same type if needed.
+  return castBuffer(rewriter, *yieldedAlloc, type);
+}
+
+/// Helper function for loop bufferization. Given a range of values, apply
+/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
+/// value in the result vector.
+static SmallVector<Value>
+convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
+                    llvm::function_ref<Value(Value, int64_t)> func) {
+  SmallVector<Value> result;
+  for (const auto &it : llvm::enumerate(values)) {
+    size_t idx = it.index();
+    Value val = it.value();
+    result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
+  }
+  return result;
+}
+
+/// Helper function for loop bufferization. Given a list of pre-bufferization
+/// yielded values, compute the list of bufferized yielded values.
+SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
+                                    TypeRange bufferizedTypes,
+                                    const DenseSet<int64_t> &tensorIndices,
+                                    const DenseSet<int64_t> &equivalentTensors,
+                                    BufferizationState &state) {
+  return convertTensorValues(
+      values, tensorIndices, [&](Value val, int64_t index) {
+        return getYieldedBuffer(rewriter, val,
+                                bufferizedTypes[index].cast<BaseMemRefType>(),
+                                equivalentTensors.contains(index), state);
+      });
+}
+
 /// Bufferization of scf.for. Replace with a new scf.for that operates on
 /// memrefs.
 struct ForOpInterface
@@ -312,78 +456,38 @@ struct ForOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto forOp = cast<scf::ForOp>(op);
-    auto bufferizableOp = cast<BufferizableOpInterface>(op);
+    auto oldYieldOp =
+        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
     Block *oldLoopBody = &forOp.getLoopBody().front();
 
-    // Helper function for casting MemRef buffers.
-    auto castBuffer = [&](Value buffer, Type type) {
-      assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
-      assert(buffer.getType().isa<BaseMemRefType>() &&
-             "expected BaseMemRefType");
-      // If the buffer already has the correct type, no cast is needed.
-      if (buffer.getType() == type)
-        return buffer;
-      // TODO: In case `type` has a layout map that is not the fully dynamic
-      // one, we may not be able to cast the buffer. In that case, the loop
-      // iter_arg's layout map must be changed (see uses of `castBuffer`).
-      assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
-             "scf.for op bufferization: cast incompatible");
-      return rewriter.create<memref::CastOp>(buffer.getLoc(), type, buffer)
-          .getResult();
-    };
-
     // Indices of all iter_args that have tensor type. These are the ones that
     // are bufferized.
-    DenseSet<int64_t> indices;
+    DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
     // For every yielded value, is the value equivalent to its corresponding
     // bbArg?
-    SmallVector<bool> equivalentYields;
-    for (const auto &it : llvm::enumerate(forOp.getInitArgs())) {
-      if (it.value().getType().isa<TensorType>()) {
-        indices.insert(it.index());
-        BufferRelation relation = bufferizableOp.bufferRelation(
-            forOp->getResult(it.index()), state.getAnalysisState());
-        equivalentYields.push_back(relation == BufferRelation::Equivalent);
-      } else {
-        equivalentYields.push_back(false);
-      }
-    }
+    DenseSet<int64_t> equivalentYields =
+        getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(),
+                             state.getAnalysisState());
 
-    // Given a range of values, apply `func` to those marked in `indices`.
-    // Otherwise, store the unmodified value in the result vector.
-    auto convert = [&](ValueRange values,
-                       llvm::function_ref<Value(Value, int64_t)> func) {
-      SmallVector<Value> result;
-      for (const auto &it : llvm::enumerate(values)) {
-        size_t idx = it.index();
-        Value val = it.value();
-        result.push_back(indices.contains(idx) ? func(val, idx) : val);
-      }
-      return result;
-    };
+    // The new memref init_args of the loop.
+    SmallVector<Value> initArgs =
+        getBuffers(rewriter, forOp.getIterOpOperands(), state);
+    if (initArgs.size() != indices.size())
+      return failure();
 
     // Construct a new scf.for op with memref instead of tensor values.
-    SmallVector<Value> initArgs;
-    for (OpOperand &opOperand : forOp.getIterOpOperands()) {
-      if (opOperand.get().getType().isa<TensorType>()) {
-        FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
-        if (failed(resultBuffer))
-          return failure();
-        initArgs.push_back(*resultBuffer);
-      } else {
-        initArgs.push_back(opOperand.get());
-      }
-    }
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), initArgs);
+    ValueRange initArgsRange(initArgs);
+    TypeRange initArgsTypes(initArgsRange);
     Block *loopBody = &newForOp.getLoopBody().front();
 
     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
     // iter_args of the new loop in ToTensorOps.
     rewriter.setInsertionPointToStart(loopBody);
-    SmallVector<Value> iterArgs =
-        convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) {
+    SmallVector<Value> iterArgs = convertTensorValues(
+        newForOp.getRegionIterArgs(), indices, [&](Value val, int64_t index) {
           return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
         });
     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
@@ -399,42 +503,8 @@ struct ForOpInterface
     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
     rewriter.setInsertionPoint(yieldOp);
     SmallVector<Value> yieldValues =
-        convert(yieldOp.getResults(), [&](Value val, int64_t index) {
-          Type initArgType = initArgs[index].getType();
-          ensureToMemrefOpIsValid(val, initArgType);
-          Value yieldedVal =
-              bufferization::lookupBuffer(rewriter, val, state.getOptions());
-
-          if (equivalentYields[index])
-            // Yielded value is equivalent to the corresponding iter_arg bbArg.
-            // Yield the value directly. Most IR should be like that. Everything
-            // else must be resolved with copies and is potentially inefficient.
-            // By default, such problematic IR would already have been rejected
-            // during `verifyAnalysis`, unless `allow-return-allocs`.
-            return castBuffer(yieldedVal, initArgType);
-
-          // It is not certain that the yielded value and the iter_arg bbArg
-          // have the same buffer. Allocate a new buffer and copy. The yielded
-          // buffer will get deallocated by `deallocateBuffers`.
-
-          // TODO: There are cases in which it is not neccessary to return a new
-          // buffer allocation. E.g., when equivalent values are yielded in a
-          // 
diff erent order. This could be resolved with copies.
-          Optional<Value> yieldedAlloc = state.createAlloc(
-              rewriter, val.getLoc(), yieldedVal, /*deallocMemref=*/false);
-          // TODO: We should rollback, but for now just assume that this always
-          // succeeds.
-          assert(yieldedAlloc.hasValue() && "could not create alloc");
-          LogicalResult copyStatus =
-              bufferization::createMemCpy(rewriter, val.getLoc(), yieldedVal,
-                                          *yieldedAlloc, state.getOptions());
-          (void)copyStatus;
-          assert(succeeded(copyStatus) && "could not create memcpy");
-
-          // The iter_arg memref type may have a layout map. Cast the new buffer
-          // to the same type if needed.
-          return castBuffer(*yieldedAlloc, initArgType);
-        });
+        getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices,
+                         equivalentYields, state);
     yieldOp.getResultsMutable().assign(yieldValues);
 
     // Replace loop results.


        


More information about the Mlir-commits mailing list