[Mlir-commits] [mlir] a5d09c6 - [mlir][scf] Implement BufferizableOpInterface for scf::WhileOp

Matthias Springer llvmlistbot at llvm.org
Fri May 6 01:28:46 PDT 2022


Author: Matthias Springer
Date: 2022-05-06T17:24:33+09:00
New Revision: a5d09c637261252393a015e7858efd85c9166e32

URL: https://github.com/llvm/llvm-project/commit/a5d09c637261252393a015e7858efd85c9166e32
DIFF: https://github.com/llvm/llvm-project/commit/a5d09c637261252393a015e7858efd85c9166e32.diff

LOG: [mlir][scf] Implement BufferizableOpInterface for scf::WhileOp

This follows the same implementation strategy as scf::ForOp and common functionality is extracted into helper functions.

This implementation works well in cases where each yielded value (from either body/condition region) is equivalent to the corresponding bbArg of the parent block. In that case, each OpResult of the loop may be aliasing with the corresponding OpOperand of the loop (and with no other OpOperand).

In the absence of said equivalence relationship, new buffer copies must be inserted, so that the aliasing OpOperand/OpResult contract of scf::WhileOp is honored. In essence, by yielding a newly allocated buffer, we can enforce the specified may-alias relationship. (Newly allocated buffers cannot alias with any OpOperands of the loop.)

Differential Revision: https://reviews.llvm.org/D124929

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b6d2001403bd1..39af7d337d054 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -271,7 +271,7 @@ static DenseSet<int64_t> getTensorIndices(ValueRange values) {
 
 /// Helper function for loop bufferization. Return the indices of all
 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
-DenseSet<int64_t> getEquivalentBuffers(ValueRange bbArgs,
+DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
                                        ValueRange yieldedValues,
                                        const AnalysisState &state) {
   DenseSet<int64_t> result;
@@ -403,6 +403,18 @@ SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
       });
 }
 
+/// Helper function for loop bufferization. Given a list of bbArgs of the new
+/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
+/// ToTensorOps, so that the block body can be moved over to the new op.
+SmallVector<Value>
+getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
+                     const DenseSet<int64_t> &tensorIndices) {
+  return convertTensorValues(
+      bbArgs, tensorIndices, [&](Value val, int64_t index) {
+        return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
+      });
+}
+
 /// Bufferization of scf.for. Replace with a new scf.for that operates on
 /// memrefs.
 struct ForOpInterface
@@ -486,10 +498,8 @@ struct ForOpInterface
     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
     // iter_args of the new loop in ToTensorOps.
     rewriter.setInsertionPointToStart(loopBody);
-    SmallVector<Value> iterArgs = convertTensorValues(
-        newForOp.getRegionIterArgs(), indices, [&](Value val, int64_t index) {
-          return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
-        });
+    SmallVector<Value> iterArgs =
+        getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
 
     // Erase terminator if present.
@@ -546,6 +556,187 @@ struct ForOpInterface
   }
 };
 
