[Mlir-commits] [mlir] 996834e - [mlir][SCF] Fix scf.while bufferization

Matthias Springer llvmlistbot at llvm.org
Tue May 17 15:38:41 PDT 2022


Author: Matthias Springer
Date: 2022-05-18T00:35:50+02:00
New Revision: 996834e6813ab5481a58e42e7a11f57d243a3a99

URL: https://github.com/llvm/llvm-project/commit/996834e6813ab5481a58e42e7a11f57d243a3a99
DIFF: https://github.com/llvm/llvm-project/commit/996834e6813ab5481a58e42e7a11f57d243a3a99.diff

LOG: [mlir][SCF] Fix scf.while bufferization

Before this fix, the bufferization implementation made the incorrect assumption that the values yielded from the "before" region must match with the values yielded from the "after" region.

Differential Revision: https://reviews.llvm.org/D125835

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ea0a3e6fa85a3..4ca73feb85da2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -543,11 +543,11 @@ struct BufferizationState {
             Optional<ForceInPlacability> overrideInPlace = None,
             Optional<Operation *> customCopyInsertionPoint = None);
 
-  /// Return the buffer type for a given OpOperand (tensor) after bufferization.
+  /// Return the buffer type for a given Value (tensor) after bufferization.
   ///
   /// Note: Op implementations should preferrably call `getBuffer()->getType()`.
   /// This function should only be used if `getBuffer` cannot be used.
-  BaseMemRefType getBufferType(OpOperand &opOperand) const;
+  BaseMemRefType getBufferType(Value value) const;
 
   /// Return a reference to the BufferizationOptions.
   const BufferizationOptions &getOptions() const {

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index be7cf209d46a6..a1c9fd66ff4a4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -333,13 +333,12 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
   return resultBuffer;
 }
 
-/// Return the buffer type for a given OpOperand (tensor) after bufferization.
-BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const {
-  Value tensor = opOperand.get();
-  auto tensorType = tensor.getType().dyn_cast<TensorType>();
+/// Return the buffer type for a given Value (tensor) after bufferization.
+BaseMemRefType BufferizationState::getBufferType(Value value) const {
+  auto tensorType = value.getType().dyn_cast<TensorType>();
   assert(tensorType && "unexpected non-tensor type");
 
-  if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
+  if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
     return toTensorOp.memref().getType().cast<BaseMemRefType>();
 
   return getMemRefType(tensorType, getOptions());

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 638deb6a68c4b..d0c3c381dbfa1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -276,14 +276,14 @@ static DenseSet<int64_t> getTensorIndices(ValueRange values) {
 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
                                        ValueRange yieldedValues,
                                        const AnalysisState &state) {
+  unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
   DenseSet<int64_t> result;
-  int64_t counter = 0;
-  for (const auto &it : llvm::zip(bbArgs, yieldedValues)) {
-    if (!std::get<0>(it).getType().isa<TensorType>())
+  for (unsigned int i = 0; i < minSize; ++i) {
+    if (!bbArgs[i].getType().isa<TensorType>() ||
+        !yieldedValues[i].getType().isa<TensorType>())
       continue;
-    if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it)))
-      result.insert(counter);
-    counter++;
+    if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
+      result.insert(i);
   }
   return result;
 }
@@ -486,8 +486,6 @@ struct ForOpInterface
     // The new memref init_args of the loop.
     SmallVector<Value> initArgs =
         getBuffers(rewriter, forOp.getIterOpOperands(), state);
-    if (initArgs.size() != indices.size())
-      return failure();
 
     // Construct a new scf.for op with memref instead of tensor values.
     auto newForOp = rewriter.create<scf::ForOp>(
@@ -578,7 +576,16 @@ struct WhileOpInterface
   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
                                             const AnalysisState &state) const {
     auto whileOp = cast<scf::WhileOp>(op);
-    return {whileOp->getResult(opOperand.getOperandNumber())};
+    unsigned int idx = opOperand.getOperandNumber();
+
+    // The OpResults and OpOperands may not match. They may not even have the
+    // same type. The number of OpResults and OpOperands can also 
diff er.
+    if (idx >= op->getNumResults() ||
+        opOperand.get().getType() != op->getResult(idx).getType())
+      return {};
+
+    // The only aliasing OpResult may be the one at the same index.
+    return {whileOp->getResult(idx)};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -589,6 +596,13 @@ struct WhileOpInterface
     unsigned int resultNumber = opResult.getResultNumber();
     auto whileOp = cast<scf::WhileOp>(op);
 
+    // The "before" region bbArgs and the OpResults may not match.
+    if (resultNumber >= whileOp.getBeforeArguments().size())
+      return BufferRelation::None;
+    if (opResult.getType() !=
+        whileOp.getBeforeArguments()[resultNumber].getType())
+      return BufferRelation::None;
+
     auto conditionOp = whileOp.getConditionOp();
     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
     Value conditionOperand = conditionOp.getArgs()[resultNumber];
@@ -627,9 +641,12 @@ struct WhileOpInterface
            "regions with multiple blocks not supported");
     Block *afterBody = &whileOp.getAfter().front();
 
-    // Indices of all iter_args that have tensor type. These are the ones that
-    // are bufferized.
-    DenseSet<int64_t> indices = getTensorIndices(whileOp.getInits());
+    // Indices of all bbArgs that have tensor type. These are the ones that
+    // are bufferized. The "before" and "after" regions may have 
diff erent args.
+    DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
+    DenseSet<int64_t> indicesAfter =
+        getTensorIndices(whileOp.getAfterArguments());
+
     // For every yielded value, is the value equivalent to its corresponding
     // bbArg?
     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
