[Mlir-commits] [mlir] 86974e3 - [mlir][SCF][bufferize] Support different iter_arg/init_arg types (scf.while)

Matthias Springer llvmlistbot at llvm.org
Tue Aug 30 07:58:38 PDT 2022


Author: Matthias Springer
Date: 2022-08-30T16:58:21+02:00
New Revision: 86974e32a4f7c80fd6503c89b065cccf3a91300a

URL: https://github.com/llvm/llvm-project/commit/86974e32a4f7c80fd6503c89b065cccf3a91300a
DIFF: https://github.com/llvm/llvm-project/commit/86974e32a4f7c80fd6503c89b065cccf3a91300a.diff

LOG: [mlir][SCF][bufferize] Support different iter_arg/init_arg types (scf.while)

This change implements the same functionality as D132860, but for scf.while.

Differential Revision: https://reviews.llvm.org/D132927

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 7e458a1f3762..a13badaba7cd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -27,16 +27,78 @@ namespace mlir {
 namespace scf {
 namespace {
 
-// bufferization.to_memref is not allowed to change the rank.
-static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
-#ifndef NDEBUG
-  auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
-  assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
-                                rankedTensorType.getRank())) &&
-         "to_memref would be invalid: mismatching ranks");
-#endif
+/// Helper function for loop bufferization. Cast the given buffer to the given
+/// memref type.
+static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
+  assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
+  assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
+  // If the buffer already has the correct type, no cast is needed.
+  if (buffer.getType() == type)
+    return buffer;
+  // TODO: In case `type` has a layout map that is not the fully dynamic
+  // one, we may not be able to cast the buffer. In that case, the loop
+  // iter_arg's layout map must be changed (see uses of `castBuffer`).
+  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
+         "scf.while op bufferization: cast incompatible");
+  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
 }
 
+/// Bufferization of scf.condition.
+struct ConditionOpInterface
+    : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
+                                                    scf::ConditionOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return false;
+  }
+
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    return {};
+  }
+
+  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
+                            const AnalysisState &state) const {
+    // Condition operands always bufferize inplace. Otherwise, an alloc + copy
+    // may be generated inside the block. We should not return/yield allocations
+    // when possible.
+    return true;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    auto conditionOp = cast<scf::ConditionOp>(op);
+    auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
+
+    SmallVector<Value> newArgs;
+    for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
+      Value value = it.value();
+      if (value.getType().isa<TensorType>()) {
+        FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+        if (failed(maybeBuffer))
+          return failure();
+        FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+            whileOp.getAfterArguments()[it.index()], options);
+        if (failed(resultType))
+          return failure();
+        Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
+        newArgs.push_back(buffer);
+      } else {
+        newArgs.push_back(value);
+      }
+    }
+
+    replaceOpWithNewBufferizedOp<scf::ConditionOp>(
+        rewriter, op, conditionOp.getCondition(), newArgs);
+    return success();
+  }
+};
+
 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
 /// fully implemented at the moment.
 struct ExecuteRegionOpInterface
@@ -283,22 +345,6 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
   return result;
 }
 
-/// Helper function for loop bufferization. Cast the given buffer to the given
-/// memref type.
-static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
-  assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
-  assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
-  // If the buffer already has the correct type, no cast is needed.
-  if (buffer.getType() == type)
-    return buffer;
-  // TODO: In case `type` has a layout map that is not the fully dynamic
-  // one, we may not be able to cast the buffer. In that case, the loop
-  // iter_arg's layout map must be changed (see uses of `castBuffer`).
-  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
-         "scf.while op bufferization: cast incompatible");
-  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
-}
-
 /// Helper function for loop bufferization. Return the bufferized values of the
 /// given OpOperands. If an operand is not a tensor, return the original value.
 static FailureOr<SmallVector<Value>>
@@ -319,60 +365,10 @@ getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
   return result;
 }
 
