[Mlir-commits] [mlir] [mlir][SPIR-V] Support spirv.loop_control attribute on scf.for and scf.while (PR #189392)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 30 07:17:30 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Arseniy Obolenskiy (aobolensk)

<details>
<summary>Changes</summary>

Propagate the `spirv.loop_control` attribute from `scf.for` and `scf.while` operations to the generated `spirv.mlir.loop` during SCFToSPIRV conversion

---
Full diff: https://github.com/llvm/llvm-project/pull/189392.diff


4 Files Affected:

- (modified) mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp (+10-4) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+4) 
- (modified) mlir/test/Conversion/SCFToSPIRV/for.mlir (+26) 
- (modified) mlir/test/Conversion/SCFToSPIRV/while.mlir (+17) 


``````````diff
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 55ed31ee695ff..a7dcb7045a256 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -135,8 +135,11 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
     // a single back edge from the continue to header block, and a single exit
     // from header to merge.
     auto loc = forOp.getLoc();
-    auto loopOp =
-        spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
+    auto loopControl = spirv::LoopControl::None;
+    if (auto attr =
+            forOp->getAttrOfType<spirv::LoopControlAttr>("spirv.loop_control"))
+      loopControl = attr.getValue();
+    auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
     loopOp.addEntryAndMergeBlock(rewriter);
 
     OpBuilder::InsertionGuard guard(rewriter);
@@ -348,8 +351,11 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
   matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = whileOp.getLoc();
-    auto loopOp =
-        spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
+    auto loopControl = spirv::LoopControl::None;
+    if (auto attr = whileOp->getAttrOfType<spirv::LoopControlAttr>(
+            "spirv.loop_control"))
+      loopControl = attr.getValue();
+    auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
     loopOp.addEntryAndMergeBlock(rewriter);
 
     Region &beforeRegion = whileOp.getBefore();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index c9b22fe145d88..8f508f9f0374b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1039,6 +1039,10 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
   } else if (symbol == spirv::getTargetEnvAttrName()) {
     if (!isa<spirv::TargetEnvAttr>(attr))
       return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
+  } else if (symbol == "spirv.loop_control") {
+    if (!isa<spirv::LoopControlAttr>(attr))
+      return op->emitError("'")
+             << symbol << "' must be a spirv::LoopControlAttr";
   } else {
     return op->emitError("found unsupported '")
            << symbol << "' attribute on operation";
diff --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir
index 9c552166cd72d..702bee476668f 100644
--- a/mlir/test/Conversion/SCFToSPIRV/for.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir
@@ -89,4 +89,30 @@ func.func @loop_yield(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>
   return
 }
 
+// CHECK-LABEL: @loop_unroll
+func.func @loop_unroll(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: spirv.mlir.loop control(Unroll) {
+  scf.for %arg4 = %lb to %ub step %step {
+    %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+  } {spirv.loop_control = #spirv.loop_control<Unroll>}
+  return
+}
+
+// CHECK-LABEL: @loop_dont_unroll
+func.func @loop_dont_unroll(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: spirv.mlir.loop control(DontUnroll) {
+  scf.for %arg4 = %lb to %ub step %step {
+    %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+  } {spirv.loop_control = #spirv.loop_control<DontUnroll>}
+  return
+}
+
 } // end module
diff --git a/mlir/test/Conversion/SCFToSPIRV/while.mlir b/mlir/test/Conversion/SCFToSPIRV/while.mlir
index ff455383a7f05..29e12fc4ccbdf 100644
--- a/mlir/test/Conversion/SCFToSPIRV/while.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/while.mlir
@@ -127,4 +127,21 @@ func.func @while_loop_after_typeconv(%arg0: f32) -> index {
   return %res : index
 }
 
+// -----
+
+// CHECK-LABEL: @while_loop_unroll
+func.func @while_loop_unroll(%arg0: i32, %arg1: i32) -> i32 {
+  %c2_i32 = arith.constant 2 : i32
+  // CHECK: spirv.mlir.loop control(Unroll) {
+  %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : i32
+    scf.condition(%1) %arg3 : i32
+  } do {
+  ^bb0(%arg5: i32):
+    %1 = arith.muli %arg5, %c2_i32 : i32
+    scf.yield %1 : i32
+  } attributes {spirv.loop_control = #spirv.loop_control<Unroll>}
+  return %0 : i32
+}
+
 } // end module

``````````

</details>


https://github.com/llvm/llvm-project/pull/189392


More information about the Mlir-commits mailing list