[Mlir-commits] [mlir] 878950b - [mlir][bufferization] Simplify `getBufferType`

Matthias Springer llvmlistbot at llvm.org
Wed Aug 16 06:05:11 PDT 2023


Author: Matthias Springer
Date: 2023-08-16T15:02:07+02:00
New Revision: 878950b82cb7727361b5ae13ea0326e39c7677fe

URL: https://github.com/llvm/llvm-project/commit/878950b82cb7727361b5ae13ea0326e39c7677fe
DIFF: https://github.com/llvm/llvm-project/commit/878950b82cb7727361b5ae13ea0326e39c7677fe.diff

LOG: [mlir][bufferization] Simplify `getBufferType`

`getBufferType` computes the bufferized type of an SSA value without bufferizing any IR. This is useful for predicting the bufferized type of iter_args of a loop.

To avoid endless recursion (e.g., in the case of "scf.for", the type of the iter_arg depends on the type of init_arg and the type of the yielded value; the type of the yielded value depends on the type of the iter_arg again), `fixedTypes` was used to fall back to "fixed" type. A simpler way is to maintain an "invocation stack". `getBufferType` implementations can then inspect the invocation stack to detect repetitive computations (typically when computing the bufferized type of a block argument).

Also improve error messages in case of inconsistent memory spaces inside of a loop.

Differential Revision: https://reviews.llvm.org/D158060

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 7f758e492d89eb..6fc487c1a11aa5 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -609,17 +609,18 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
                                         const BufferizationOptions &options);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
-/// without bufferizing any IR. If at any point during the type computation, the
-/// type of a value in `fixedTypes` in required, the mapped type is used.
+/// without bufferizing any IR. This function (and not the other overload
+/// without `invocationStack`) can be used from `getBufferType` implementations
+/// of the `BufferizableOpInterface`.
 ///
 /// Note: It should be sufficient to call `getBuffer()->getType()` in most
 /// cases. However, when a buffer type should be predicted without modifying any
 /// IR, this function can be used.
 ///
-/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType>
-getBufferType(Value value, const BufferizationOptions &options,
-              const DenseMap<Value, BaseMemRefType> &fixedTypes);
+/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
+FailureOr<BaseMemRefType> getBufferType(Value value,
+                                        const BufferizationOptions &options,
+                                        SmallVector<Value> &invocationStack);
 
 /// Replace an op with replacement values. The op is deleted. Tensor OpResults
 /// must be replaced with memref values.
@@ -691,7 +692,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// places.
 FailureOr<BaseMemRefType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
-                     const DenseMap<Value, BaseMemRefType> &fixedTypes);
+                     SmallVector<Value> &invocationStack);
 
 /// This is the default implementation of
 /// BufferizableOpInterface::resultBufferizesToMemoryWrite. Should not be called

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 6a349f34ef1fbd..bd7a2d8b3f1eac 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -494,18 +494,31 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
 
           This method is useful when the bufferized type of value must be
           predicted before modifying any IR.
+
+          Implementations may call `bufferization::getBufferType` to compute the
+          bufferized type of another SSA value. The same (unmodified)
+          `invocationStack` must be passed to that function. The stack contains
+          all SSA values for which a buffer type computation is currently in
+          progress. Implementations may inspect the stack to detect repetitive
+          computations for the same SSA value. (E.g., when bufferized types of a
+          loop.)
+
+          Note: This interface method should never be called directly from user
+          code. Always use `bufferization::getBufferType`.
         }],
         /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
-                      "const ::mlir::DenseMap<::mlir::Value, ::mlir::BaseMemRefType>":$fixedTypes),
+                      "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           assert(getOwnerOfValue(value) == $_op.getOperation() &&
                  "expected that value belongs to this op");
+          assert(invocationStack.back() == value &&
+                 "inconsistant invocation stack");
           return ::mlir::bufferization::detail::defaultGetBufferType(
-              value, options, fixedTypes);
+              value, options, invocationStack);
         }]
       >,
       InterfaceMethod<

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index d62f781a01b003..fec07af349b3a8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -108,7 +108,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
-        const DenseMap<Value, BaseMemRefType> &fixedTypes);
+        SmallVector<Value> &invocationStack);
 
     RankedTensorType getType() {
       return ::llvm::cast<RankedTensorType>(getResult().getType());
@@ -388,7 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
-        const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+        SmallVector<Value> &invocationStack) {
       return ::llvm::cast<BaseMemRefType>(getMemref().getType());
     }
   }];

diff  --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index ddc5ce54040f65..1bfb0c7e102cd3 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -170,13 +170,13 @@ struct SelectOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() && "invalid value");
     auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
-                                                 options, fixedTypes);
+                                                 options, invocationStack);
     auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
