[Mlir-commits] [mlir] ab5eae0 - [mlir][spirv][NFC] Clean up scf-to-spirv pass

Jakub Kuderski llvmlistbot at llvm.org
Tue Mar 14 15:45:41 PDT 2023


Author: Jakub Kuderski
Date: 2023-03-14T18:44:35-04:00
New Revision: ab5eae01646e2a83356ec8fe300bf727dadc87dd

URL: https://github.com/llvm/llvm-project/commit/ab5eae01646e2a83356ec8fe300bf727dadc87dd
DIFF: https://github.com/llvm/llvm-project/commit/ab5eae01646e2a83356ec8fe300bf727dadc87dd.diff

LOG: [mlir][spirv][NFC] Clean up scf-to-spirv pass

This is a clean up before fixing issues identified in this pass by
https://github.com/llvm/llvm-project/issues/61380 and similar issues.

- Move patterns definitions closer to declarations.
- Simplify pattern definitions.
- Drop hand-written pass constructor in favor of an auto-generated on.
- Fix typos in pass description.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D146077

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 690958b24a608..a958b967bb4b6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -812,13 +812,12 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
 def SCFToSPIRV : Pass<"convert-scf-to-spirv"> {
   let summary = "Convert SCF dialect to SPIR-V dialect.";
   let description = [{
-    This pass converts SCF ops into SPIR-V structured control flow ops.
-    SPIR-V structured control flow ops does not support yielding values.
+    Converts SCF ops into SPIR-V structured control flow ops.
+    SPIR-V structured control flow ops do not support yielding values.
     So for SCF ops yielding values, SPIR-V variables are created for
     holding the values and load/store operations are emitted for updating
     them.
   }];
-  let constructor = "mlir::createConvertSCFToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 00062ef4a7a91..2572bbcbd6bb5 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -41,16 +41,57 @@ struct ScfToSPIRVContextImpl {
 /// StoreOp cannot be created earlier as they may use a 
diff erent type than
 /// yield operands.
 ScfToSPIRVContext::ScfToSPIRVContext() {
-  impl = std::make_unique<ScfToSPIRVContextImpl>();
+  impl = std::make_unique<::ScfToSPIRVContextImpl>();
 }
 
 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
 
+namespace {
+
 //===----------------------------------------------------------------------===//
-// Pattern Declarations
+// Helper Functions
+//===----------------------------------------------------------------------===//
+
+/// Replaces SCF op outputs with SPIR-V variable loads.
+/// We create VariableOp to handle the results value of the control flow region.
+/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
+/// after the loop we load the value from the allocation and use it as the SCF
+/// op result.
+template <typename ScfOp, typename OpTy>
+void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
+                           ConversionPatternRewriter &rewriter,
+                           ScfToSPIRVContextImpl *scfToSPIRVContext,
+                           ArrayRef<Type> returnTypes) {
+
+  Location loc = scfOp.getLoc();
+  auto &allocas = scfToSPIRVContext->outputVars[newOp];
+  // Clearing the allocas is necessary in case a dialect conversion path failed
+  // previously, and this is the second attempt of this conversion.
+  allocas.clear();
+  SmallVector<Value, 8> resultValue;
+  for (Type convertedType : returnTypes) {
+    auto pointerType =
+        spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
+    rewriter.setInsertionPoint(newOp);
+    auto alloc = rewriter.create<spirv::VariableOp>(
+        loc, pointerType, spirv::StorageClass::Function,
+        /*initializer=*/nullptr);
+    allocas.push_back(alloc);
+    rewriter.setInsertionPointAfter(newOp);
+    Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
+    resultValue.push_back(loadResult);
+  }
+  rewriter.replaceOp(scfOp, resultValue);
+}
+
+Region::iterator getBlockIt(Region &region, unsigned index) {
+  return std::next(region.begin(), index);
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// Common class for all vector to GPU patterns.
 template <typename OpTy>
 class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
@@ -79,356 +120,306 @@ class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
   SPIRVTypeConverter &typeConverter;
 };
 
+//===----------------------------------------------------------------------===//
+// scf::ForOp
+//===----------------------------------------------------------------------===//
+
 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
-class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
-public:
-  using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
+struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
+  using SCFToSPIRVPattern::SCFToSPIRVPattern;
 
   LogicalResult
   matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
+                  ConversionPatternRewriter &rewriter) const override {
+    // scf::ForOp can be lowered to the structured control flow represented by
+    // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
+    // latch and the merge block the exit block. The resulting spirv::LoopOp has
+    // a single back edge from the continue to header block, and a single exit
+    // from header to merge.
+    auto loc = forOp.getLoc();
+    auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+    loopOp.addEntryAndMergeBlock();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    // Create the block for the header.
+    auto *header = new Block();
+    // Insert the header.
+    loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1),
+                                        header);
+
+    // Create the new induction variable to use.
+    Value adapLowerBound = adaptor.getLowerBound();
+    BlockArgument newIndVar =
+        header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc());
+    for (Value arg : adaptor.getInitArgs())
+      header->addArgument(arg.getType(), arg.getLoc());
+    Block *body = forOp.getBody();
+
+    // Apply signature conversion to the body of the forOp. It has a single
+    // block, with argument which is the induction variable. That has to be
+    // replaced with the new induction variable.
+    TypeConverter::SignatureConversion signatureConverter(
+        body->getNumArguments());
+    signatureConverter.remapInput(0, newIndVar);
+    for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
+      signatureConverter.remapInput(i, header->getArgument(i));
+    body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
+                                             signatureConverter);
+
+    // Move the blocks from the forOp into the loopOp. This is the body of the
+    // loopOp.
+    rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
+                                getBlockIt(loopOp.getBody(), 2));
+
+    SmallVector<Value, 8> args(1, adaptor.getLowerBound());
+    args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
+    // Branch into it from the entry.
+    rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+    rewriter.create<spirv::BranchOp>(loc, header, args);
+
+    // Generate the rest of the loop header.
+    rewriter.setInsertionPointToEnd(header);
+    auto *mergeBlock = loopOp.getMergeBlock();
+    auto cmpOp = rewriter.create<spirv::SLessThanOp>(
+        loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
+
+    rewriter.create<spirv::BranchConditionalOp>(
+        loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
+
+    // Generate instructions to increment the step of the induction variable and
+    // branch to the header.
+    Block *continueBlock = loopOp.getContinueBlock();
+    rewriter.setInsertionPointToEnd(continueBlock);
+
+    // Add the step to the induction variable and branch to the header.
+    Value updatedIndVar = rewriter.create<spirv::IAddOp>(
+        loc, newIndVar.getType(), newIndVar, adaptor.getStep());
+    rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
+
+    // Infer the return types from the init operands. Vector type may get
+    // converted to CooperativeMatrix or to Vector type, to avoid having complex
+    // extra logic to figure out the right type we just infer it from the Init
+    // operands.
+    SmallVector<Type, 8> initTypes;
+    for (auto arg : adaptor.getInitArgs())
+      initTypes.push_back(arg.getType());
+    replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
+                          initTypes);
+    return success();
+  }
 };
 