+/// Bufferization of scf.while. Replace with a new scf.while that operates on
+/// memrefs.
+struct WhileOpInterface
+    : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
+                                                    scf::WhileOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    // Tensor iter_args of scf::WhileOps are always considered as a read.
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    // Tensor iter_args of scf::WhileOps are always considered as a write.
+    return true;
+  }
+
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    auto whileOp = cast<scf::WhileOp>(op);
+    return {whileOp->getResult(opOperand.getOperandNumber())};
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const AnalysisState &state) const {
+    // WhileOp results are equivalent to their corresponding init_args if the
+    // corresponding iter_args and yield values are equivalent (for both the
+    // "before" and the "after" block).
+    unsigned int resultNumber = opResult.getResultNumber();
+    auto whileOp = cast<scf::WhileOp>(op);
+
+    auto conditionOp = whileOp.getConditionOp();
+    BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
+    Value conditionOperand = conditionOp.getArgs()[resultNumber];
+    bool equivCondition =
+        state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
+
+    auto yieldOp = whileOp.getYieldOp();
+    BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
+    Value yieldOperand = yieldOp.getOperand(resultNumber);
+    bool equivYield =
+        state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
+
+    return equivCondition && equivYield ? BufferRelation::Equivalent
+                                        : BufferRelation::None;
+  }
+
+  bool isWritable(Operation *op, Value value,
+                  const AnalysisState &state) const {
+    // Interestingly, scf::WhileOp's bbArg can **always** be viewed
+    // inplace from the perspective of ops nested under:
+    //   1. Either the matching iter operand is not bufferized inplace and an
+    //      alloc + optional copy makes the bbArg itself inplaceable.
+    //   2. Or the matching iter operand is bufferized inplace and bbArg just
+    //      bufferizes to that too.
+    return true;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          BufferizationState &state) const {
+    auto whileOp = cast<scf::WhileOp>(op);
+
+    assert(whileOp.getBefore().getBlocks().size() == 1 &&
+           "regions with multiple blocks not supported");
+    Block *beforeBody = &whileOp.getBefore().front();
+    assert(whileOp.getAfter().getBlocks().size() == 1 &&
+           "regions with multiple blocks not supported");
+    Block *afterBody = &whileOp.getAfter().front();
+
+    // Indices of all iter_args that have tensor type. These are the ones that
+    // are bufferized.
+    DenseSet<int64_t> indices = getTensorIndices(whileOp.getInits());
+    // For every yielded value, is the value equivalent to its corresponding
+    // bbArg?
+    DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
+        whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(),
+        state.getAnalysisState());
+    DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
+        whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(),
+        state.getAnalysisState());
+
+    // The new memref init_args of the loop.
+    SmallVector<Value> initArgs =
+        getBuffers(rewriter, whileOp->getOpOperands(), state);
+    if (initArgs.size() != indices.size())
+      return failure();
+
+    // Construct a new scf.while op with memref instead of tensor values.
+    ValueRange argsRange(initArgs);
+    TypeRange argsTypes(argsRange);
+    auto newWhileOp =
+        rewriter.create<scf::WhileOp>(whileOp.getLoc(), argsTypes, initArgs);
+    // Add before/after regions to the new op.
+    SmallVector<Location> bbArgLocs(initArgs.size(), whileOp.getLoc());
+    Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
+    newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs);
+    Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
+    newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs);
+
+    // Set up new iter_args and move the loop condition block to the new op.
+    // The old block uses tensors, so wrap the (memref) bbArgs of the new block
+    // in ToTensorOps.
+    rewriter.setInsertionPointToStart(newBeforeBody);
+    SmallVector<Value> newBeforeArgs = getBbArgReplacements(
+        rewriter, newWhileOp.getBeforeArguments(), indices);
+    rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
+
+    // Update scf.condition of new loop.
+    auto newConditionOp = newWhileOp.getConditionOp();
+    rewriter.setInsertionPoint(newConditionOp);
+    SmallVector<Value> newConditionArgs =
+        getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices,
+                         equivalentYieldsBefore, state);
+    newConditionOp.getArgsMutable().assign(newConditionArgs);
+
+    // Set up new iter_args and move the loop body block to the new op.
+    // The old block uses tensors, so wrap the (memref) bbArgs of the new block
+    // in ToTensorOps.
+    rewriter.setInsertionPointToStart(newAfterBody);
+    SmallVector<Value> newAfterArgs =
+        getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices);
+    rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
+
+    // Update scf.yield of the new loop.
+    auto newYieldOp = newWhileOp.getYieldOp();
+    rewriter.setInsertionPoint(newYieldOp);
+    SmallVector<Value> newYieldValues =
+        getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices,
+                         equivalentYieldsAfter, state);
+    newYieldOp.getResultsMutable().assign(newYieldValues);
+
+    // Replace loop results.
+    replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
+
+    return success();
+  }
+
+  /// Assert that yielded values of an scf.while op are equivalent to their
+  /// corresponding bbArgs. In that case, the buffer relations of the
+  /// corresponding OpResults are "Equivalent".
+  ///
+  /// If this is not the case, allocs+copies are inserted and yielded from
+  /// the loop. This could be a performance problem, so it must be explicitly
+  /// activated with `alloc-return-allocs`.
+  ///
+  /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
+  /// equivalence condition must be checked for both.
+  LogicalResult verifyAnalysis(Operation *op,
+                               const AnalysisState &state) const {
+    auto whileOp = cast<scf::WhileOp>(op);
+    const auto &options =
+        static_cast<const OneShotBufferizationOptions &>(state.getOptions());
+    if (options.allowReturnAllocs)
+      return success();
+
+    auto conditionOp = whileOp.getConditionOp();
+    for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
+      if (!it.value().getType().isa<TensorType>())
+        continue;
+      if (!state.areEquivalentBufferizedValues(
+              it.value(), conditionOp->getBlock()->getArgument(it.index())))
+        return conditionOp->emitError()
+               << "Condition arg #" << it.index()
+               << " is not equivalent to the corresponding iter bbArg";
+    }
+
+    auto yieldOp = whileOp.getYieldOp();
+    for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
+      if (!it.value().getType().isa<TensorType>())
+        continue;
+      if (!state.areEquivalentBufferizedValues(
+              it.value(), yieldOp->getBlock()->getArgument(it.index())))
+        return yieldOp->emitError()
+               << "Yield operand #" << it.index()
+               << " is not equivalent to the corresponding iter bbArg";
+    }
+
+    return success();
+  }
+};
+
 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
 /// this is for analysis only.
 struct YieldOpInterface
