[Mlir-commits] [mlir] 721c5cc - [mlir][spirv] Allow yielding values from loop regions (#135344)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 30 08:07:11 PDT 2025


Author: Igor Wodiany
Date: 2025-04-30T16:07:07+01:00
New Revision: 721c5cc327e30edbdd6cf14348ec076c79079125

URL: https://github.com/llvm/llvm-project/commit/721c5cc327e30edbdd6cf14348ec076c79079125
DIFF: https://github.com/llvm/llvm-project/commit/721c5cc327e30edbdd6cf14348ec076c79079125.diff

LOG: [mlir][spirv] Allow yielding values from loop regions (#135344)

This change extends `spirv.mlir.loop` so it can yield values, the same
as `spirv.mlir.selection`.

Added: 
    

Modified: 
    mlir/docs/Dialects/SPIR-V.md
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
    mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
    mlir/test/Target/SPIRV/loop.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index ae9afbd9fdfe5..1e8c1c7be9f6a 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -734,6 +734,18 @@ func.func @loop(%count : i32) -> () {
 }
 ```
 
+Similarly to selection, loops can also yield values using `spirv.mlir.merge`. This
+mechanism allows values defined within the loop region to be used outside of it.
+
+For example
+
+```mlir
+%yielded = spirv.mlir.loop -> i32 {
+  // ...
+  spirv.mlir.merge %to_yield : i32
+}
+```
+
 ### Block argument for Phi
 
 There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index 039af03871411..ef6682ab3630c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -311,17 +311,27 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
     The continue block should be the second to last block and it should have a
     branch to the loop header block. The loop continue block should be the only
     block, except the entry block, branching to the header block.
+
+    Values defined inside the loop regions cannot be directly used
+    outside of them; however, the loop region can yield values. These values are
+    yielded using a `spirv.mlir.merge` op and returned as a result of the loop op.
   }];
 
   let arguments = (ins
     SPIRV_LoopControlAttr:$loop_control
   );
 
-  let results = (outs);
+  let results = (outs Variadic<AnyType>:$results);
 
   let regions = (region AnyRegion:$body);
 
-  let builders = [OpBuilder<(ins)>];
+  let builders = [
+    OpBuilder<(ins)>,
+    OpBuilder<(ins "spirv::LoopControl":$loopControl),
+    [{
+      build($_builder, $_state, TypeRange(), loopControl);
+    }]>
+  ];
 
   let extraClassDeclaration = [{
     // Returns the entry block.

diff  --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 577959bbdbeaa..371456552b5b5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -229,6 +229,11 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
   if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
                                                                         result))
     return failure();
+
+  if (succeeded(parser.parseOptionalArrow()))
+    if (parser.parseTypeList(result.types))
+      return failure();
+
   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
 }
 
@@ -236,6 +241,10 @@ void LoopOp::print(OpAsmPrinter &printer) {
   auto control = getLoopControl();
   if (control != spirv::LoopControl::None)
     printer << " control(" << spirv::stringifyLoopControl(control) << ")";
+  if (getNumResults() > 0) {
+    printer << " -> ";
+    printer << getResultTypes();
+  }
   printer << ' ';
   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                       /*printBlockTerminators=*/true);

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b0220bc16e15e..1e867dde51001 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2003,7 +2003,14 @@ LogicalResult ControlFlowStructurizer::structurize() {
   // block inside the selection (`body.back()`). Values produced by block
   // arguments will be yielded by the selection region. We do not update uses or
   // erase original block arguments yet. It will be done later in the code.
-  if (!isLoop) {
+  //
+  // Code below is not executed for loops as it would interfere with the logic
+  // above. Currently block arguments in the merge block are not supported, but
+  // instead, the code above copies those arguments from the header block into
+  // the merge block. As such, running the code would yield those copied
+  // arguments that is most likely not a desired behaviour. This may need to be
+  // revisited in the future.
+  if (!isLoop)
     for (BlockArgument blockArg : mergeBlock->getArguments()) {
       // Create new block arguments in the last block ("merge block") of the
       // selection region. We create one argument for each argument in
@@ -2013,7 +2020,6 @@ LogicalResult ControlFlowStructurizer::structurize() {
       valuesToYield.push_back(body.back().getArguments().back());
       outsideUses.push_back(blockArg);
     }
-  }
 
   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
   // cleaned up.
@@ -2025,32 +2031,30 @@ LogicalResult ControlFlowStructurizer::structurize() {
 
   // All internal uses should be removed from original blocks by now, so
   // whatever is left is an outside use and will need to be yielded from
-  // the newly created selection region.
-  if (!isLoop) {
-    for (Block *block : constructBlocks) {
-      for (Operation &op : *block) {
-        if (!op.use_empty())
-          for (Value result : op.getResults()) {
-            valuesToYield.push_back(mapper.lookupOrNull(result));
-            outsideUses.push_back(result);
-          }
-      }
-      for (BlockArgument &arg : block->getArguments()) {
-        if (!arg.use_empty()) {
-          valuesToYield.push_back(mapper.lookupOrNull(arg));
-          outsideUses.push_back(arg);
+  // the newly created selection / loop region.
+  for (Block *block : constructBlocks) {
+    for (Operation &op : *block) {
+      if (!op.use_empty())
+        for (Value result : op.getResults()) {
+          valuesToYield.push_back(mapper.lookupOrNull(result));
+          outsideUses.push_back(result);
         }
+    }
+    for (BlockArgument &arg : block->getArguments()) {
+      if (!arg.use_empty()) {
+        valuesToYield.push_back(mapper.lookupOrNull(arg));
+        outsideUses.push_back(arg);
       }
     }
   }
 
   assert(valuesToYield.size() == outsideUses.size());
 
-  // If we need to yield any values from the selection region we will take
-  // care of it here.
-  if (!isLoop && !valuesToYield.empty()) {
+  // If we need to yield any values from the selection / loop region we will
+  // take care of it here.
+  if (!valuesToYield.empty()) {
     LLVM_DEBUG(logger.startLine()
-               << "[cf] yielding values from the selection region\n");
+               << "[cf] yielding values from the selection / loop region\n");
 
     // Update `mlir.merge` with values to be yield.
     auto mergeOps = body.back().getOps<spirv::MergeOp>();
@@ -2059,25 +2063,40 @@ LogicalResult ControlFlowStructurizer::structurize() {
     merge->setOperands(valuesToYield);
 
     // MLIR does not allow changing the number of results of an operation, so
-    // we create a new SelectionOp with required list of results and move
-    // the region from the initial SelectionOp. The initial operation is then
-    // removed. Since we move the region to the new op all links between blocks
-    // and remapping we have previously done should be preserved.
+    // we create a new SelectionOp / LoopOp with required list of results and
+    // move the region from the initial SelectionOp / LoopOp. The initial
+    // operation is then removed. Since we move the region to the new op all
+    // links between blocks and remapping we have previously done should be
+    // preserved.
     builder.setInsertionPoint(&mergeBlock->front());
-    auto selectionOp = builder.create<spirv::SelectionOp>(
-        location, TypeRange(ValueRange(outsideUses)),
-        static_cast<spirv::SelectionControl>(control));
-    selectionOp->getRegion(0).takeBody(body);
+
+    Operation *newOp = nullptr;
+
+    if (isLoop)
+      newOp = builder.create<spirv::LoopOp>(
+          location, TypeRange(ValueRange(outsideUses)),
+          static_cast<spirv::LoopControl>(control));
+    else
+      newOp = builder.create<spirv::SelectionOp>(
+          location, TypeRange(ValueRange(outsideUses)),
+          static_cast<spirv::SelectionControl>(control));
+
+    newOp->getRegion(0).takeBody(body);
 
     // Remove initial op and swap the pointer to the newly created one.
     op->erase();
-    op = selectionOp;
+    op = newOp;
 
-    // Update all outside uses to use results of the SelectionOp and remove
-    // block arguments from the original merge block.
+    // Update all outside uses to use results of the SelectionOp / LoopOp and
+    // remove block arguments from the original merge block.
     for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
-      outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
-    mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
+      outsideUses[i].replaceAllUsesWith(op->getResult(i));
+
+    // We do not support block arguments in loop merge block. Also running this
+    // function with loop would break some of the loop specific code above
+    // dealing with block arguments.
+    if (!isLoop)
+      mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
   }
 
   // Check that whether some op in the to-be-erased blocks still has uses. Those

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 5ed59a4134d37..ff3cc92ee8078 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -520,6 +520,13 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   auto mergeID = getBlockID(mergeBlock);
   auto loc = loopOp.getLoc();
 
+  // Before we do anything replace results of the selection operation with
+  // values yielded (with `mlir.merge`) from inside the region.
+  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+  assert(loopOp.getNumResults() == mergeOp.getNumOperands());
+  for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
+    loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
   // This LoopOp is in some MLIR block with preceding and following ops. In the
   // binary format, it should reside in separate SPIR-V blocks from its
   // preceding and following ops. So we need to emit unconditional branches to

diff  --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 107c8a3207b02..8ec0bf5bbaacf 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -426,6 +426,47 @@ func.func @only_entry_and_continue_branch_to_header() -> () {
 
 // -----
 
+func.func @loop_yield(%count : i32) -> () {
+  %zero = spirv.Constant 0: i32
+  %one = spirv.Constant 1: i32
+  %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+  // CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+  %final_i = spirv.mlir.loop -> i32 {
+    // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
+    spirv.Branch ^header(%zero: i32)
+
+  // CHECK-NEXT: ^bb1({{%.*}}: i32):
+  ^header(%i : i32):
+    %cmp = spirv.SLessThan %i, %count : i32
+    // CHECK: spirv.BranchConditional %{{.*}}, ^bb2, ^bb4
+    spirv.BranchConditional %cmp, ^body, ^merge
+
+  // CHECK-NEXT: ^bb2:
+  ^body:
+    // CHECK-NEXT: spirv.Branch ^bb3
+    spirv.Branch ^continue
+
+  // CHECK-NEXT: ^bb3:
+  ^continue:
+    %new_i = spirv.IAdd %i, %one : i32
+    // CHECK: spirv.Branch ^bb1({{%.*}}: i32)
+    spirv.Branch ^header(%new_i: i32)
+
+  // CHECK-NEXT: ^bb4:
+  ^merge:
+    // CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+    spirv.mlir.merge %i : i32
+  }
+
+  // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+  spirv.Store "Function" %var, %final_i : i32
+
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.merge
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index d89600558f56d..95b87b319ac2d 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -288,3 +288,50 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
     spirv.Return
   }
 }
+
+// -----
+
+// Loop yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+  spirv.func @loop_yield(%count : i32) -> () "None" {
+    %zero = spirv.Constant 0: i32
+    %one = spirv.Constant 1: i32
+    %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+    %final_i = spirv.mlir.loop -> i32 {
+// CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+      spirv.Branch ^header(%zero: i32)
+
+// CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
+    ^header(%i : i32):
+      %cmp = spirv.SLessThan %i, %count : i32
+// CHECK: spirv.BranchConditional %{{.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
+      spirv.BranchConditional %cmp, ^body, ^merge
+
+// CHECK-NEXT: ^[[BODY:.+]]:
+    ^body:
+// CHECK-NEXT: spirv.Branch ^[[CONTINUE:.+]]
+      spirv.Branch ^continue
+
+// CHECK-NEXT: ^[[CONTINUE:.+]]:
+    ^continue:
+      %new_i = spirv.IAdd %i, %one : i32
+// CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+      spirv.Branch ^header(%new_i: i32)
+
+// CHECK-NEXT: ^[[MERGE:.+]]:
+    ^merge:
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+      spirv.mlir.merge %i : i32
+// CHECK-NEXT: }
+    }
+
+// CHECK-NEXT: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+    spirv.Store "Function" %var, %final_i : i32
+
+// CHECK-NEXT: spirv.Return
+    spirv.Return
+  }
+}


        


More information about the Mlir-commits mailing list