-/// Helper function for loop bufferization. Compute the buffer that should be
-/// yielded from a loop block (loop body or loop condition).
-static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
-                                         BaseMemRefType type,
-                                         const BufferizationOptions &options) {
-  assert(tensor.getType().isa<TensorType>() && "expected tensor");
-  ensureToMemrefOpIsValid(tensor, type);
-  FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
-  if (failed(yieldedVal))
-    return failure();
-  return castBuffer(rewriter, *yieldedVal, type);
-}
-
-/// Helper function for loop bufferization. Given a range of values, apply
-/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
-/// value in the result vector.
-static FailureOr<SmallVector<Value>>
-convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
-                    llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
-  SmallVector<Value> result;
-  for (const auto &it : llvm::enumerate(values)) {
-    size_t idx = it.index();
-    Value val = it.value();
-    if (tensorIndices.contains(idx)) {
-      FailureOr<Value> maybeVal = func(val, idx);
-      if (failed(maybeVal))
-        return failure();
-      result.push_back(*maybeVal);
-    } else {
-      result.push_back(val);
-    }
-  }
-  return result;
-}
-
-/// Helper function for loop bufferization. Given a list of pre-bufferization
-/// yielded values, compute the list of bufferized yielded values.
-FailureOr<SmallVector<Value>>
-getYieldedValues(RewriterBase &rewriter, ValueRange values,
-                 TypeRange bufferizedTypes,
-                 const DenseSet<int64_t> &tensorIndices,
-                 const BufferizationOptions &options) {
-  return convertTensorValues(
-      values, tensorIndices, [&](Value val, int64_t index) {
-        return getYieldedBuffer(rewriter, val,
-                                bufferizedTypes[index].cast<BaseMemRefType>(),
-                                options);
-      });
-}
-
 /// Helper function for loop bufferization. Given a list of bbArgs of the new
 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
 /// ToTensorOps, so that the block body can be moved over to the new op.
-SmallVector<Value>
+static SmallVector<Value>
 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
                      const DenseSet<int64_t> &tensorIndices) {
   SmallVector<Value> result;
@@ -390,6 +386,74 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
   return result;
 }
 
+/// Compute the bufferized type of a loop iter_arg. This type must be equal to
+/// the bufferized type of the corresponding init_arg and the bufferized type
+/// of the corresponding yielded value.
+///
+/// This function uses bufferization::getBufferType to compute the bufferized
+/// type of the init_arg and of the yielded value. (The computation of the
+/// usually requires computing the bufferized type of the corresponding
+/// iter_arg; the implementation of getBufferType traces back the use-def chain
+/// of the given value and computes a buffer type along the way.) If both buffer
+/// types are equal, no casts are needed the computed buffer type can be used
+/// directly. Otherwise, the buffer types can only 
diff er in their layout map
+/// and a cast must be inserted.
+static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+    BlockArgument iterArg, Value initArg, Value yieldedValue,
+    const BufferizationOptions &options,
+    const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+  // Determine the buffer type of the init_arg.
+  auto initArgBufferType =
+      bufferization::getBufferType(initArg, options, fixedTypes);
+  if (failed(initArgBufferType))
+    return failure();
+
+  // Fix the iter_arg type, so that recursive lookups return the buffer type
+  // of the init_arg. This is to avoid infinite loops when calculating the
+  // buffer type of the yielded value.
+  //
+  // Note: For more precise layout map computation, a fixpoint iteration could
+  // be done (i.e., re-computing the yielded buffer type until the bufferized
+  // iter_arg type no longer changes). This current implementation immediately
+  // switches to a fully dynamic layout map when a mismatch between bufferized
+  // init_arg type and bufferized yield value type is detected.
+  DenseMap<Value, BaseMemRefType> newFixedTypes(fixedTypes);
+  newFixedTypes[iterArg] = *initArgBufferType;
+
+  // Compute the buffer type of the yielded value.
+  BaseMemRefType yieldedValueBufferType;
+  if (yieldedValue.getType().isa<BaseMemRefType>()) {
+    // scf.yield was already bufferized.
+    yieldedValueBufferType = yieldedValue.getType().cast<BaseMemRefType>();
+  } else {
+    auto maybeBufferType =
+        bufferization::getBufferType(yieldedValue, options, newFixedTypes);
+    if (failed(maybeBufferType))
+      return failure();
+    yieldedValueBufferType = *maybeBufferType;
+  }
+
+  // If yielded type and init_arg type are the same, use that type directly.
+  if (*initArgBufferType == yieldedValueBufferType)
+    return yieldedValueBufferType;
+
+  // If there is a mismatch between the yielded buffer type and the iter_arg
+  // buffer type, the buffer type must be promoted to a fully dynamic layout
+  // map.
+  auto yieldedRanked = yieldedValueBufferType.cast<MemRefType>();
+#ifndef NDEBUG
+  auto iterRanked = initArgBufferType->cast<MemRefType>();
+  assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
+         "expected same shape");
+  assert(yieldedRanked.getMemorySpaceAsInt() ==
+             iterRanked.getMemorySpaceAsInt() &&
+         "expected same memory space");
+#endif // NDEBUG
+  return getMemRefTypeWithFullyDynamicLayout(
+      iterArg.getType().cast<RankedTensorType>(),
+      yieldedRanked.getMemorySpaceAsInt());
+}
+
 /// Bufferization of scf.for. Replace with a new scf.for that operates on
 /// memrefs.
 struct ForOpInterface