-                                                  options, fixedTypes);
+                                                  options, invocationStack);
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1427037619e591..a96cfedc9a4527 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
 
 //===----------------------------------------------------------------------===//
@@ -728,27 +729,25 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
-  DenseMap<Value, BaseMemRefType> fixedTypes;
-  return getBufferType(value, options, fixedTypes);
+  SmallVector<Value> invocationStack;
+  return getBufferType(value, options, invocationStack);
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType> bufferization::getBufferType(
-    Value value, const BufferizationOptions &options,
-    const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+FailureOr<BaseMemRefType>
+bufferization::getBufferType(Value value, const BufferizationOptions &options,
+                             SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) &&
          "unexpected non-tensor type");
-
-  // If the `value` is in `fixedTypes`, return the mapped type.
-  const auto &it = fixedTypes.find(value);
-  if (it != fixedTypes.end())
-    return it->second;
+  invocationStack.push_back(value);
+  auto popFromStack =
+      llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
 
   // Try querying BufferizableOpInterface.
   Operation *op = getOwnerOfValue(value);
   auto bufferizableOp = options.dynCastBufferizableOp(op);
   if (bufferizableOp)
-    return bufferizableOp.getBufferType(value, options, fixedTypes);
+    return bufferizableOp.getBufferType(value, options, invocationStack);
 
   // Op is not bufferizable.
   if (!options.defaultMemorySpace.has_value())
@@ -996,7 +995,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
 
 FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
-    const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+    SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
 
   // No further analysis is possible for a block argument.