+//===----------------------------------------------------------------------===//
+// scf::IfOp
+//===----------------------------------------------------------------------===//
+
 /// Pattern to convert a scf::IfOp within kernel functions into
 /// spirv::SelectionOp.
-class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
-public:
-  using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
+struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
+  using SCFToSPIRVPattern::SCFToSPIRVPattern;
 
   LogicalResult
   matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
-public:
-  using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
+                  ConversionPatternRewriter &rewriter) const override {
+    // When lowering `scf::IfOp` we explicitly create a selection header block
+    // before the control flow diverges and a merge block where control flow
+    // subsequently converges.
+    auto loc = ifOp.getLoc();
+
+    // Create `spirv.selection` operation, selection header block and merge
+    // block.
+    auto selectionOp =
+        rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+    auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
+                                            selectionOp.getBody().end());
+    rewriter.create<spirv::MergeOp>(loc);
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto *selectionHeaderBlock =
+        rewriter.createBlock(&selectionOp.getBody().front());
+
+    // Inline `then` region before the merge block and branch to it.
+    auto &thenRegion = ifOp.getThenRegion();
+    auto *thenBlock = &thenRegion.front();
+    rewriter.setInsertionPointToEnd(&thenRegion.back());
+    rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+    rewriter.inlineRegionBefore(thenRegion, mergeBlock);
+
+    auto *elseBlock = mergeBlock;
+    // If `else` region is not empty, inline that region before the merge block
+    // and branch to it.
+    if (!ifOp.getElseRegion().empty()) {
+      auto &elseRegion = ifOp.getElseRegion();
+      elseBlock = &elseRegion.front();
+      rewriter.setInsertionPointToEnd(&elseRegion.back());
+      rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+      rewriter.inlineRegionBefore(elseRegion, mergeBlock);
+    }
 