@@ -507,60 +571,14 @@ struct ForOpInterface
       resultNum = value.cast<OpResult>().getResultNumber();
     }
 
-    // Determine the buffer type of the init_arg.
-    Value initArg = forOp.getInitArgs()[resultNum];
-    auto initArgBufferType =
-        bufferization::getBufferType(initArg, options, fixedTypes);
-    if (failed(initArgBufferType))
-      return failure();
-
-    // Fix the iter_arg type, so that recursive lookups return the buffer type
-    // of the init_arg. This is to avoid infinite loops when calculating the
-    // buffer type of the yielded value.
-    //
-    // Note: For more precise layout map computation, a fixpoint iteration could
-    // be done (i.e., re-computing the yielded buffer type until the bufferized
-    // iter_arg type no longer changes). This current implementation immediately
-    // switches to a fully dynamic layout map when a mismatch between bufferized
-    // init_arg type and bufferized yield value type is detected.
-    DenseMap<Value, BaseMemRefType> newFixedTypes(fixedTypes);
-    newFixedTypes[forOp.getRegionIterArgs()[resultNum]] = *initArgBufferType;
-
-    // Compute the buffer type of the yielded value.
+    // Compute the bufferized type.
     auto yieldOp =
         cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
     Value yieldedValue = yieldOp.getOperand(resultNum);
-    BaseMemRefType yieldedValueBufferType;
-    if (yieldedValue.getType().isa<BaseMemRefType>()) {
-      // scf.yield was already bufferized.
-      yieldedValueBufferType = yieldedValue.getType().cast<BaseMemRefType>();
-    } else {
-      auto maybeBufferType =
-          bufferization::getBufferType(yieldedValue, options, newFixedTypes);
-      if (failed(maybeBufferType))
-        return failure();
-      yieldedValueBufferType = *maybeBufferType;
-    }
-
-    // If yielded type and init_arg type are the same, use that type directly.
-    if (*initArgBufferType == yieldedValueBufferType)
-      return yieldedValueBufferType;
-
-    // If there is a mismatch between the yielded buffer type and the iter_arg
-    // buffer type, the buffer type must be promoted to a fully dynamic layout
-    // map.
-    auto yieldedRanked = yieldedValueBufferType.cast<MemRefType>();
-#ifndef NDEBUG
-    auto iterRanked = initArgBufferType->cast<MemRefType>();
-    assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
-           "expected same shape");
-    assert(yieldedRanked.getMemorySpaceAsInt() ==
-               iterRanked.getMemorySpaceAsInt() &&
-           "expected same memory space");
-#endif // NDEBUG
-    return getMemRefTypeWithFullyDynamicLayout(
-        forOp.getRegionIterArgs()[resultNum].getType().cast<RankedTensorType>(),
-        yieldedRanked.getMemorySpaceAsInt());
+    BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
+    Value initArg = forOp.getInitArgs()[resultNum];
+    return computeLoopRegionIterArgBufferType(iterArg, initArg, yieldedValue,
+                                              options, fixedTypes);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -800,8 +818,6 @@ struct WhileOpInterface
     return success();
   }
 
-  // TODO: Implement getBufferType interface method and infer buffer types.
-
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto whileOp = cast<scf::WhileOp>(op);
@@ -826,6 +842,17 @@ struct WhileOpInterface
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
 
+    // Cast init_args if necessary.
+    SmallVector<Value> castedInitArgs;
+    for (const auto &it : llvm::enumerate(initArgs)) {
+      Value initArg = it.value();
+      auto targetType = bufferization::getBufferType(
+          whileOp.getBeforeArguments()[it.index()], options);
+      if (failed(targetType))
+        return failure();
+      castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
+    }
+
     // The result types of a WhileOp are the same as the "after" bbArg types.
     SmallVector<Type> argsTypesAfter = llvm::to_vector(
         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
@@ -834,13 +861,14 @@ struct WhileOpInterface
         }));
 
     // Construct a new scf.while op with memref instead of tensor values.
-    ValueRange argsRangeBefore(initArgs);
+    ValueRange argsRangeBefore(castedInitArgs);
     TypeRange argsTypesBefore(argsRangeBefore);
-    auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
-                                                    argsTypesAfter, initArgs);
+    auto newWhileOp = rewriter.create<scf::WhileOp>(
+        whileOp.getLoc(), argsTypesAfter, castedInitArgs);
 
     // Add before/after regions to the new op.