@@ -1013,7 +1012,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return getBufferType(equivalentOperand, options, fixedTypes);
+    return getBufferType(equivalentOperand, options, invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index e16cfcead1c37d..c8681374ccae11 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -230,9 +230,9 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType> AllocTensorOp::getBufferType(
-    Value value, const BufferizationOptions &options,
-    const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+FailureOr<BaseMemRefType>
+AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
+                             SmallVector<Value> &invocationStack) {
   assert(value == getResult() && "invalid value");
 
   // Compute memory space of this allocation.
@@ -241,7 +241,7 @@ FailureOr<BaseMemRefType> AllocTensorOp::getBufferType(
     memorySpace = *getMemorySpace();
   } else if (getCopy()) {
     auto copyBufferType =
-        bufferization::getBufferType(getCopy(), options, fixedTypes);
+        bufferization::getBufferType(getCopy(), options, invocationStack);
     if (failed(copyBufferType))
       return failure();
     memorySpace = copyBufferType->getMemorySpace();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index b72c4f1999401b..10c704fc64dd51 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -199,7 +199,7 @@ struct CallOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto callOp = cast<func::CallOp>(op);
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
@@ -321,7 +321,7 @@ struct FuncOpInterface
     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto funcOp = cast<FuncOp>(op);
     auto bbArg = cast<BlockArgument>(value);
     // Unstructured control flow is not supported.

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 11ee5415bb4cbc..ac01d264eb8fba 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct IfOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto ifOp = cast<scf::IfOp>(op);
     auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
     auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
@@ -227,7 +227,7 @@ struct IfOpInterface
       thenBufferType = cast<BaseMemRefType>(thenValue.getType());
     } else {
       auto maybeBufferType =
-          bufferization::getBufferType(thenValue, options, fixedTypes);
+          bufferization::getBufferType(thenValue, options, invocationStack);
       if (failed(maybeBufferType))
         return failure();
       thenBufferType = *maybeBufferType;
@@ -237,7 +237,7 @@ struct IfOpInterface
       elseBufferType = cast<BaseMemRefType>(elseValue.getType());
     } else {
       auto maybeBufferType =
-          bufferization::getBufferType(elseValue, options, fixedTypes);
+          bufferization::getBufferType(elseValue, options, invocationStack);
       if (failed(maybeBufferType))
         return failure();
       elseBufferType = *maybeBufferType;
@@ -331,33 +331,34 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 ///
 /// This function uses bufferization::getBufferType to compute the bufferized
 /// type of the init_arg and of the yielded value. (The computation of the
-/// usually requires computing the bufferized type of the corresponding
-/// iter_arg; the implementation of getBufferType traces back the use-def chain
-/// of the given value and computes a buffer type along the way.) If both buffer
-/// types are equal, no casts are needed the computed buffer type can be used
-/// directly. Otherwise, the buffer types can only 
diff er in their layout map
-/// and a cast must be inserted.
+/// bufferized yielded value type usually requires computing the bufferized type
+/// of the iter_arg again; the implementation of getBufferType traces back the
+/// use-def chain of the given value and computes a buffer type along the way.)
+/// If both buffer types are equal, no casts are needed the computed buffer type
+/// can be used directly. Otherwise, the buffer types can only 
diff er in their
+/// layout map and a cast must be inserted.
 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
-    BlockArgument iterArg, Value initArg, Value yieldedValue,
-    const BufferizationOptions &options,
-    const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+    Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
+    const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
   auto initArgBufferType =
-      bufferization::getBufferType(initArg, options, fixedTypes);
+      bufferization::getBufferType(initArg, options, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
-  // Fix the iter_arg type, so that recursive lookups return the buffer type
-  // of the init_arg. This is to avoid infinite loops when calculating the
-  // buffer type of the yielded value.
-  //
-  // Note: For more precise layout map computation, a fixpoint iteration could
-  // be done (i.e., re-computing the yielded buffer type until the bufferized
-  // iter_arg type no longer changes). This current implementation immediately
-  // switches to a fully dynamic layout map when a mismatch between bufferized
-  // init_arg type and bufferized yield value type is detected.
-  DenseMap<Value, BaseMemRefType> newFixedTypes(fixedTypes);
-  newFixedTypes[iterArg] = *initArgBufferType;
+  if (llvm::count(invocationStack, iterArg) >= 2) {
+    // If the iter_arg is already twice on the invocation stack, just take the
+    // type of the init_arg. This is to avoid infinite loops when calculating
+    // the buffer type. This will most likely result in computing a memref type
+    // with a fully dynamic layout map.
+
+    // Note: For more precise layout map computation, a fixpoint iteration could
+    // be done (i.e., re-computing the yielded buffer type until the bufferized
+    // iter_arg type no longer changes). This current implementation immediately
+    // switches to a fully dynamic layout map when a mismatch between bufferized
+    // init_arg type and bufferized yield value type is detected.
+    return *initArgBufferType;
+  }
 
   // Compute the buffer type of the yielded value.
   BaseMemRefType yieldedValueBufferType;
@@ -365,8 +366,10 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
     // scf.yield was already bufferized.
     yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
   } else {
+    // Note: This typically triggers a recursive call for the buffer type of
+    // the iter_arg.
     auto maybeBufferType =
-        bufferization::getBufferType(yieldedValue, options, newFixedTypes);
+        bufferization::getBufferType(yieldedValue, options, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -376,20 +379,26 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   if (*initArgBufferType == yieldedValueBufferType)
     return yieldedValueBufferType;
 
-  // If there is a mismatch between the yielded buffer type and the iter_arg
+  // If there is a mismatch between the yielded buffer type and the init_arg
   // buffer type, the buffer type must be promoted to a fully dynamic layout
   // map.
-  auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
+  auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
+  auto iterTensorType = cast<TensorType>(iterArg.getType());
+  auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
+  if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
+    return loopOp->emitOpError(
+        "init_arg and yielded value bufferize to inconsistent memory spaces");
 #ifndef NDEBUG
-  auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
-  assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
-         "expected same shape");
-  assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
-         "expected same memory space");
+  if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
+    assert(
+        llvm::all_equal({yieldedRankedBufferType.getShape(),
+                         cast<MemRefType>(initBufferType).getShape(),
+                         cast<RankedTensorType>(iterTensorType).getShape()}) &&
+        "expected same shape");
+  }
 #endif // NDEBUG
   return getMemRefTypeWithFullyDynamicLayout(
-      cast<RankedTensorType>(iterArg.getType()),
-      yieldedRanked.getMemorySpace());
+      iterTensorType, yieldedBufferType.getMemorySpace());
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -513,29 +522,32 @@ struct ForOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto forOp = cast<scf::ForOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
     assert(isa<TensorType>(value.getType()) && "expected tensor type");
 
-    // Get result/argument number.
-    unsigned resultNum;
-    if (auto bbArg = dyn_cast<BlockArgument>(value)) {
-      resultNum =
-          forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
-              .getResultNumber();
-    } else {
-      resultNum = cast<OpResult>(value).getResultNumber();
+    if (auto opResult = dyn_cast<OpResult>(value)) {
+      // The type of an OpResult must match the corresponding iter_arg type.
+      BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(
+          forOp.getOpOperandForResult(opResult));
+      return bufferization::getBufferType(bbArg, options, invocationStack);
     }
 
+    // Compute result/argument number.
+    BlockArgument bbArg = cast<BlockArgument>(value);
+    unsigned resultNum =
+        forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
+            .getResultNumber();
+
     // Compute the bufferized type.
     auto yieldOp =
         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
     Value yieldedValue = yieldOp.getOperand(resultNum);
     BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
     Value initArg = forOp.getInitArgs()[resultNum];
-    return computeLoopRegionIterArgBufferType(iterArg, initArg, yieldedValue,
-                                              options, fixedTypes);
+    return computeLoopRegionIterArgBufferType(
+        op, iterArg, initArg, yieldedValue, options, invocationStack);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -838,7 +850,7 @@ struct WhileOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto whileOp = cast<scf::WhileOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
     assert(isa<TensorType>(value.getType()) && "expected tensor type");
@@ -849,8 +861,8 @@ struct WhileOpInterface
         Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
         auto yieldOp = whileOp.getYieldOp();
         Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
-        return computeLoopRegionIterArgBufferType(bbArg, initArg, yieldedValue,
-                                                  options, fixedTypes);
+        return computeLoopRegionIterArgBufferType(
+            op, bbArg, initArg, yieldedValue, options, invocationStack);
       }
     }
 