-  LogicalResult
-  matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
+    // Create a `spirv.BranchConditional` operation for selection header block.
+    rewriter.setInsertionPointToEnd(selectionHeaderBlock);
+    rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
+                                                thenBlock, ArrayRef<Value>(),
+                                                elseBlock, ArrayRef<Value>());
+
+    SmallVector<Type, 8> returnTypes;
+    for (auto result : ifOp.getResults()) {
+      auto convertedType = typeConverter.convertType(result.getType());
+      if (!convertedType)
+        return rewriter.notifyMatchFailure(
+            loc,
+            llvm::formatv("failed to convert type '{0}'", result.getType()));
+
+      returnTypes.push_back(convertedType);
+    }
+    replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
+                          returnTypes);
+    return success();
+  }
 };
 
-class WhileOpConversion final : public SCFToSPIRVPattern<scf::WhileOp> {
+//===----------------------------------------------------------------------===//
+// scf::YieldOp
+//===----------------------------------------------------------------------===//
+
+struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
 public:
-  using SCFToSPIRVPattern<scf::WhileOp>::SCFToSPIRVPattern;
+  using SCFToSPIRVPattern::SCFToSPIRVPattern;
 
   LogicalResult
-  matchAndRewrite(scf::WhileOp forOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-/// Helper function to replaces SCF op outputs with SPIR-V variable loads.
-/// We create VariableOp to handle the results value of the control flow region.
-/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
-/// after the loop we load the value from the allocation and use it as the SCF
-/// op result.
-template <typename ScfOp, typename OpTy>
-static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
-                                  ConversionPatternRewriter &rewriter,
-                                  ScfToSPIRVContextImpl *scfToSPIRVContext,
-                                  ArrayRef<Type> returnTypes) {
-
-  Location loc = scfOp.getLoc();
-  auto &allocas = scfToSPIRVContext->outputVars[newOp];
-  // Clearing the allocas is necessary in case a dialect conversion path failed
-  // previously, and this is the second attempt of this conversion.
-  allocas.clear();
-  SmallVector<Value, 8> resultValue;
-  for (Type convertedType : returnTypes) {
-    auto pointerType =
-        spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
-    rewriter.setInsertionPoint(newOp);
-    auto alloc = rewriter.create<spirv::VariableOp>(
-        loc, pointerType, spirv::StorageClass::Function,
-        /*initializer=*/nullptr);
-    allocas.push_back(alloc);
-    rewriter.setInsertionPointAfter(newOp);
-    Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
-    resultValue.push_back(loadResult);
+  matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    ValueRange operands = adaptor.getOperands();
+
+    // If the region is return values, store each value into the associated
+    // VariableOp created during lowering of the parent region.
+    if (!operands.empty()) {
+      auto &allocas =
+          scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
+      if (allocas.size() != operands.size())
+        return failure();
+
+      auto loc = terminatorOp.getLoc();
+      for (unsigned i = 0, e = operands.size(); i < e; i++)
+        rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
+      if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
+        // For loops we also need to update the branch jumping back to the
+        // header.
+        auto br = cast<spirv::BranchOp>(
+            rewriter.getInsertionBlock()->getTerminator());
+        SmallVector<Value, 8> args(br.getBlockArguments());
+        args.append(operands.begin(), operands.end());
+        rewriter.setInsertionPoint(br);
+        rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
+                                         args);
+        rewriter.eraseOp(br);
+      }
+    }
+    rewriter.eraseOp(terminatorOp);
+    return success();
   }
-  rewriter.replaceOp(scfOp, resultValue);
-}
-
-static Region::iterator getBlockIt(Region &region, unsigned index) {
-  return std::next(region.begin(), index);
-}
+};
 
 //===----------------------------------------------------------------------===//