-    SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
+    SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
+                                          whileOp.getLoc());
     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
                                          whileOp.getLoc());
     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
@@ -856,19 +884,6 @@ struct WhileOpInterface
         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
     rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
 
-    // Update scf.condition of new loop.
-    auto newConditionOp = newWhileOp.getConditionOp();
-    rewriter.setInsertionPoint(newConditionOp);
-    // Only equivalent buffers or new buffer allocations may be yielded to the
-    // "after" region.
-    // TODO: This could be relaxed for better bufferization results.
-    FailureOr<SmallVector<Value>> newConditionArgs =
-        getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
-                         indicesAfter, options);
-    if (failed(newConditionArgs))
-      return failure();
-    newConditionOp.getArgsMutable().assign(*newConditionArgs);
-
     // Set up new iter_args and move the loop body block to the new op.
     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
     // in ToTensorOps.
@@ -877,25 +892,51 @@ struct WhileOpInterface
         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
     rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
 
-    // Update scf.yield of the new loop.
-    auto newYieldOp = newWhileOp.getYieldOp();
-    rewriter.setInsertionPoint(newYieldOp);
-    // Only equivalent buffers or new buffer allocations may be yielded to the
-    // "before" region.
-    // TODO: This could be relaxed for better bufferization results.
-    FailureOr<SmallVector<Value>> newYieldValues =
-        getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
-                         indicesBefore, options);
-    if (failed(newYieldValues))
-      return failure();
-    newYieldOp.getResultsMutable().assign(*newYieldValues);
-
     // Replace loop results.
     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
 
     return success();
   }
 
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto whileOp = cast<scf::WhileOp>(op);
+    assert(getOwnerOfValue(value) == op && "invalid value");
+    assert(value.getType().isa<TensorType>() && "expected tensor type");
+
+    // Case 1: Block argument of the "before" region.
+    if (auto bbArg = value.cast<BlockArgument>()) {
+      if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
+        Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
+        auto yieldOp = whileOp.getYieldOp();
+        Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
+        return computeLoopRegionIterArgBufferType(bbArg, initArg, yieldedValue,
+                                                  options, fixedTypes);
+      }
+    }
+
+    // Case 2: OpResult of the loop or block argument of the "after" region.
+    // The bufferized "after" bbArg type can be directly computed from the
+    // bufferized "before" bbArg type.
+    unsigned resultNum;
+    if (auto opResult = value.dyn_cast<OpResult>()) {
+      resultNum = opResult.getResultNumber();
+    } else if (value.cast<BlockArgument>().getOwner()->getParent() ==
+               &whileOp.getAfter()) {
+      resultNum = value.cast<BlockArgument>().getArgNumber();
+    } else {
+      llvm_unreachable("invalid value");
+    }
+    Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
+    if (!conditionYieldedVal.getType().isa<TensorType>()) {
+      // scf.condition was already bufferized.
+      return conditionYieldedVal.getType().cast<BaseMemRefType>();
+    }
+    return bufferization::getBufferType(conditionYieldedVal, options,
+                                        fixedTypes);
+  }
+
   /// Assert that yielded values of an scf.while op are equivalent to their
   /// corresponding bbArgs. In that case, the buffer relations of the
   /// corresponding OpResults are "Equivalent".
@@ -979,11 +1020,6 @@ struct YieldOpInterface
             yieldOp->getParentOp()))
       return yieldOp->emitError("unsupported scf::YieldOp parent");
 
-    // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
-    // together with scf.while.)
-    if (isa<scf::WhileOp>(yieldOp->getParentOp()))
-      return success();
-
     SmallVector<Value> newResults;
     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
       Value value = it.value();
@@ -992,15 +1028,20 @@ struct YieldOpInterface
         if (failed(maybeBuffer))
           return failure();
         Value buffer = *maybeBuffer;
-        // In case of scf::ForOp / scf::IfOp, we may have to cast the value
-        // before yielding it.
-        // TODO: Do the same for scf::WhileOp.
+        // We may have to cast the value before yielding it.
         if (isa<scf::ForOp, scf::IfOp>(yieldOp->getParentOp())) {
           FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
               yieldOp->getParentOp()->getResult(it.index()), options);
           if (failed(resultType))
             return failure();
           buffer = castBuffer(rewriter, buffer, *resultType);
