[Mlir-commits] [mlir] 526b71e - [mlir] spirv: Add scf.while spirv conversion

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 16 02:20:12 PST 2021


Author: Butygin
Date: 2021-11-16T13:19:34+03:00
New Revision: 526b71e44acd4f792e36df62c4654706f3e26efe

URL: https://github.com/llvm/llvm-project/commit/526b71e44acd4f792e36df62c4654706f3e26efe
DIFF: https://github.com/llvm/llvm-project/commit/526b71e44acd4f792e36df62c4654706f3e26efe.diff

LOG: [mlir] spirv: Add scf.while spirv conversion

* It works similar to scf.for coversion, but convert condition and yield ops as part of scf.whille pattern so it don't need to maintain external state

Differential Revision: https://reviews.llvm.org/D113007

Added: 
    mlir/test/Conversion/SCFToSPIRV/while.mlir

Modified: 
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index f581a0a2dfd5f..ffa50cac392fa 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -107,6 +107,15 @@ class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
   matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
+
+class WhileOpConversion final : public SCFToSPIRVPattern<scf::WhileOp> {
+public:
+  using SCFToSPIRVPattern<scf::WhileOp>::SCFToSPIRVPattern;
+
+  LogicalResult
+  matchAndRewrite(scf::WhileOp forOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
 } // namespace
 
 /// Helper function to replaces SCF op outputs with SPIR-V variable loads.
@@ -141,6 +150,10 @@ static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
   rewriter.replaceOp(scfOp, resultValue);
 }
 
+static Region::iterator getBlockIt(Region &region, unsigned index) {
+  return std::next(region.begin(), index);
+}
+
 //===----------------------------------------------------------------------===//
 // scf::ForOp
 //===----------------------------------------------------------------------===//
@@ -161,7 +174,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
   // Create the block for the header.
   auto *header = new Block();
   // Insert the header.
-  loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
+  loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header);
 
   // Create the new induction variable to use.
   BlockArgument newIndVar = header->addArgument(adaptor.lowerBound().getType());
@@ -183,7 +196,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
   // Move the blocks from the forOp into the loopOp. This is the body of the
   // loopOp.
   rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
-                              std::next(loopOp.body().begin(), 2));
+                              getBlockIt(loopOp.body(), 2));
 
   SmallVector<Value, 8> args(1, adaptor.lowerBound());
   args.append(adaptor.initArgs().begin(), adaptor.initArgs().end());
@@ -293,9 +306,11 @@ LogicalResult TerminatorOpConversion::matchAndRewrite(
   // If the region is return values, store each value into the associated
   // VariableOp created during lowering of the parent region.
   if (!operands.empty()) {
-    auto loc = terminatorOp.getLoc();
     auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
-    assert(allocas.size() == operands.size());
+    if (allocas.size() != operands.size())
+      return failure();
+
+    auto loc = terminatorOp.getLoc();
     for (unsigned i = 0, e = operands.size(); i < e; i++)
       rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
     if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
@@ -314,6 +329,97 @@ LogicalResult TerminatorOpConversion::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// scf::WhileOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
+                                   ConversionPatternRewriter &rewriter) const {
+  auto loc = whileOp.getLoc();
+  auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+  loopOp.addEntryAndMergeBlock();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+
+  Region &beforeRegion = whileOp.before();
+  Region &afterRegion = whileOp.after();
+
+  Block &entryBlock = *loopOp.getEntryBlock();
+  Block &beforeBlock = beforeRegion.front();
+  Block &afterBlock = afterRegion.front();
+  Block &mergeBlock = *loopOp.getMergeBlock();
+
+  auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
+  SmallVector<Value> condArgs;
+  if (failed(rewriter.getRemappedValues(cond.args(), condArgs)))
+    return failure();
+
+  Value conditionVal = rewriter.getRemappedValue(cond.condition());
+  if (!conditionVal)
+    return failure();
+
+  auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
+  SmallVector<Value> yieldArgs;
+  if (failed(rewriter.getRemappedValues(yield.results(), yieldArgs)))
+    return failure();
+
+  // Move the while before block as the initial loop header block.
+  rewriter.inlineRegionBefore(beforeRegion, loopOp.body(),
+                              getBlockIt(loopOp.body(), 1));
+
+  // Move the while after block as the initial loop body block.
+  rewriter.inlineRegionBefore(afterRegion, loopOp.body(),
+                              getBlockIt(loopOp.body(), 2));
+
+  // Jump from the loop entry block to the loop header block.
+  rewriter.setInsertionPointToEnd(&entryBlock);
+  rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.inits());
+
+  auto condLoc = cond.getLoc();
+
+  SmallVector<Value> resultValues(condArgs.size());
+
+  // For other SCF ops, the scf.yield op yields the value for the whole SCF op.
+  // So we use the scf.yield op as the anchor to create/load/store SPIR-V local
+  // variables. But for the scf.while op, the scf.yield op yields a value for
+  // the before region, which may not matching the whole op's result. Instead,
+  // the scf.condition op returns values matching the whole op's results. So we
+  // need to create/load/store variables according to that.
+  for (auto it : llvm::enumerate(condArgs)) {
+    auto res = it.value();
+    auto i = it.index();
+    auto pointerType =
+        spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
+
+    // Create local variables before the scf.while op.
+    rewriter.setInsertionPoint(loopOp);
+    auto alloc = rewriter.create<spirv::VariableOp>(
+        condLoc, pointerType, spirv::StorageClass::Function,
+        /*initializer=*/nullptr);
+
+    // Load the final result values after the scf.while op.
+    rewriter.setInsertionPointAfter(loopOp);
+    auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
+    resultValues[i] = loadResult;
+
+    // Store the current iteration's result value.
+    rewriter.setInsertionPointToEnd(&beforeBlock);
+    rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
+  }
+
+  rewriter.setInsertionPointToEnd(&beforeBlock);
+  rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
+      cond, conditionVal, &afterBlock, condArgs, &mergeBlock, llvm::None);
+
+  // Convert the scf.yield op to a branch back to the header block.
+  rewriter.setInsertionPointToEnd(&afterBlock);
+  rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock, yieldArgs);
+
+  rewriter.replaceOp(whileOp, resultValues);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Hooks
 //===----------------------------------------------------------------------===//