-// scf::ForOp
+// scf::WhileOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
-                                 ConversionPatternRewriter &rewriter) const {
-  // scf::ForOp can be lowered to the structured control flow represented by
-  // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
-  // latch and the merge block the exit block. The resulting spirv::LoopOp has a
-  // single back edge from the continue to header block, and a single exit from
-  // header to merge.
-  auto loc = forOp.getLoc();
-  auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
-  loopOp.addEntryAndMergeBlock();
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  // Create the block for the header.
-  auto *header = new Block();
-  // Insert the header.
-  loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), header);
-
-  // Create the new induction variable to use.
-  Value adapLowerBound = adaptor.getLowerBound();
-  BlockArgument newIndVar =
-      header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc());
-  for (Value arg : adaptor.getInitArgs())
-    header->addArgument(arg.getType(), arg.getLoc());
-  Block *body = forOp.getBody();
-
-  // Apply signature conversion to the body of the forOp. It has a single block,
-  // with argument which is the induction variable. That has to be replaced with
-  // the new induction variable.
-  TypeConverter::SignatureConversion signatureConverter(
-      body->getNumArguments());
-  signatureConverter.remapInput(0, newIndVar);
-  for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
-    signatureConverter.remapInput(i, header->getArgument(i));
-  body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
-                                           signatureConverter);
-
-  // Move the blocks from the forOp into the loopOp. This is the body of the
-  // loopOp.
-  rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
-                              getBlockIt(loopOp.getBody(), 2));
-
-  SmallVector<Value, 8> args(1, adaptor.getLowerBound());
-  args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
-  // Branch into it from the entry.
-  rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
-  rewriter.create<spirv::BranchOp>(loc, header, args);
-
-  // Generate the rest of the loop header.
-  rewriter.setInsertionPointToEnd(header);
-  auto *mergeBlock = loopOp.getMergeBlock();
-  auto cmpOp = rewriter.create<spirv::SLessThanOp>(
-      loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
-
-  rewriter.create<spirv::BranchConditionalOp>(
-      loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
-
-  // Generate instructions to increment the step of the induction variable and
-  // branch to the header.
-  Block *continueBlock = loopOp.getContinueBlock();
-  rewriter.setInsertionPointToEnd(continueBlock);
-
-  // Add the step to the induction variable and branch to the header.
-  Value updatedIndVar = rewriter.create<spirv::IAddOp>(
-      loc, newIndVar.getType(), newIndVar, adaptor.getStep());
-  rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
-
-  // Infer the return types from the init operands. Vector type may get
-  // converted to CooperativeMatrix or to Vector type, to avoid having complex
-  // extra logic to figure out the right type we just infer it from the Init
-  // operands.
-  SmallVector<Type, 8> initTypes;
-  for (auto arg : adaptor.getInitArgs())
-    initTypes.push_back(arg.getType());
-  replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes);
-  return success();
-}
+struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
+  using SCFToSPIRVPattern::SCFToSPIRVPattern;
 
-//===----------------------------------------------------------------------===//
-// scf::IfOp
-//===----------------------------------------------------------------------===//
+  LogicalResult
+  matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = whileOp.getLoc();
+    auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+    loopOp.addEntryAndMergeBlock();
 
-LogicalResult
-IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
-                                ConversionPatternRewriter &rewriter) const {
-  // When lowering `scf::IfOp` we explicitly create a selection header block
-  // before the control flow diverges and a merge block where control flow
-  // subsequently converges.
-  auto loc = ifOp.getLoc();
-
-  // Create `spirv.selection` operation, selection header block and merge block.
-  auto selectionOp =
-      rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
-  auto *mergeBlock =
-      rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end());
-  rewriter.create<spirv::MergeOp>(loc);
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  auto *selectionHeaderBlock =
-      rewriter.createBlock(&selectionOp.getBody().front());
-
-  // Inline `then` region before the merge block and branch to it.
-  auto &thenRegion = ifOp.getThenRegion();
-  auto *thenBlock = &thenRegion.front();
-  rewriter.setInsertionPointToEnd(&thenRegion.back());
-  rewriter.create<spirv::BranchOp>(loc, mergeBlock);
-  rewriter.inlineRegionBefore(thenRegion, mergeBlock);
-
-  auto *elseBlock = mergeBlock;
-  // If `else` region is not empty, inline that region before the merge block
-  // and branch to it.
-  if (!ifOp.getElseRegion().empty()) {
-    auto &elseRegion = ifOp.getElseRegion();
-    elseBlock = &elseRegion.front();
-    rewriter.setInsertionPointToEnd(&elseRegion.back());
-    rewriter.create<spirv::BranchOp>(loc, mergeBlock);
-    rewriter.inlineRegionBefore(elseRegion, mergeBlock);
-  }
+    OpBuilder::InsertionGuard guard(rewriter);
 
