[Mlir-commits] [mlir] fbce985 - [mlir][NFC] Move conversion of scf to spir-v ops in their own file
Thomas Raoux
llvmlistbot at llvm.org
Wed Jul 1 17:07:12 PDT 2020
Author: Thomas Raoux
Date: 2020-07-01T17:06:50-07:00
New Revision: fbce9855e9d5483f724d231dd4ecc2b79807d217
URL: https://github.com/llvm/llvm-project/commit/fbce9855e9d5483f724d231dd4ecc2b79807d217
DIFF: https://github.com/llvm/llvm-project/commit/fbce9855e9d5483f724d231dd4ecc2b79807d217.diff
LOG: [mlir][NFC] Move conversion of scf to spir-v ops in their own file
Move patterns for scf to spir-v ops in their own file/folder.
Differential Revision: https://reviews.llvm.org/D82914
Added:
mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
Modified:
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
new file mode 100644
index 000000000000..95173717ceec
--- /dev/null
+++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
@@ -0,0 +1,32 @@
+//===------------ SCFToSPIRV.h - Pass entrypoint ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Provides patterns for lowering SCF ops to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_
+#define MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_
+
+#include <memory>
+
+namespace mlir {
+class MLIRContext;
+class Pass;
+
+// Owning list of rewriting patterns.
+class OwningRewritePatternList;
+class SPIRVTypeConverter;
+
+/// Collects a set of patterns to lower from scf.for, scf.if, and
+/// loop.terminator to CFG operations within the SPIR-V dialect.
+void populateSCFToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e63b44cff782..57602881ab7c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(LinalgToLLVM)
add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(SCFToGPU)
+add_subdirectory(SCFToSPIRV)
add_subdirectory(SCFToStandard)
add_subdirectory(ShapeToSCF)
add_subdirectory(ShapeToStandard)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
index fbf6b0787091..cce793fe5a6e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRGPUToSPIRVTransforms
MLIRGPU
MLIRIR
MLIRPass
+ MLIRSCFToSPIRV
MLIRSPIRV
MLIRStandardOps
MLIRStandardToSPIRVTransforms
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
index 2b4829adcdeb..2845611a920a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
@@ -11,7 +11,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
@@ -20,41 +19,6 @@
using namespace mlir;
namespace {
-
-/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
-class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
-public:
- using SPIRVOpLowering<scf::ForOp>::SPIRVOpLowering;
-
- LogicalResult
- matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Pattern to convert a scf::IfOp within kernel functions into
-/// spirv::SelectionOp.
-class IfOpConversion final : public SPIRVOpLowering<scf::IfOp> {
-public:
- using SPIRVOpLowering<scf::IfOp>::SPIRVOpLowering;
-
- LogicalResult
- matchAndRewrite(scf::IfOp IfOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Pattern to erase a scf::YieldOp.
-class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
-public:
- using SPIRVOpLowering<scf::YieldOp>::SPIRVOpLowering;
-
- LogicalResult
- matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.eraseOp(terminatorOp);
- return success();
- }
-};
-
/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
/// builtin variables.
template <typename SourceOp, spirv::BuiltIn builtin>
@@ -128,134 +92,6 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
} // namespace
-//===----------------------------------------------------------------------===//
-// scf::ForOp.
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- // scf::ForOp can be lowered to the structured control flow represented by
- // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
- // latch and the merge block the exit block. The resulting spirv::LoopOp has a
- // single back edge from the continue to header block, and a single exit from
- // header to merge.
- scf::ForOpAdaptor forOperands(operands);
- auto loc = forOp.getLoc();
- auto loopControl = rewriter.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::LoopControl::None));
- auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
- loopOp.addEntryAndMergeBlock();
-
- OpBuilder::InsertionGuard guard(rewriter);
- // Create the block for the header.
- auto header = new Block();
- // Insert the header.
- loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
-
- // Create the new induction variable to use.
- BlockArgument newIndVar =
- header->addArgument(forOperands.lowerBound().getType());
- Block *body = forOp.getBody();
-
- // Apply signature conversion to the body of the forOp. It has a single block,
- // with argument which is the induction variable. That has to be replaced with
- // the new induction variable.
- TypeConverter::SignatureConversion signatureConverter(
- body->getNumArguments());
- signatureConverter.remapInput(0, newIndVar);
- FailureOr<Block *> newBody = rewriter.convertRegionTypes(
- &forOp.getLoopBody(), typeConverter, &signatureConverter);
- if (failed(newBody))
- return failure();
- body = *newBody;
-
- // Delete the loop terminator.
- rewriter.eraseOp(body->getTerminator());
-
- // Move the blocks from the forOp into the loopOp. This is the body of the
- // loopOp.
- rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
- std::next(loopOp.body().begin(), 2));
-
- // Branch into it from the entry.
- rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
- rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
-
- // Generate the rest of the loop header.
- rewriter.setInsertionPointToEnd(header);
- auto mergeBlock = loopOp.getMergeBlock();
- auto cmpOp = rewriter.create<spirv::SLessThanOp>(
- loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
- rewriter.create<spirv::BranchConditionalOp>(
- loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
-
- // Generate instructions to increment the step of the induction variable and
- // branch to the header.
- Block *continueBlock = loopOp.getContinueBlock();
- rewriter.setInsertionPointToEnd(continueBlock);
-
- // Add the step to the induction variable and branch to the header.
- Value updatedIndVar = rewriter.create<spirv::IAddOp>(
- loc, newIndVar.getType(), newIndVar, forOperands.step());
- rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
-
- rewriter.eraseOp(forOp);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// scf::IfOp.
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- // When lowering `scf::IfOp` we explicitly create a selection header block
- // before the control flow diverges and a merge block where control flow
- // subsequently converges.
- scf::IfOpAdaptor ifOperands(operands);
- auto loc = ifOp.getLoc();
-
- // Create `spv.selection` operation, selection header block and merge block.
- auto selectionControl = rewriter.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::SelectionControl::None));
- auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
- selectionOp.addMergeBlock();
- auto *mergeBlock = selectionOp.getMergeBlock();
-
- OpBuilder::InsertionGuard guard(rewriter);
- auto *selectionHeaderBlock = new Block();
- selectionOp.body().getBlocks().push_front(selectionHeaderBlock);
-
- // Inline `then` region before the merge block and branch to it.
- auto &thenRegion = ifOp.thenRegion();
- auto *thenBlock = &thenRegion.front();
- rewriter.setInsertionPointToEnd(&thenRegion.back());
- rewriter.create<spirv::BranchOp>(loc, mergeBlock);
- rewriter.inlineRegionBefore(thenRegion, mergeBlock);
-
- auto *elseBlock = mergeBlock;
- // If `else` region is not empty, inline that region before the merge block
- // and branch to it.
- if (!ifOp.elseRegion().empty()) {
- auto &elseRegion = ifOp.elseRegion();
- elseBlock = &elseRegion.front();
- rewriter.setInsertionPointToEnd(&elseRegion.back());
- rewriter.create<spirv::BranchOp>(loc, mergeBlock);
- rewriter.inlineRegionBefore(elseRegion, mergeBlock);
- }
-
- // Create a `spv.BranchConditional` operation for selection header block.
- rewriter.setInsertionPointToEnd(selectionHeaderBlock);
- rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
- thenBlock, ArrayRef<Value>(),
- elseBlock, ArrayRef<Value>());
-
- rewriter.eraseOp(ifOp);
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Builtins.
//===----------------------------------------------------------------------===//
@@ -479,8 +315,7 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
patterns.insert<
- ForOpConversion, GPUFuncOpConversion, GPUModuleConversion,
- GPUReturnOpConversion, IfOpConversion,
+ GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::ThreadIdOp,
@@ -491,5 +326,5 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
spirv::BuiltIn::NumSubgroups>,
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
spirv::BuiltIn::SubgroupSize>,
- TerminatorOpConversion, WorkGroupSizeConversion>(context, typeConverter);
+ WorkGroupSizeConversion>(context, typeConverter);
}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index 1f486b96e86c..c3bda25f0347 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
#include "../PassDetail.h"
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
+#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/SCF/SCF.h"
@@ -59,6 +60,7 @@ void GPUToSPIRVPass::runOnOperation() {
SPIRVTypeConverter typeConverter(targetAttr);
OwningRewritePatternList patterns;
populateGPUToSPIRVPatterns(context, typeConverter, patterns);
+ populateSCFToSPIRVPatterns(context, typeConverter, patterns);
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
if (failed(applyFullConversion(kernelModules, *target, patterns)))
diff --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt
new file mode 100644
index 000000000000..6d95813d717f
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_conversion_library(MLIRSCFToSPIRV
+ SCFToSPIRV.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToSPIRV
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAffineOps
+ MLIRAffineToStandard
+ MLIRSPIRV
+ MLIRIR
+ MLIRLinalgOps
+ MLIRPass
+ MLIRStandardOps
+ MLIRSupport
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
new file mode 100644
index 000000000000..a6a08b10fc63
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -0,0 +1,191 @@
+//===- SCFToSPIRV.cpp - Convert SCF ops to SPIR-V dialect -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the conversion patterns from SCF ops to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Module.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
+class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
+public:
+ using SPIRVOpLowering<scf::ForOp>::SPIRVOpLowering;
+
+ LogicalResult
+ matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern to convert a scf::IfOp within kernel functions into
+/// spirv::SelectionOp.
+class IfOpConversion final : public SPIRVOpLowering<scf::IfOp> {
+public:
+ using SPIRVOpLowering<scf::IfOp>::SPIRVOpLowering;
+
+ LogicalResult
+ matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern to erase a scf::YieldOp.
+class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
+public:
+ using SPIRVOpLowering<scf::YieldOp>::SPIRVOpLowering;
+
+ LogicalResult
+ matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.eraseOp(terminatorOp);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// scf::ForOp.
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // scf::ForOp can be lowered to the structured control flow represented by
+ // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
+ // latch and the merge block the exit block. The resulting spirv::LoopOp has a
+ // single back edge from the continue to header block, and a single exit from
+ // header to merge.
+ scf::ForOpAdaptor forOperands(operands);
+ auto loc = forOp.getLoc();
+ auto loopControl = rewriter.getI32IntegerAttr(
+ static_cast<uint32_t>(spirv::LoopControl::None));
+ auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
+ loopOp.addEntryAndMergeBlock();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Create the block for the header.
+ auto *header = new Block();
+ // Insert the header.
+ loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
+
+ // Create the new induction variable to use.
+ BlockArgument newIndVar =
+ header->addArgument(forOperands.lowerBound().getType());
+ Block *body = forOp.getBody();
+
+ // Apply signature conversion to the body of the forOp. It has a single block,
+ // with argument which is the induction variable. That has to be replaced with
+ // the new induction variable.
+ TypeConverter::SignatureConversion signatureConverter(
+ body->getNumArguments());
+ signatureConverter.remapInput(0, newIndVar);
+ FailureOr<Block *> newBody = rewriter.convertRegionTypes(
+ &forOp.getLoopBody(), typeConverter, &signatureConverter);
+ if (failed(newBody))
+ return failure();
+ body = *newBody;
+
+ // Delete the loop terminator.
+ rewriter.eraseOp(body->getTerminator());
+
+ // Move the blocks from the forOp into the loopOp. This is the body of the
+ // loopOp.
+ rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
+ std::next(loopOp.body().begin(), 2));
+
+ // Branch into it from the entry.
+ rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
+ rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
+
+ // Generate the rest of the loop header.
+ rewriter.setInsertionPointToEnd(header);
+ auto *mergeBlock = loopOp.getMergeBlock();
+ auto cmpOp = rewriter.create<spirv::SLessThanOp>(
+ loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
+ rewriter.create<spirv::BranchConditionalOp>(
+ loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
+
+ // Generate instructions to increment the step of the induction variable and
+ // branch to the header.
+ Block *continueBlock = loopOp.getContinueBlock();
+ rewriter.setInsertionPointToEnd(continueBlock);
+
+ // Add the step to the induction variable and branch to the header.
+ Value updatedIndVar = rewriter.create<spirv::IAddOp>(
+ loc, newIndVar.getType(), newIndVar, forOperands.step());
+ rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
+
+ rewriter.eraseOp(forOp);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// scf::IfOp.
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // When lowering `scf::IfOp` we explicitly create a selection header block
+ // before the control flow diverges and a merge block where control flow
+ // subsequently converges.
+ scf::IfOpAdaptor ifOperands(operands);
+ auto loc = ifOp.getLoc();
+
+ // Create `spv.selection` operation, selection header block and merge block.
+ auto selectionControl = rewriter.getI32IntegerAttr(
+ static_cast<uint32_t>(spirv::SelectionControl::None));
+ auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
+ selectionOp.addMergeBlock();
+ auto *mergeBlock = selectionOp.getMergeBlock();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto *selectionHeaderBlock = new Block();
+ selectionOp.body().getBlocks().push_front(selectionHeaderBlock);
+
+ // Inline `then` region before the merge block and branch to it.
+ auto &thenRegion = ifOp.thenRegion();
+ auto *thenBlock = &thenRegion.front();
+ rewriter.setInsertionPointToEnd(&thenRegion.back());
+ rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+ rewriter.inlineRegionBefore(thenRegion, mergeBlock);
+
+ auto *elseBlock = mergeBlock;
+ // If `else` region is not empty, inline that region before the merge block
+ // and branch to it.
+ if (!ifOp.elseRegion().empty()) {
+ auto &elseRegion = ifOp.elseRegion();
+ elseBlock = &elseRegion.front();
+ rewriter.setInsertionPointToEnd(&elseRegion.back());
+ rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+ rewriter.inlineRegionBefore(elseRegion, mergeBlock);
+ }
+
+ // Create a `spv.BranchConditional` operation for selection header block.
+ rewriter.setInsertionPointToEnd(selectionHeaderBlock);
+ rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
+ thenBlock, ArrayRef<Value>(),
+ elseBlock, ArrayRef<Value>());
+
+ rewriter.eraseOp(ifOp);
+ return success();
+}
+
+void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
+ context, typeConverter);
+}
More information about the Mlir-commits
mailing list