[Mlir-commits] [mlir] 731676b - [mlir][spirv] Fix nested control flow serialization

Lei Zhang llvmlistbot at llvm.org
Sat Dec 11 11:47:35 PST 2021


Author: Lei Zhang
Date: 2021-12-11T14:47:19-05:00
New Revision: 731676b10dfe684d7f166cb0f85cd6ede1660119

URL: https://github.com/llvm/llvm-project/commit/731676b10dfe684d7f166cb0f85cd6ede1660119
DIFF: https://github.com/llvm/llvm-project/commit/731676b10dfe684d7f166cb0f85cd6ede1660119.diff

LOG: [mlir][spirv] Fix nested control flow serialization

If we have a `spv.mlir.selection` op nested in a `spv.mlir.loop`
op, when serializing the loop's block, we might need to jump
from the selection op's merge block, which might be different
than the immediate MLIR IR predecessor block. But we still need
to get the block argument from the MLIR IR predecessor block.

Also, if the `spv.mlir.selection` is in the `spv.mlir.loop`'s
header block, we need to make sure `OpLoopMerge` is emitted
in the current block before start processing the nested selection
op. Otherwise we'll see the LoopMerge in the wrong SPIR-V
basic block.

Reviewed By: Hardcode84

Differential Revision: https://reviews.llvm.org/D115560

Added: 
    

Modified: 
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.h
    mlir/test/Target/SPIRV/loop.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 8d4b7f854b95f..6d2e7384ea2f8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1733,6 +1733,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
       LLVM_DEBUG(llvm::dbgs()
                  << "[cf] block " << block << " is a function entry block\n");
     }
+
     for (auto &op : *block)
       newBlock->push_back(op.clone(mapper));
   }
@@ -1746,9 +1747,8 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
       if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
         succOp.set(mappedOp);
   };
-  for (auto &block : body) {
+  for (auto &block : body)
     block.walk(remapOperands);
-  }
 
   // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
   // the selection/loop construct into its region. Next we need to fix the
@@ -1758,8 +1758,12 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
   // SelectionOp/LoopOp resides right now.
   headerBlock->replaceAllUsesWith(mergeBlock);
 
+  LLVM_DEBUG(llvm::dbgs() << "[cf] after cloning and fixing references:\n");
+  LLVM_DEBUG(llvm::dbgs() << *headerBlock->getParentOp());
+  LLVM_DEBUG(llvm::dbgs() << "\n");
+
   if (isLoop) {
-    // The loop selection/loop header block may have block arguments. Since now
+    // The selection/loop header block may have block arguments. Since now
     // we place the selection/loop op inside the old merge block, we need to
     // make sure the old merge block has the same block argument list.
     assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index e151aa40f0d4f..5e5dccdf76bd2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -350,9 +350,8 @@ class Deserializer {
   //    guarantees that we enter and exit in structured ways and the construct
   //    is nestable.
   // 3. Put the new spv.mlir.selection/spv.mlir.loop op at the beginning of the
-  // old merge
-  //    block and redirect all branches to the old header block to the old
-  //    merge block (which contains the spv.mlir.selection/spv.mlir.loop op
+  //    old merge block and redirect all branches to the old header block to the
+  //    old merge block (which contains the spv.mlir.selection/spv.mlir.loop op
   //    now).
 
   /// For OpPhi instructions, we use block arguments to represent them. OpPhi

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ffb3b51f59cf8..128a4c9d0ac51 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -406,6 +406,9 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
   // instruction to start a new SPIR-V block for ops following this SelectionOp.
   // The block should use the <id> for the merge block.
   encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
+  LLVM_DEBUG(llvm::dbgs() << "done merge ");
+  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
+  LLVM_DEBUG(llvm::dbgs() << "\n");
   return success();
 }
 
@@ -414,10 +417,9 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   // properly. We don't need to assign for the entry block, which is just for
   // satisfying MLIR region's structural requirement.
   auto &body = loopOp.body();
-  for (Block &block :
-       llvm::make_range(std::next(body.begin(), 1), body.end())) {
+  for (Block &block : llvm::make_range(std::next(body.begin(), 1), body.end()))
     getOrCreateBlockID(&block);
-  }
+
   auto *headerBlock = loopOp.getHeaderBlock();
   auto *continueBlock = loopOp.getContinueBlock();
   auto *mergeBlock = loopOp.getMergeBlock();
@@ -469,6 +471,9 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   // start a new SPIR-V block for ops following this LoopOp. The block should
   // use the <id> for the merge block.
   encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
+  LLVM_DEBUG(llvm::dbgs() << "done merge ");
+  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
+  LLVM_DEBUG(llvm::dbgs() << "\n");
   return success();
 }
 

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index a306257eaed66..d6df234ecf87a 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -921,16 +921,26 @@ uint32_t Serializer::getOrCreateBlockID(Block *block) {
   return blockIDMap[block] = getNextID();
 }
 