-  // Create a `spirv.BranchConditional` operation for selection header block.
-  rewriter.setInsertionPointToEnd(selectionHeaderBlock);
-  rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
-                                              thenBlock, ArrayRef<Value>(),
-                                              elseBlock, ArrayRef<Value>());
+    Region &beforeRegion = whileOp.getBefore();
+    Region &afterRegion = whileOp.getAfter();
 
-  SmallVector<Type, 8> returnTypes;
-  for (auto result : ifOp.getResults()) {
-    auto convertedType = typeConverter.convertType(result.getType());
-    if (!convertedType)
-      return rewriter.notifyMatchFailure(
-          loc, llvm::formatv("failed to convert type '{0}'", result.getType()));
+    Block &entryBlock = *loopOp.getEntryBlock();
+    Block &beforeBlock = beforeRegion.front();
+    Block &afterBlock = afterRegion.front();
+    Block &mergeBlock = *loopOp.getMergeBlock();
 
-    returnTypes.push_back(convertedType);
-  }
-  replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
-                        returnTypes);
-  return success();
-}
+    auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
+    SmallVector<Value> condArgs;
+    if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
+      return failure();
 
-//===----------------------------------------------------------------------===//
-// scf::YieldOp
-//===----------------------------------------------------------------------===//
+    Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
+    if (!conditionVal)
+      return failure();
 
-/// Yield is lowered to stores to the VariableOp created during lowering of the
-/// parent region. For loops we also need to update the branch looping back to
-/// the header with the loop carried values.
-LogicalResult TerminatorOpConversion::matchAndRewrite(
-    scf::YieldOp terminatorOp, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  ValueRange operands = adaptor.getOperands();
-
-  // If the region is return values, store each value into the associated
-  // VariableOp created during lowering of the parent region.
-  if (!operands.empty()) {
-    auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
-    if (allocas.size() != operands.size())
+    auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
+    SmallVector<Value> yieldArgs;
+    if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
       return failure();
 
-    auto loc = terminatorOp.getLoc();
-    for (unsigned i = 0, e = operands.size(); i < e; i++)
-      rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
-    if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
-      // For loops we also need to update the branch jumping back to the header.
-      auto br =
-          cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
-      SmallVector<Value, 8> args(br.getBlockArguments());
-      args.append(operands.begin(), operands.end());
-      rewriter.setInsertionPoint(br);
-      rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
-                                       args);
-      rewriter.eraseOp(br);
+    // Move the while before block as the initial loop header block.
+    rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
+                                getBlockIt(loopOp.getBody(), 1));
+
+    // Move the while after block as the initial loop body block.
+    rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
+                                getBlockIt(loopOp.getBody(), 2));
+
+    // Jump from the loop entry block to the loop header block.
+    rewriter.setInsertionPointToEnd(&entryBlock);
+    rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
+
+    auto condLoc = cond.getLoc();
+
+    SmallVector<Value> resultValues(condArgs.size());
+
+    // For other SCF ops, the scf.yield op yields the value for the whole SCF
+    // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V
+    // local variables. But for the scf.while op, the scf.yield op yields a
+    // value for the before region, which may not matching the whole op's
+    // result. Instead, the scf.condition op returns values matching the whole
+    // op's results. So we need to create/load/store variables according to
+    // that.
+    for (const auto &it : llvm::enumerate(condArgs)) {
+      auto res = it.value();
+      auto i = it.index();
+      auto pointerType =
+          spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
+
+      // Create local variables before the scf.while op.
+      rewriter.setInsertionPoint(loopOp);
+      auto alloc = rewriter.create<spirv::VariableOp>(
+          condLoc, pointerType, spirv::StorageClass::Function,
+          /*initializer=*/nullptr);
+
+      // Load the final result values after the scf.while op.
+      rewriter.setInsertionPointAfter(loopOp);
+      auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
+      resultValues[i] = loadResult;
+
+      // Store the current iteration's result value.
+      rewriter.setInsertionPointToEnd(&beforeBlock);
+      rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
     }