@@ -581,7 +772,7 @@ struct YieldOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto yieldOp = cast<scf::YieldOp>(op);
-    if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
+    if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
             yieldOp->getParentOp()))
       return yieldOp->emitError("unsupported scf::YieldOp parent");
     return success();
@@ -598,6 +789,7 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
     ForOp::attachInterface<ForOpInterface>(*ctx);
     IfOp::attachInterface<IfOpInterface>(*ctx);
+    WhileOp::attachInterface<WhileOpInterface>(*ctx);
     YieldOp::attachInterface<YieldOpInterface>(*ctx);
   });
 }

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index d92eaff1b2522..8ab28773d7f7c 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -110,6 +110,54 @@ func.func @scf_for(%A : tensor<?xf32>,
 
 // -----
 
+func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
+                                         %arg1: tensor<5xi1>,
+                                         %idx: index) -> (i1, i1)
+{
+  %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
+      : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
+    %condition = tensor.extract %w0[%idx] : tensor<5xi1>
+    // expected-error @+1 {{Condition arg #0 is not equivalent to the corresponding iter bbArg}}
+    scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
+  } do {
+  ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
+    %pos = "dummy.some_op"() : () -> (index)
+    %val = "dummy.another_op"() : () -> (i1)
+    %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
+    scf.yield %1, %b1 : tensor<5xi1>, tensor<5xi1>
+  }
+
+  %v0 = tensor.extract %r0[%idx] : tensor<5xi1>
+  %v1 = tensor.extract %r1[%idx] : tensor<5xi1>
+  return %v0, %v1 : i1, i1
+}
+
+// -----
+
+func.func @scf_while_non_equiv_yield(%arg0: tensor<5xi1>,
+                                     %arg1: tensor<5xi1>,
+                                     %idx: index) -> (i1, i1)
+{
+  %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
+      : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
+    %condition = tensor.extract %w0[%idx] : tensor<5xi1>
+    scf.condition(%condition) %w0, %w1 : tensor<5xi1>, tensor<5xi1>
+  } do {
+  ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
+    %pos = "dummy.some_op"() : () -> (index)
+    %val = "dummy.another_op"() : () -> (i1)
+    %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
+    // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}}
+    scf.yield %b1, %1 : tensor<5xi1>, tensor<5xi1>
+  }
+
+  %v0 = tensor.extract %r0[%idx] : tensor<5xi1>
+  %v1 = tensor.extract %r1[%idx] : tensor<5xi1>
+  return %v0, %v1 : i1, i1
+}
+
+// -----
+
 func.func private @fun_with_side_effects(%A: tensor<?xf32> {bufferization.writable = true})
 
 func.func @foo(%A: tensor<?xf32> {bufferization.writable = true}) -> (tensor<?xf32>) {

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 1b6fd99147970..22b5e41364c03 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -1,12 +1,12 @@
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file | FileCheck %s
 
 // Run fuzzer with 
diff erent seeds.
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null
 
 // Test bufferization using memref types that have no layout map.
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs fully-dynamic-layout-maps=0 bufferize-function-boundaries" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs fully-dynamic-layout-maps=0 bufferize-function-boundaries" -split-input-file -o /dev/null
 
 // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
 
@@ -328,3 +328,124 @@ func.func @scf_for_swapping_yields(
 //       CHECK:     return %[[r0]], %[[r1]]
   return %f0, %f1: f32, f32
 }
+
+// -----
+
+// CHECK-LABEL: func @scf_while(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xi1, #{{.*}}>
+func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
+  // CHECK: scf.while : () -> () {
+  %res = scf.while (%arg1 = %arg0) : (tensor<?xi1>) -> tensor<?xi1> {
+    // CHECK: %[[condition:.*]] = memref.load %[[arg0]]
+    // CHECK: scf.condition(%[[condition]])
+    %condition = tensor.extract %arg1[%idx] : tensor<?xi1>
+    scf.condition(%condition) %arg1 : tensor<?xi1>
+  } do {
+  ^bb0(%arg2: tensor<?xi1>):
+    // CHECK: } do {
+    // CHECK: memref.store %{{.*}}, %[[arg0]]
+    // CHECK: scf.yield
+    // CHECK: }
+    %pos = "dummy.some_op"() : () -> (index)
+    %val = "dummy.another_op"() : () -> (i1)
+    %1 = tensor.insert %val into %arg2[%pos] : tensor<?xi1>
+    scf.yield %1 : tensor<?xi1>
+  }
+
+  // CHECK: return
+  return %res : tensor<?xi1>
+}
+
+// -----
+
+// The loop condition yields non-equivalent buffers.
+
+// CHECK-LABEL: func @scf_while_non_equiv_condition(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}>
+func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
+                                         %arg1: tensor<5xi1>,
+                                         %idx: index)
+  -> (tensor<5xi1>, tensor<5xi1>)
+{
+  // These allocation used to be inside the scf.while loop, but they were
+  // hoisted.
+  // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+  // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+  // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} {
+  %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
+      : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
+    // CHECK: %[[condition:.*]] = memref.load %[[w0]]
+    // CHECK: memref.copy %[[w1]], %[[a1]]
+    // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
+    // CHECK: memref.copy %[[w0]], %[[a0]]
+    // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
+    // CHECK: scf.condition(%[[condition]]) %[[casted1]], %[[casted0]]
+    %condition = tensor.extract %w0[%idx] : tensor<5xi1>
+    scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
+  } do {
+  ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
+    // CHECK: } do {
+    // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}):
+    // CHECK: memref.store %{{.*}}, %[[b0]]
+    // CHECK: scf.yield %[[b0]], %[[b1]]
+    // CHECK: }
+    %pos = "dummy.some_op"() : () -> (index)
+    %val = "dummy.another_op"() : () -> (i1)
+    %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
+    scf.yield %1, %b1 : tensor<5xi1>, tensor<5xi1>
+  }
+
+  // CHECK: return %[[loop]]#0, %[[loop]]#1
+  return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
+}
+
+// -----
+
+// Both the loop condition and the loop buffer yield non-equivalent buffers.
+
+// CHECK-LABEL: func @scf_while_non_equiv_condition_and_body(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}>
+func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
+                                                  %arg1: tensor<5xi1>,
+                                                  %idx: index)
+  -> (tensor<5xi1>, tensor<5xi1>)
+{
+  // These allocation used to be inside the scf.while loop, but they were
+  // hoisted.
+  // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+  // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+  // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+  // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+  // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} {
+  %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
+      : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
+    // CHECK: %[[condition:.*]] = memref.load %[[w0]]
+    // CHECK: memref.copy %[[w1]], %[[a3]]
+    // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
+    // CHECK: memref.copy %[[w0]], %[[a2]]
+    // CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
+    // CHECK: scf.condition(%[[condition]]) %[[casted3]], %[[casted2]]
+    %condition = tensor.extract %w0[%idx] : tensor<5xi1>
+    scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
+  } do {
+  ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
+    // CHECK: } do {
+    // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}):
+    // CHECK: memref.store %{{.*}}, %[[b0]]
+    // CHECK: memref.copy %[[b1]], %[[a1]]
+    // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
+    // CHECK: memref.copy %[[b0]], %[[a0]]
+    // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
+    // CHECK: scf.yield %[[casted1]], %[[casted0]]
+    // CHECK: }
+    %pos = "dummy.some_op"() : () -> (index)
+    %val = "dummy.another_op"() : () -> (i1)
+    %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
+    scf.yield %b1, %1 : tensor<5xi1>, tensor<5xi1>
+  }
+
+  // CHECK-DAG: memref.dealloc %[[a0]]
+  // CHECK-DAG: memref.dealloc %[[a1]]
+  // CHECK: return %[[loop]]#0, %[[loop]]#1
+  return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
+}


        


More information about the Mlir-commits mailing list