@@ -321,6 +427,7 @@ LogicalResult TerminatorOpConversion::matchAndRewrite(
 void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                       ScfToSPIRVContext &scfToSPIRVContext,
                                       RewritePatternSet &patterns) {
-  patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
-      patterns.getContext(), typeConverter, scfToSPIRVContext.getImpl());
+  patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
+               WhileOpConversion>(patterns.getContext(), typeConverter,
+                                  scfToSPIRVContext.getImpl());
 }

diff  --git a/mlir/test/Conversion/SCFToSPIRV/while.mlir b/mlir/test/Conversion/SCFToSPIRV/while.mlir
new file mode 100644
index 0000000000000..d44ea1e98404e
--- /dev/null
+++ b/mlir/test/Conversion/SCFToSPIRV/while.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-spirv %s -o - | FileCheck %s
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader, Int64], [SPV_KHR_storage_buffer_storage_class]>, {}>
+} {
+
+// CHECK-LABEL: @while_loop1
+func @while_loop1(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK-SAME: (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+  // CHECK: %[[INITVAR:.*]] = spv.Constant 2 : i32
+  // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<i32, Function>
+  // CHECK: spv.mlir.loop {
+  // CHECK:   spv.Branch ^[[HEADER:.*]](%[[ARG1]] : i32)
+  // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: i32):
+  // CHECK:   %[[CMP:.*]] = spv.SLessThan %[[INDVAR1]], %[[ARG2]] : i32
+  // CHECK:   spv.Store "Function" %[[VAR1]], %[[INDVAR1]] : i32
+  // CHECK:   spv.BranchConditional %[[CMP]], ^[[BODY:.*]](%[[INDVAR1]] : i32), ^[[MERGE:.*]]
+  // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i32):
+  // CHECK:   %[[UPDATED:.*]] = spv.IMul %[[INDVAR2]], %[[INITVAR]] : i32
+  // CHECK: spv.Branch ^[[HEADER]](%[[UPDATED]] : i32)
+  // CHECK: ^[[MERGE]]:
+  // CHECK:   spv.mlir.merge
+  // CHECK: }
+  %c2_i32 = arith.constant 2 : i32
+  %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : i32
+    scf.condition(%1) %arg3 : i32
+  } do {
+  ^bb0(%arg5: i32):
+    %1 = arith.muli %arg5, %c2_i32 : i32
+    scf.yield %1 : i32
+  }
+  // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR1]] : i32
+  // CHECK: spv.ReturnValue %[[OUT]] : i32
+  return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @while_loop2
+func @while_loop2(%arg0: f32) -> i64 {
+  // CHECK-SAME: (%[[ARG:.*]]: f32)
+  // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr<i64, Function>
+  // CHECK: spv.mlir.loop {
+  // CHECK:   spv.Branch ^[[HEADER:.*]](%[[ARG]] : f32)
+  // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: f32):
+  // CHECK:   %[[SHARED:.*]] = "foo.shared_compute"(%[[INDVAR1]]) : (f32) -> i64
+  // CHECK:   %[[CMP:.*]] = "foo.evaluate_condition"(%[[INDVAR1]], %[[SHARED]]) : (f32, i64) -> i1
+  // CHECK:   spv.Store "Function" %[[VAR]], %[[SHARED]] : i64
+  // CHECK:   spv.BranchConditional %[[CMP]], ^[[BODY:.*]](%[[SHARED]] : i64), ^[[MERGE:.*]]
+  // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i64):
+  // CHECK:   %[[UPDATED:.*]] = "foo.payload"(%[[INDVAR2]]) : (i64) -> f32
+  // CHECK: spv.Branch ^[[HEADER]](%[[UPDATED]] : f32)
+  // CHECK: ^[[MERGE]]:
+  // CHECK:   spv.mlir.merge
+  // CHECK: }
+  %res = scf.while (%arg1 = %arg0) : (f32) -> i64 {
+    %shared = "foo.shared_compute"(%arg1) : (f32) -> i64
+    %condition = "foo.evaluate_condition"(%arg1, %shared) : (f32, i64) -> i1
+    scf.condition(%condition) %shared : i64
+  } do {
+  ^bb0(%arg2: i64):
+    %res = "foo.payload"(%arg2) : (i64) -> f32
+    scf.yield %res : f32
+  }
+  // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : i64
+  // CHECK: spv.ReturnValue %[[OUT]] : i64
+  return %res : i64
+}
+
+} // end module


        


More information about the Mlir-commits mailing list