[Mlir-commits] [mlir] fbce985 - [mlir][NFC] Move conversion of scf to spir-v ops in their own file

Thomas Raoux llvmlistbot at llvm.org
Wed Jul 1 17:07:12 PDT 2020


Author: Thomas Raoux
Date: 2020-07-01T17:06:50-07:00
New Revision: fbce9855e9d5483f724d231dd4ecc2b79807d217

URL: https://github.com/llvm/llvm-project/commit/fbce9855e9d5483f724d231dd4ecc2b79807d217
DIFF: https://github.com/llvm/llvm-project/commit/fbce9855e9d5483f724d231dd4ecc2b79807d217.diff

LOG: [mlir][NFC] Move conversion of scf to spir-v ops in their own file

Move patterns for scf to spir-v ops in their own file/folder.

Differential Revision: https://reviews.llvm.org/D82914

Added: 
    mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
    mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Modified: 
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
new file mode 100644
index 000000000000..95173717ceec
--- /dev/null
+++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
@@ -0,0 +1,32 @@
+//===------------ SCFToSPIRV.h - Pass entrypoint ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Provides patterns for lowering SCF ops to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_
+#define MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_
+
+#include <memory>
+
+namespace mlir {
+class MLIRContext;
+class Pass;
+
+// Owning list of rewriting patterns.
+class OwningRewritePatternList;
+class SPIRVTypeConverter;
+
+/// Collects a set of patterns to lower from scf.for, scf.if, and
+/// loop.terminator to CFG operations within the SPIR-V dialect.
+void populateSCFToSPIRVPatterns(MLIRContext *context,
+                                SPIRVTypeConverter &typeConverter,
+                                OwningRewritePatternList &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e63b44cff782..57602881ab7c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(LinalgToLLVM)
 add_subdirectory(LinalgToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(SCFToGPU)
+add_subdirectory(SCFToSPIRV)
 add_subdirectory(SCFToStandard)
 add_subdirectory(ShapeToSCF)
 add_subdirectory(ShapeToStandard)

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
index fbf6b0787091..cce793fe5a6e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRGPUToSPIRVTransforms
   MLIRGPU
   MLIRIR
   MLIRPass
+  MLIRSCFToSPIRV
   MLIRSPIRV
   MLIRStandardOps
   MLIRStandardToSPIRVTransforms

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
index 2b4829adcdeb..2845611a920a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
@@ -11,7 +11,6 @@
 //===----------------------------------------------------------------------===//
 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
@@ -20,41 +19,6 @@
 using namespace mlir;
 
 namespace {
-
-/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
-class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
-public:
-  using SPIRVOpLowering<scf::ForOp>::SPIRVOpLowering;
-
-  LogicalResult
-  matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Pattern to convert a scf::IfOp within kernel functions into
-/// spirv::SelectionOp.
-class IfOpConversion final : public SPIRVOpLowering<scf::IfOp> {
-public:
-  using SPIRVOpLowering<scf::IfOp>::SPIRVOpLowering;
-
-  LogicalResult
-  matchAndRewrite(scf::IfOp IfOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Pattern to erase a scf::YieldOp.
-class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
-public:
-  using SPIRVOpLowering<scf::YieldOp>::SPIRVOpLowering;
-
-  LogicalResult
-  matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.eraseOp(terminatorOp);
-    return success();
-  }
-};
-
 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
 /// builtin variables.
 template <typename SourceOp, spirv::BuiltIn builtin>
@@ -128,134 +92,6 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
 
 } // namespace
 
-//===----------------------------------------------------------------------===//
-// scf::ForOp.
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
-                                 ConversionPatternRewriter &rewriter) const {
-  // scf::ForOp can be lowered to the structured control flow represented by
-  // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
-  // latch and the merge block the exit block. The resulting spirv::LoopOp has a
-  // single back edge from the continue to header block, and a single exit from
-  // header to merge.
-  scf::ForOpAdaptor forOperands(operands);
-  auto loc = forOp.getLoc();
-  auto loopControl = rewriter.getI32IntegerAttr(
-      static_cast<uint32_t>(spirv::LoopControl::None));
-  auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
-  loopOp.addEntryAndMergeBlock();
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  // Create the block for the header.
-  auto header = new Block();
-  // Insert the header.
-  loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
-
-  // Create the new induction variable to use.
-  BlockArgument newIndVar =
-      header->addArgument(forOperands.lowerBound().getType());
-  Block *body = forOp.getBody();
-
-  // Apply signature conversion to the body of the forOp. It has a single block,
-  // with argument which is the induction variable. That has to be replaced with
-  // the new induction variable.
-  TypeConverter::SignatureConversion signatureConverter(
-      body->getNumArguments());
-  signatureConverter.remapInput(0, newIndVar);
-  FailureOr<Block *> newBody = rewriter.convertRegionTypes(
-      &forOp.getLoopBody(), typeConverter, &signatureConverter);
-  if (failed(newBody))
-    return failure();
-  body = *newBody;
-
-  // Delete the loop terminator.
-  rewriter.eraseOp(body->getTerminator());
-
-  // Move the blocks from the forOp into the loopOp. This is the body of the
-  // loopOp.
-  rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
-                              std::next(loopOp.body().begin(), 2));
-
-  // Branch into it from the entry.
-  rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
-  rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
-
-  // Generate the rest of the loop header.
-  rewriter.setInsertionPointToEnd(header);
-  auto mergeBlock = loopOp.getMergeBlock();
-  auto cmpOp = rewriter.create<spirv::SLessThanOp>(
-      loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
-  rewriter.create<spirv::BranchConditionalOp>(
-      loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
-
-  // Generate instructions to increment the step of the induction variable and
-  // branch to the header.
-  Block *continueBlock = loopOp.getContinueBlock();
-  rewriter.setInsertionPointToEnd(continueBlock);
-
-  // Add the step to the induction variable and branch to the header.
-  Value updatedIndVar = rewriter.create<spirv::IAddOp>(
-      loc, newIndVar.getType(), newIndVar, forOperands.step());
-  rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
-
-  rewriter.eraseOp(forOp);
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// scf::IfOp.
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
-                                ConversionPatternRewriter &rewriter) const {
-  // When lowering `scf::IfOp` we explicitly create a selection header block
-  // before the control flow diverges and a merge block where control flow
-  // subsequently converges.
-  scf::IfOpAdaptor ifOperands(operands);
-  auto loc = ifOp.getLoc();
-
-  // Create `spv.selection` operation, selection header block and merge block.
-  auto selectionControl = rewriter.getI32IntegerAttr(
-      static_cast<uint32_t>(spirv::SelectionControl::None));
-  auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
-  selectionOp.addMergeBlock();
-  auto *mergeBlock = selectionOp.getMergeBlock();
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  auto *selectionHeaderBlock = new Block();
-  selectionOp.body().getBlocks().push_front(selectionHeaderBlock);
-
-  // Inline `then` region before the merge block and branch to it.
-  auto &thenRegion = ifOp.thenRegion();
-  auto *thenBlock = &thenRegion.front();
-  rewriter.setInsertionPointToEnd(&thenRegion.back());
-  rewriter.create<spirv::BranchOp>(loc, mergeBlock);
-  rewriter.inlineRegionBefore(thenRegion, mergeBlock);
-
-  auto *elseBlock = mergeBlock;
-  // If `else` region is not empty, inline that region before the merge block
-  // and branch to it.
-  if (!ifOp.elseRegion().empty()) {
-    auto &elseRegion = ifOp.elseRegion();
-    elseBlock = &elseRegion.front();
-    rewriter.setInsertionPointToEnd(&elseRegion.back());
-    rewriter.create<spirv::BranchOp>(loc, mergeBlock);
-    rewriter.inlineRegionBefore(elseRegion, mergeBlock);
-  }
-
-  // Create a `spv.BranchConditional` operation for selection header block.
-  rewriter.setInsertionPointToEnd(selectionHeaderBlock);
-  rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
-                                              thenBlock, ArrayRef<Value>(),
-                                              elseBlock, ArrayRef<Value>());
-
-  rewriter.eraseOp(ifOp);
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Builtins.
 //===----------------------------------------------------------------------===//
@@ -479,8 +315,7 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
                                       OwningRewritePatternList &patterns) {
   populateWithGenerated(context, &patterns);
   patterns.insert<
-      ForOpConversion, GPUFuncOpConversion, GPUModuleConversion,
-      GPUReturnOpConversion, IfOpConversion,
+      GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::ThreadIdOp,
@@ -491,5 +326,5 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
                                       spirv::BuiltIn::NumSubgroups>,
       SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
                                       spirv::BuiltIn::SubgroupSize>,
-      TerminatorOpConversion, WorkGroupSizeConversion>(context, typeConverter);
+      WorkGroupSizeConversion>(context, typeConverter);
 }

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index 1f486b96e86c..c3bda25f0347 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
 #include "../PassDetail.h"
 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
+#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/SCF/SCF.h"
@@ -59,6 +60,7 @@ void GPUToSPIRVPass::runOnOperation() {
   SPIRVTypeConverter typeConverter(targetAttr);
   OwningRewritePatternList patterns;
   populateGPUToSPIRVPatterns(context, typeConverter, patterns);
+  populateSCFToSPIRVPatterns(context, typeConverter, patterns);
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
 
   if (failed(applyFullConversion(kernelModules, *target, patterns)))

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt
new file mode 100644
index 000000000000..6d95813d717f
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_conversion_library(MLIRSCFToSPIRV
+  SCFToSPIRV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToSPIRV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRAffineOps
+  MLIRAffineToStandard
+  MLIRSPIRV
+  MLIRIR
+  MLIRLinalgOps
+  MLIRPass
+  MLIRStandardOps
+  MLIRSupport
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
new file mode 100644
index 000000000000..a6a08b10fc63
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -0,0 +1,191 @@
+//===- SCFToSPIRV.cpp - Convert SCF ops to SPIR-V dialect -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the conversion patterns from SCF ops to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Module.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
+class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
+public:
+  using SPIRVOpLowering<scf::ForOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern to convert a scf::IfOp within kernel functions into
+/// spirv::SelectionOp.
+class IfOpConversion final : public SPIRVOpLowering<scf::IfOp> {
+public:
+  using SPIRVOpLowering<scf::IfOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern to erase a scf::YieldOp.
+class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
+public:
+  using SPIRVOpLowering<scf::YieldOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.eraseOp(terminatorOp);
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// scf::ForOp.
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
+                                 ConversionPatternRewriter &rewriter) const {
+  // scf::ForOp can be lowered to the structured control flow represented by
+  // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
+  // latch and the merge block the exit block. The resulting spirv::LoopOp has a
+  // single back edge from the continue to header block, and a single exit from
+  // header to merge.
+  scf::ForOpAdaptor forOperands(operands);
+  auto loc = forOp.getLoc();
+  auto loopControl = rewriter.getI32IntegerAttr(
+      static_cast<uint32_t>(spirv::LoopControl::None));
+  auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
+  loopOp.addEntryAndMergeBlock();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  // Create the block for the header.
+  auto *header = new Block();
+  // Insert the header.
+  loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
+
+  // Create the new induction variable to use.
+  BlockArgument newIndVar =
+      header->addArgument(forOperands.lowerBound().getType());
+  Block *body = forOp.getBody();
+
+  // Apply signature conversion to the body of the forOp. It has a single block,
+  // with argument which is the induction variable. That has to be replaced with
+  // the new induction variable.
+  TypeConverter::SignatureConversion signatureConverter(
+      body->getNumArguments());
+  signatureConverter.remapInput(0, newIndVar);
+  FailureOr<Block *> newBody = rewriter.convertRegionTypes(
+      &forOp.getLoopBody(), typeConverter, &signatureConverter);
+  if (failed(newBody))
+    return failure();
+  body = *newBody;
+
+  // Delete the loop terminator.
+  rewriter.eraseOp(body->getTerminator());
+
+  // Move the blocks from the forOp into the loopOp. This is the body of the
+  // loopOp.
+  rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
+                              std::next(loopOp.body().begin(), 2));
+
+  // Branch into it from the entry.
+  rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
+  rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
+
+  // Generate the rest of the loop header.
+  rewriter.setInsertionPointToEnd(header);
+  auto *mergeBlock = loopOp.getMergeBlock();
+  auto cmpOp = rewriter.create<spirv::SLessThanOp>(
+      loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
+  rewriter.create<spirv::BranchConditionalOp>(
+      loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
+
+  // Generate instructions to increment the step of the induction variable and
+  // branch to the header.
+  Block *continueBlock = loopOp.getContinueBlock();
+  rewriter.setInsertionPointToEnd(continueBlock);
+
+  // Add the step to the induction variable and branch to the header.
+  Value updatedIndVar = rewriter.create<spirv::IAddOp>(
+      loc, newIndVar.getType(), newIndVar, forOperands.step());
+  rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
+
+  rewriter.eraseOp(forOp);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// scf::IfOp.
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
+                                ConversionPatternRewriter &rewriter) const {
+  // When lowering `scf::IfOp` we explicitly create a selection header block
+  // before the control flow diverges and a merge block where control flow
+  // subsequently converges.
+  scf::IfOpAdaptor ifOperands(operands);
+  auto loc = ifOp.getLoc();
+
+  // Create `spv.selection` operation, selection header block and merge block.
+  auto selectionControl = rewriter.getI32IntegerAttr(
+      static_cast<uint32_t>(spirv::SelectionControl::None));
+  auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
+  selectionOp.addMergeBlock();
+  auto *mergeBlock = selectionOp.getMergeBlock();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  auto *selectionHeaderBlock = new Block();
+  selectionOp.body().getBlocks().push_front(selectionHeaderBlock);
+
+  // Inline `then` region before the merge block and branch to it.
+  auto &thenRegion = ifOp.thenRegion();
+  auto *thenBlock = &thenRegion.front();
+  rewriter.setInsertionPointToEnd(&thenRegion.back());
+  rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+  rewriter.inlineRegionBefore(thenRegion, mergeBlock);
+
+  auto *elseBlock = mergeBlock;
+  // If `else` region is not empty, inline that region before the merge block
+  // and branch to it.
+  if (!ifOp.elseRegion().empty()) {
+    auto &elseRegion = ifOp.elseRegion();
+    elseBlock = &elseRegion.front();
+    rewriter.setInsertionPointToEnd(&elseRegion.back());
+    rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+    rewriter.inlineRegionBefore(elseRegion, mergeBlock);
+  }
+
+  // Create a `spv.BranchConditional` operation for selection header block.
+  rewriter.setInsertionPointToEnd(selectionHeaderBlock);
+  rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
+                                              thenBlock, ArrayRef<Value>(),
+                                              elseBlock, ArrayRef<Value>());
+
+  rewriter.eraseOp(ifOp);
+  return success();
+}
+
+void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
+                                      SPIRVTypeConverter &typeConverter,
+                                      OwningRewritePatternList &patterns) {
+  patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
+      context, typeConverter);
+}


        


More information about the Mlir-commits mailing list