@@ -872,7 +884,7 @@ struct WhileOpInterface
       return cast<BaseMemRefType>(conditionYieldedVal.getType());
     }
     return bufferization::getBufferType(conditionYieldedVal, options,
-                                        fixedTypes);
+                                        invocationStack);
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1104,20 +1116,20 @@ struct ForallOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto forallOp = cast<ForallOp>(op);
 
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
       return bufferization::getBufferType(
-          forallOp.getTiedOpOperand(bbArg)->get(), options, fixedTypes);
+          forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
     return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        fixedTypes);
+        invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index efcea2f6b45ca9..a67ea0334b22b9 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -48,10 +48,10 @@ struct CastOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto castOp = cast<tensor::CastOp>(op);
-    auto maybeSrcBufferType =
-        bufferization::getBufferType(castOp.getSource(), options, fixedTypes);
+    auto maybeSrcBufferType = bufferization::getBufferType(
+        castOp.getSource(), options, invocationStack);
     if (failed(maybeSrcBufferType))
       return failure();
     Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
@@ -133,10 +133,10 @@ struct CollapseShapeOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
     auto maybeSrcBufferType = bufferization::getBufferType(
-        collapseShapeOp.getSrc(), options, fixedTypes);
+        collapseShapeOp.getSrc(), options, invocationStack);
     if (failed(maybeSrcBufferType))
       return failure();
     auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
@@ -302,10 +302,10 @@ struct ExpandShapeOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
     auto maybeSrcBufferType = bufferization::getBufferType(
-        expandShapeOp.getSrc(), options, fixedTypes);
+        expandShapeOp.getSrc(), options, invocationStack);
     if (failed(maybeSrcBufferType))
       return failure();
     auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
@@ -383,11 +383,11 @@ struct ExtractSliceOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     assert(value == extractSliceOp.getResult() && "invalid value");
     auto srcMemrefType = bufferization::getBufferType(
-        extractSliceOp.getSource(), options, fixedTypes);
+        extractSliceOp.getSource(), options, invocationStack);
     if (failed(srcMemrefType))
       return failure();
     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
@@ -853,11 +853,11 @@ struct PadOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     // Infer memory space from the source tensor.
     auto padOp = cast<tensor::PadOp>(op);
-    auto maybeSrcBufferType =
-        bufferization::getBufferType(padOp.getSource(), options, fixedTypes);
+    auto maybeSrcBufferType = bufferization::getBufferType(
+        padOp.getSource(), options, invocationStack);
     if (failed(maybeSrcBufferType))
       return failure();
     MemRefLayoutAttrInterface layout;
@@ -1002,11 +1002,11 @@ struct ReshapeOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+                SmallVector<Value> &invocationStack) const {
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
     assert(value == reshapeOp.getResult() && "unexpected value provided");
     auto maybeSourceBufferType = bufferization::getBufferType(
-        reshapeOp.getSource(), options, fixedTypes);
+        reshapeOp.getSource(), options, invocationStack);
     if (failed(maybeSourceBufferType))
       return failure();
     return getMemRefTypeWithStaticIdentityLayout(

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
index c77b6ce345e9ff..c8d6d506270a99 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
@@ -26,3 +26,16 @@ func.func @execute_region_multiple_blocks(%t: tensor<5xf32>) -> tensor<5xf32> {
   }
   func.return %0 : tensor<5xf32>
 }
+
+// -----
+
+func.func @inconsistent_memory_space_scf_for(%lb: index, %ub: index, %step: index) -> tensor<10xf32> {
+  %0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<10xf32>
+  %1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<10xf32>
+  // expected-error @below{{init_arg and yielded value bufferize to inconsistent memory spaces}}
+  %2 = scf.for %iv = %lb to %ub step %step iter_args(%arg = %0) -> tensor<10xf32> {
+    // expected-error @below {{failed to bufferize op}}
+    scf.yield %1 : tensor<10xf32>
+  }
+  return %2 : tensor<10xf32>
+}


        


More information about the Mlir-commits mailing list