+        } else if (auto whileOp =
+                       dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
+          FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+              whileOp.getBeforeArguments()[it.index()], options);
+          if (failed(resultType))
+            return failure();
+          buffer = castBuffer(rewriter, buffer, *resultType);
         }
         newResults.push_back(buffer);
       } else {
@@ -1103,6 +1144,7 @@ struct PerformConcurrentlyOpInterface
 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
+    ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
     ForOp::attachInterface<ForOpInterface>(*ctx);
     IfOp::attachInterface<IfOpInterface>(*ctx);

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index c8bf6125e252..61ec6faae86a 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -388,23 +388,23 @@ func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
     // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[w0]], %[[a0]]
     // CHECK: memref.dealloc %[[w0]]
-    // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
-    // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
-    // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
-    // CHECK: memref.dealloc %[[a0]]
-    // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
+    // CHECK: %[[cloned1:.*]] = bufferization.clone %[[a1]]
     // CHECK: memref.dealloc %[[a1]]
+    // CHECK: %[[cloned0:.*]] = bufferization.clone %[[a0]]
+    // CHECK: memref.dealloc %[[a0]]
     // CHECK: scf.condition(%[[condition]]) %[[cloned1]], %[[cloned0]]
     %condition = tensor.extract %w0[%idx] : tensor<5xi1>
     scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
   } do {
   ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
     // CHECK: } do {
-    // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}):
+    // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>):
     // CHECK: memref.store %{{.*}}, %[[b0]]
-    // CHECK: %[[cloned2:.*]] = bufferization.clone %[[b1]]
+    // CHECK: %[[casted0:.*]] = memref.cast %[[b0]] : memref<5xi1> to memref<5xi1, #{{.*}}>
+    // CHECK: %[[casted1:.*]] = memref.cast %[[b1]] : memref<5xi1> to memref<5xi1, #{{.*}}>
+    // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted1]]
     // CHECK: memref.dealloc %[[b1]]
-    // CHECK: %[[cloned3:.*]] = bufferization.clone %[[b0]]
+    // CHECK: %[[cloned3:.*]] = bufferization.clone %[[casted0]]
     // CHECK: memref.dealloc %[[b0]]
     // CHECK: scf.yield %[[cloned3]], %[[cloned2]]
     // CHECK: }
@@ -441,25 +441,24 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
     // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[w0]], %[[a0]]
     // CHECK: memref.dealloc %[[w0]]
-    // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
-    // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
-    // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
-    // CHECK: memref.dealloc %[[a0]]
-    // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
+    // CHECK: %[[cloned1:.*]] = bufferization.clone %[[a1]]
     // CHECK: memref.dealloc %[[a1]]
+    // CHECK: %[[cloned0:.*]] = bufferization.clone %[[a0]]
+    // CHECK: memref.dealloc %[[a0]]
     // CHECK: scf.condition(%[[condition]]) %[[cloned1]], %[[cloned0]]
     %condition = tensor.extract %w0[%idx] : tensor<5xi1>
     scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
   } do {
   ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
     // CHECK: } do {
-    // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}):
+    // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>):
     // CHECK: memref.store %{{.*}}, %[[b0]]
     // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[b1]], %[[a3]]
     // CHECK: memref.dealloc %[[b1]]
     // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[b0]], %[[a2]]
+    // CHECK: memref.dealloc %[[b0]]
     // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
     // CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
     // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]]
@@ -764,3 +763,32 @@ func.func @scf_for_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
   %x = tensor.extract %r[%c1] : tensor<?xf32>
   return %x : f32
 }
+
+// -----
+
+// We just check that this example bufferizes to valid IR.
+
+// CHECK-LABEL: func @scf_while_buffer_type_mismatch
+func.func @scf_while_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %cst = arith.constant 5.5 : f32
+  %0 = bufferization.alloc_tensor(%sz) : tensor<?xf32>
+  %e2 = tensor.extract_slice %0[1][%sz2][1] : tensor<?xf32> to tensor<?xf32>
+  // init_arg and iter_arg have 
diff erent buffer types. This must be resolved
+  // with casts.
+  %r = scf.while (%t = %e2) : (tensor<?xf32>) -> (tensor<?xf32>) {
+    %c = "test.condition"() : () -> (i1)
+    %s = "test.dummy"() : () -> (index)
+    %e = tensor.extract_slice %t[1][%s][1] : tensor<?xf32> to tensor<?xf32>
+    scf.condition(%c) %e : tensor<?xf32>
+  } do {
+  ^bb0(%b0: tensor<?xf32>):
+    %s2 = "test.dummy"() : () -> (index)
+    %n = tensor.insert %cst into %b0[%s2] : tensor<?xf32>
+    scf.yield %n : tensor<?xf32>
+  }
+  %x = tensor.extract %r[%c1] : tensor<?xf32>
+  return %x : f32
+}


        


More information about the Mlir-commits mailing list