+#ifndef NDEBUG
+void Serializer::printBlock(Block *block, raw_ostream &os) {
+  os << "block " << block << " (id = ";
+  if (uint32_t id = getBlockID(block))
+    os << id;
+  else
+    os << "unknown";
+  os << ")\n";
+}
+#endif
+
 LogicalResult
 Serializer::processBlock(Block *block, bool omitLabel,
-                         function_ref<LogicalResult()> actionBeforeTerminator) {
+                         function_ref<LogicalResult()> emitMerge) {
   LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
   LLVM_DEBUG(block->print(llvm::dbgs()));
   LLVM_DEBUG(llvm::dbgs() << '\n');
   if (!omitLabel) {
     uint32_t blockID = getOrCreateBlockID(block);
-    LLVM_DEBUG(llvm::dbgs()
-               << "[block] " << block << " (id = " << blockID << ")\n");
+    LLVM_DEBUG(printBlock(block, llvm::dbgs()));
 
     // Emit OpLabel for this block.
     encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
@@ -940,6 +950,24 @@ Serializer::processBlock(Block *block, bool omitLabel,
   if (failed(emitPhiForBlockArguments(block)))
     return failure();
 
+  // If we need to emit merge instructions, it must happen in this block. Check
+  // whether we have other structured control flow ops, which will be expanded
+  // into multiple basic blocks. If that's the case, we need to emit the merge
+  // right now and then create new blocks for further serialization of the ops
+  // in this block.
+  if (emitMerge && llvm::any_of(block->getOperations(), [](Operation &op) {
+        return isa<spirv::LoopOp, spirv::SelectionOp>(op);
+      })) {
+    if (failed(emitMerge()))
+      return failure();
+    emitMerge = nullptr;
+
+    // Start a new block for further serialization.
+    uint32_t blockID = getNextID();
+    encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
+    encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
+  }
+
   // Process each op in this block except the terminator.
   for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
     if (failed(processOperation(&op)))
@@ -947,8 +975,8 @@ Serializer::processBlock(Block *block, bool omitLabel,
   }
 
   // Process the terminator.
-  if (actionBeforeTerminator)
-    if (failed(actionBeforeTerminator()))
+  if (emitMerge)
+    if (failed(emitMerge()))
       return failure();
   if (failed(processOperation(&block->back())))
     return failure();
@@ -962,14 +990,19 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
   if (block->args_empty() || block->isEntryBlock())
     return success();
 
+  LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
+
   // If the block has arguments, we need to create SPIR-V OpPhi instructions.
   // A SPIR-V OpPhi instruction is of the syntax:
   //   OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
   // So we need to collect all predecessor blocks and the arguments they send
   // to this block.
   SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
-  for (Block *predecessor : block->getPredecessors()) {
-    auto *terminator = predecessor->getTerminator();
+  for (Block *mlirPredecessor : block->getPredecessors()) {
+    auto *terminator = mlirPredecessor->getTerminator();
+    LLVM_DEBUG(llvm::dbgs() << "  mlir predecessor ");
+    LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
+    LLVM_DEBUG(llvm::dbgs() << "    terminator: " << *terminator << "\n");
     // The predecessor here is the immediate one according to MLIR's IR
     // structure. It does not directly map to the incoming parent block for the
     // OpPhi instructions at SPIR-V binary level. This is because structured
@@ -977,26 +1010,32 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
     // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the
     // branch op jumping to the OpPhi's block then resides in the previous
     // structured control flow op's merge block.
-    predecessor = getPhiIncomingBlock(predecessor);
+    Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
+    LLVM_DEBUG(llvm::dbgs() << "  spirv predecessor ");
+    LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
     if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
-      predecessors.emplace_back(predecessor, branchOp.getOperands());
+      predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
     } else if (auto branchCondOp =
                    dyn_cast<spirv::BranchConditionalOp>(terminator)) {
       Optional<OperandRange> blockOperands;
+      if (branchCondOp.trueTarget() == block) {
+        blockOperands = branchCondOp.trueTargetOperands();
+      } else {
+        assert(branchCondOp.falseTarget() == block);
+        blockOperands = branchCondOp.falseTargetOperands();
+      }
 
-      for (auto successorIdx :
-           llvm::seq<unsigned>(0, predecessor->getNumSuccessors()))
-        if (predecessor->getSuccessors()[successorIdx] == block) {
-          blockOperands = branchCondOp.getSuccessorOperands(successorIdx);
-          break;
-        }
-
-      assert(blockOperands && !blockOperands->empty() &&
+      assert(!blockOperands->empty() &&
              "expected non-empty block operand range");
-      predecessors.emplace_back(predecessor, *blockOperands);
+      predecessors.emplace_back(spirvPredecessor, *blockOperands);
     } else {
       return terminator->emitError("unimplemented terminator for Phi creation");
     }
+    LLVM_DEBUG({
+      llvm::dbgs() << "    block arguments:\n";
+      for (Value v : predecessors.back().second)
+        llvm::dbgs() << "      " << v << "\n";
+    });
   }
 
   // Then create OpPhi instruction for each of the block argument.

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 0f053c0eb5e3f..69d9461cd9c16 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -238,14 +238,18 @@ class Serializer {
   /// assigns the next available <id>
   uint32_t getOrCreateBlockID(Block *block);
 
+#ifndef NDEBUG
+  /// (For debugging) prints the block with its result <id>.
+  void printBlock(Block *block, raw_ostream &os);
+#endif
+
   /// Processes the given `block` and emits SPIR-V instructions for all ops
   /// inside. Does not emit OpLabel for this block if `omitLabel` is true.
-  /// `actionBeforeTerminator` is a callback that will be invoked before
-  /// handling the terminator op. It can be used to inject the Op*Merge
-  /// instruction if this is a SPIR-V selection/loop header block.
-  LogicalResult
-  processBlock(Block *block, bool omitLabel = false,
-               function_ref<LogicalResult()> actionBeforeTerminator = nullptr);
+  /// `emitMerge` is a callback that will be invoked before handling the
+  /// terminator op to inject the Op*Merge instruction if this is a SPIR-V
+  /// selection/loop header block.
+  LogicalResult processBlock(Block *block, bool omitLabel = false,
+                             function_ref<LogicalResult()> emitMerge = nullptr);
 
   /// Emits OpPhi instructions for the given block if it has block arguments.
   LogicalResult emitPhiForBlockArguments(Block *block);

diff  --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index bd6fa00092856..03d87475a4e20 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -4,6 +4,7 @@
 
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // for (int i = 0; i < count; ++i) {}
+// CHECK-LABEL: @loop
   spv.func @loop(%count : i32) -> () "None" {
     %zero = spv.Constant 0: i32
     %one = spv.Constant 1: i32
@@ -59,9 +60,12 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
 
 // -----
 
+// Single loop with block arguments
+
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   spv.GlobalVariable @GV1 bind(0, 0) : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
   spv.GlobalVariable @GV2 bind(0, 1) : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-LABEL: @loop_kernel
   spv.func @loop_kernel() "None" {
     %0 = spv.mlir.addressof @GV1 : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
     %1 = spv.Constant 0 : i32
@@ -111,6 +115,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // for (int i = 0; i < count; ++i) {
   //   for (int j = 0; j < count; ++j) { }
   // }
+// CHECK-LABEL: @loop
   spv.func @loop(%count : i32) -> () "None" {
     %zero = spv.Constant 0: i32
     %one = spv.Constant 1: i32
@@ -207,3 +212,77 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   spv.EntryPoint "GLCompute" @main
 }
 
+
+// -----
+
+// Loop with selection in its header
+
+spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Linkage, Addresses, Int64], []> {
+// CHECK-LABEL:   @kernel
+// CHECK-SAME:    (%[[INPUT0:.+]]: i64)
+  spv.func @kernel(%input: i64) "None" {
+// CHECK-NEXT:     %[[VAR:.+]] = spv.Variable : !spv.ptr<i1, Function>
+// CHECK-NEXT:     spv.Branch ^[[BB:.+]](%[[INPUT0]] : i64)
+// CHECK-NEXT:   ^[[BB]](%[[INPUT1:.+]]: i64):
+    %cst0_i64 = spv.Constant 0 : i64
+    %true = spv.Constant true
+    %false = spv.Constant false
+// CHECK-NEXT:     spv.mlir.loop {
+    spv.mlir.loop {
+// CHECK-NEXT:       spv.Branch ^[[LOOP_HEADER:.+]](%[[INPUT1]] : i64)
+      spv.Branch ^loop_header(%input : i64)
+// CHECK-NEXT:     ^[[LOOP_HEADER]](%[[ARG1:.+]]: i64):
+    ^loop_header(%arg1: i64):
+// CHECK-NEXT:       spv.Branch ^[[LOOP_BODY:.+]]
+// CHECK-NEXT:     ^[[LOOP_BODY]]:
+      %gt = spv.SGreaterThan %arg1, %cst0_i64 : i64
+      %var = spv.Variable : !spv.ptr<i1, Function>
+// CHECK-NEXT:       spv.mlir.selection {
+      spv.mlir.selection {
+// CHECK-NEXT:         %[[C0:.+]] = spv.Constant 0 : i64
+// CHECK-NEXT:         %[[GT:.+]] = spv.SGreaterThan %[[ARG1]], %[[C0]] : i64
+// CHECK-NEXT:         spv.BranchConditional %[[GT]], ^[[THEN:.+]], ^[[ELSE:.+]]
+        spv.BranchConditional %gt, ^then, ^else
+// CHECK-NEXT:       ^[[THEN]]:
+      ^then:
+// CHECK-NEXT:         %true = spv.Constant true
+// CHECK-NEXT:         spv.Store "Function" %[[VAR]], %true : i1
+        spv.Store "Function" %var, %true : i1
+// CHECK-NEXT:         spv.Branch ^[[SELECTION_MERGE:.+]]
+        spv.Branch ^selection_merge
+// CHECK-NEXT:       ^[[ELSE]]:
+      ^else:
+// CHECK-NEXT:         %false = spv.Constant false
+// CHECK-NEXT:         spv.Store "Function" %[[VAR]], %false : i1
+        spv.Store "Function" %var, %false : i1
+// CHECK-NEXT:         spv.Branch ^[[SELECTION_MERGE]]
+        spv.Branch ^selection_merge
+// CHECK-NEXT:       ^[[SELECTION_MERGE]]:
+      ^selection_merge:
+// CHECK-NEXT:         spv.mlir.merge
+        spv.mlir.merge
+// CHECK-NEXT:       }
+      }
+// CHECK-NEXT:       %[[LOAD:.+]] = spv.Load "Function" %[[VAR]] : i1
+      %load = spv.Load "Function" %var : i1
+// CHECK-NEXT:       spv.BranchConditional %[[LOAD]], ^[[CONTINUE:.+]](%[[ARG1]] : i64), ^[[LOOP_MERGE:.+]]
+      spv.BranchConditional %load, ^continue(%arg1 : i64), ^loop_merge
+// CHECK-NEXT:     ^[[CONTINUE]](%[[ARG2:.+]]: i64):
+    ^continue(%arg2: i64):
+// CHECK-NEXT:       %[[C0:.+]] = spv.Constant 0 : i64
+// CHECK-NEXT:       %[[LT:.+]] = spv.SLessThan %[[ARG2]], %[[C0]] : i64
+      %lt = spv.SLessThan %arg2, %cst0_i64 : i64
+// CHECK-NEXT:       spv.Store "Function" %[[VAR]], %[[LT]] : i1
+      spv.Store "Function" %var, %lt : i1
+// CHECK-NEXT:       spv.Branch ^[[LOOP_HEADER]](%[[ARG2]] : i64)
+      spv.Branch ^loop_header(%arg2 : i64)
+// CHECK-NEXT:     ^[[LOOP_MERGE]]:
+    ^loop_merge:
+// CHECK-NEXT:       spv.mlir.merge
+      spv.mlir.merge
+// CHECK-NEXT:     }
+    }
+// CHECK-NEXT:     spv.Return
+    spv.Return
+  }
+}


        


More information about the Mlir-commits mailing list