-  }
-  rewriter.eraseOp(terminatorOp);
-  return success();
-}
 
-//===----------------------------------------------------------------------===//
-// scf::WhileOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
-                                   ConversionPatternRewriter &rewriter) const {
-  auto loc = whileOp.getLoc();
-  auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
-  loopOp.addEntryAndMergeBlock();
-
-  OpBuilder::InsertionGuard guard(rewriter);
-
-  Region &beforeRegion = whileOp.getBefore();
-  Region &afterRegion = whileOp.getAfter();
-
-  Block &entryBlock = *loopOp.getEntryBlock();
-  Block &beforeBlock = beforeRegion.front();
-  Block &afterBlock = afterRegion.front();
-  Block &mergeBlock = *loopOp.getMergeBlock();
-
-  auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
-  SmallVector<Value> condArgs;
-  if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
-    return failure();
-
-  Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
-  if (!conditionVal)
-    return failure();
-
-  auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
-  SmallVector<Value> yieldArgs;
-  if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
-    return failure();
-
-  // Move the while before block as the initial loop header block.
-  rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
-                              getBlockIt(loopOp.getBody(), 1));
-
-  // Move the while after block as the initial loop body block.
-  rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
-                              getBlockIt(loopOp.getBody(), 2));
-
-  // Jump from the loop entry block to the loop header block.
-  rewriter.setInsertionPointToEnd(&entryBlock);
-  rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
-
-  auto condLoc = cond.getLoc();
-
-  SmallVector<Value> resultValues(condArgs.size());
-
-  // For other SCF ops, the scf.yield op yields the value for the whole SCF op.
-  // So we use the scf.yield op as the anchor to create/load/store SPIR-V local
-  // variables. But for the scf.while op, the scf.yield op yields a value for
-  // the before region, which may not matching the whole op's result. Instead,
-  // the scf.condition op returns values matching the whole op's results. So we
-  // need to create/load/store variables according to that.
-  for (const auto &it : llvm::enumerate(condArgs)) {
-    auto res = it.value();
-    auto i = it.index();
-    auto pointerType =
-        spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
-
-    // Create local variables before the scf.while op.
-    rewriter.setInsertionPoint(loopOp);
-    auto alloc = rewriter.create<spirv::VariableOp>(
-        condLoc, pointerType, spirv::StorageClass::Function,
-        /*initializer=*/nullptr);
-
-    // Load the final result values after the scf.while op.
-    rewriter.setInsertionPointAfter(loopOp);
-    auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
-    resultValues[i] = loadResult;
-
-    // Store the current iteration's result value.
     rewriter.setInsertionPointToEnd(&beforeBlock);
-    rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
-  }
-
-  rewriter.setInsertionPointToEnd(&beforeBlock);
-  rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
-      cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt);
+    rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
+        cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt);
 
-  // Convert the scf.yield op to a branch back to the header block.
-  rewriter.setInsertionPointToEnd(&afterBlock);
-  rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock, yieldArgs);
+    // Convert the scf.yield op to a branch back to the header block.
+    rewriter.setInsertionPointToEnd(&afterBlock);
+    rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
+                                                 yieldArgs);
 
-  rewriter.replaceOp(whileOp, resultValues);
-  return success();
-}
+    rewriter.replaceOp(whileOp, resultValues);
+    return success();
+  }
+};
+} // namespace
 
 //===----------------------------------------------------------------------===//
-// Hooks
+// Public API
 //===----------------------------------------------------------------------===//
 
 void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index a6ce99284db96..1e8fe4423a422 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -56,7 +56,3 @@ void SCFToSPIRVPass::runOnOperation() {
   if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
 }
-
-std::unique_ptr<OperationPass<>> mlir::createConvertSCFToSPIRVPass() {
-  return std::make_unique<SCFToSPIRVPass>();
-}


        


More information about the Mlir-commits mailing list