@@ -642,51 +659,64 @@ struct WhileOpInterface
     // The new memref init_args of the loop.
     SmallVector<Value> initArgs =
         getBuffers(rewriter, whileOp->getOpOperands(), state);
-    if (initArgs.size() != indices.size())
-      return failure();
+
+    // The result types of a WhileOp are the same as the "after" bbArg types.
+    SmallVector<Type> argsTypesAfter = llvm::to_vector(
+        llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
+          return state.getBufferType(bbArg).cast<Type>();
+        }));
 
     // Construct a new scf.while op with memref instead of tensor values.
-    ValueRange argsRange(initArgs);
-    TypeRange argsTypes(argsRange);
-    auto newWhileOp =
-        rewriter.create<scf::WhileOp>(whileOp.getLoc(), argsTypes, initArgs);
+    ValueRange argsRangeBefore(initArgs);
+    TypeRange argsTypesBefore(argsRangeBefore);
+    auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
+                                                    argsTypesAfter, initArgs);
+
     // Add before/after regions to the new op.
-    SmallVector<Location> bbArgLocs(initArgs.size(), whileOp.getLoc());
+    SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
+    SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
+                                         whileOp.getLoc());
     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
-    newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs);
+    newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
-    newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs);
+    newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
 
     // Set up new iter_args and move the loop condition block to the new op.
     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
     // in ToTensorOps.
     rewriter.setInsertionPointToStart(newBeforeBody);
     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
-        rewriter, newWhileOp.getBeforeArguments(), indices);
+        rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
 
     // Update scf.condition of new loop.
     auto newConditionOp = newWhileOp.getConditionOp();
     rewriter.setInsertionPoint(newConditionOp);
+    // Only equivalent buffers or new buffer allocations may be yielded to the
+    // "after" region.
+    // TODO: This could be relaxed for better bufferization results.
     SmallVector<Value> newConditionArgs =
-        getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices,
-                         equivalentYieldsBefore, state);
+        getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
+                         indicesAfter, equivalentYieldsBefore, state);
     newConditionOp.getArgsMutable().assign(newConditionArgs);
 
     // Set up new iter_args and move the loop body block to the new op.
     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
     // in ToTensorOps.
     rewriter.setInsertionPointToStart(newAfterBody);
-    SmallVector<Value> newAfterArgs =
-        getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices);
+    SmallVector<Value> newAfterArgs = getBbArgReplacements(
+        rewriter, newWhileOp.getAfterArguments(), indicesAfter);
     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
 
     // Update scf.yield of the new loop.
     auto newYieldOp = newWhileOp.getYieldOp();
     rewriter.setInsertionPoint(newYieldOp);
+    // Only equivalent buffers or new buffer allocations may be yielded to the
+    // "before" region.
+    // TODO: This could be relaxed for better bufferization results.
     SmallVector<Value> newYieldValues =
-        getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices,
-                         equivalentYieldsAfter, state);
+        getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
+                         indicesBefore, equivalentYieldsAfter, state);
     newYieldOp.getResultsMutable().assign(newYieldValues);
 
     // Replace loop results.

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 238d5365fb87e..0a48baab17e7d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -111,7 +111,7 @@ struct CollapseShapeOpInterface
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
     OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/;
-    auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>();
+    auto bufferType = state.getBufferType(srcOperand.get()).cast<MemRefType>();
 
     if (tensorResultType.getRank() == 0) {
       // 0-d collapses must go through a 
diff erent op builder.

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 09fcd8192765a..4f38fb2ce93d7 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -449,3 +449,36 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
   // CHECK: return %[[loop]]#0, %[[loop]]#1
   return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
 }
+
+// -----
+
+// CHECK-LABEL: func @scf_while_iter_arg_result_mismatch(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}>
+//       CHECK:   %[[alloc1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+//       CHECK:   %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+//       CHECK:   scf.while (%[[arg3:.*]] = %[[arg1]]) : (memref<5xi1, #{{.*}}) -> () {
+//       CHECK:     %[[load:.*]] = memref.load %[[arg0]]
+//       CHECK:     scf.condition(%[[load]])
+//       CHECK:   } do {
+//       CHECK:     memref.copy %[[arg0]], %[[alloc2]]
+//       CHECK:     memref.store %{{.*}}, %[[alloc2]]
+//       CHECK:     memref.copy %[[alloc2]], %[[alloc1]]
+//       CHECK:     %[[casted:.*]] = memref.cast %[[alloc1]] : memref<5xi1> to memref<5xi1, #{{.*}}>
+//       CHECK:     scf.yield %[[casted]]
+//       CHECK:   }
+//   CHECK-DAG:   memref.dealloc %[[alloc1]]
+//   CHECK-DAG:   memref.dealloc %[[alloc2]]
+func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
+                                              %arg1: tensor<5xi1>,
+                                              %arg2: index) {
+  scf.while (%arg3 = %arg1) : (tensor<5xi1>) -> () {
+    %0 = tensor.extract %arg0[%arg2] : tensor<5xi1>
+    scf.condition(%0)
+  } do {
+    %0 = "dummy.some_op"() : () -> index
+    %1 = "dummy.another_op"() : () -> i1
+    %2 = tensor.insert %1 into %arg0[%0] : tensor<5xi1>
+    scf.yield %2 : tensor<5xi1>
+  }
+  return
+}


        


More information about the Mlir-commits mailing list