[Mlir-commits] [mlir] [mlir][SPIR-V] Support spirv.loop_control attribute on scf.for and scf.while (PR #189392)

Arseniy Obolenskiy llvmlistbot at llvm.org
Mon Mar 30 21:03:08 PDT 2026


https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/189392

>From f330c2b5fb8e055f8c29d16f1c282f213db48119 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Mon, 30 Mar 2026 16:13:58 +0200
Subject: [PATCH 1/2] [mlir][SPIR-V] Support spirv.loop_control attribute on
 scf.for and scf.while

Propagate the `spirv.loop_control` attribute from `scf.for` and `scf.while` operations to the generated `spirv.mlir.loop` during SCF-to-SPIR-V conversion
---
 mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 14 +++++++---
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  4 +++
 mlir/test/Conversion/SCFToSPIRV/for.mlir      | 26 +++++++++++++++++++
 mlir/test/Conversion/SCFToSPIRV/while.mlir    | 17 ++++++++++++
 4 files changed, 57 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 55ed31ee695ff..a7dcb7045a256 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -135,8 +135,11 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
     // a single back edge from the continue to header block, and a single exit
     // from header to merge.
     auto loc = forOp.getLoc();
-    auto loopOp =
-        spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
+    auto loopControl = spirv::LoopControl::None;
+    if (auto attr =
+            forOp->getAttrOfType<spirv::LoopControlAttr>("spirv.loop_control"))
+      loopControl = attr.getValue();
+    auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
     loopOp.addEntryAndMergeBlock(rewriter);
 
     OpBuilder::InsertionGuard guard(rewriter);
@@ -348,8 +351,11 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
   matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = whileOp.getLoc();
-    auto loopOp =
-        spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
+    auto loopControl = spirv::LoopControl::None;
+    if (auto attr = whileOp->getAttrOfType<spirv::LoopControlAttr>(
+            "spirv.loop_control"))
+      loopControl = attr.getValue();
+    auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
     loopOp.addEntryAndMergeBlock(rewriter);
 
     Region &beforeRegion = whileOp.getBefore();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index c9b22fe145d88..8f508f9f0374b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1039,6 +1039,10 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
   } else if (symbol == spirv::getTargetEnvAttrName()) {
     if (!isa<spirv::TargetEnvAttr>(attr))
       return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
+  } else if (symbol == "spirv.loop_control") {
+    if (!isa<spirv::LoopControlAttr>(attr))
+      return op->emitError("'")
+             << symbol << "' must be a spirv::LoopControlAttr";
   } else {
     return op->emitError("found unsupported '")
            << symbol << "' attribute on operation";
diff --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir
index 9c552166cd72d..702bee476668f 100644
--- a/mlir/test/Conversion/SCFToSPIRV/for.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir
@@ -89,4 +89,30 @@ func.func @loop_yield(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>
   return
 }
 
+// CHECK-LABEL: @loop_unroll
+func.func @loop_unroll(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: spirv.mlir.loop control(Unroll) {
+  scf.for %arg4 = %lb to %ub step %step {
+    %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+  } {spirv.loop_control = #spirv.loop_control<Unroll>}
+  return
+}
+
+// CHECK-LABEL: @loop_dont_unroll
+func.func @loop_dont_unroll(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: spirv.mlir.loop control(DontUnroll) {
+  scf.for %arg4 = %lb to %ub step %step {
+    %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+  } {spirv.loop_control = #spirv.loop_control<DontUnroll>}
+  return
+}
+
 } // end module
diff --git a/mlir/test/Conversion/SCFToSPIRV/while.mlir b/mlir/test/Conversion/SCFToSPIRV/while.mlir
index ff455383a7f05..29e12fc4ccbdf 100644
--- a/mlir/test/Conversion/SCFToSPIRV/while.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/while.mlir
@@ -127,4 +127,21 @@ func.func @while_loop_after_typeconv(%arg0: f32) -> index {
   return %res : index
 }
 
+// -----
+
+// CHECK-LABEL: @while_loop_unroll
+func.func @while_loop_unroll(%arg0: i32, %arg1: i32) -> i32 {
+  %c2_i32 = arith.constant 2 : i32
+  // CHECK: spirv.mlir.loop control(Unroll) {
+  %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : i32
+    scf.condition(%1) %arg3 : i32
+  } do {
+  ^bb0(%arg5: i32):
+    %1 = arith.muli %arg5, %c2_i32 : i32
+    scf.yield %1 : i32
+  } attributes {spirv.loop_control = #spirv.loop_control<Unroll>}
+  return %0 : i32
+}
+
 } // end module

>From 233201ffd4e57e411d8d6459e6bd648a78ce3845 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 31 Mar 2026 06:02:48 +0200
Subject: [PATCH 2/2] Add function with attribute name

---
 mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h | 3 +++
 mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp     | 7 ++++---
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp        | 2 +-
 mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp        | 2 ++
 4 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
index 24574bfaf6199..7e11eb653c126 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
@@ -109,6 +109,9 @@ DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op);
 /// "Table 46. Required Limits" of the Vulkan spec.
 ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context);
 
+/// Returns the attribute name for specifying loop control.
+StringRef getLoopControlAttrName();
+
 /// Returns the attribute name for specifying SPIR-V target environment.
 StringRef getTargetEnvAttrName();
 
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index a7dcb7045a256..a9c6f7db847d3 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -136,8 +137,8 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
     // from header to merge.
     auto loc = forOp.getLoc();
     auto loopControl = spirv::LoopControl::None;
-    if (auto attr =
-            forOp->getAttrOfType<spirv::LoopControlAttr>("spirv.loop_control"))
+    if (auto attr = forOp->getAttrOfType<spirv::LoopControlAttr>(
+            spirv::getLoopControlAttrName()))
       loopControl = attr.getValue();
     auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
     loopOp.addEntryAndMergeBlock(rewriter);
@@ -353,7 +354,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
     auto loc = whileOp.getLoc();
     auto loopControl = spirv::LoopControl::None;
     if (auto attr = whileOp->getAttrOfType<spirv::LoopControlAttr>(
-            "spirv.loop_control"))
+            spirv::getLoopControlAttrName()))
       loopControl = attr.getValue();
     auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
     loopOp.addEntryAndMergeBlock(rewriter);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8f508f9f0374b..5782b42dba026 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1039,7 +1039,7 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
   } else if (symbol == spirv::getTargetEnvAttrName()) {
     if (!isa<spirv::TargetEnvAttr>(attr))
       return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
-  } else if (symbol == "spirv.loop_control") {
+  } else if (symbol == spirv::getLoopControlAttrName()) {
     if (!isa<spirv::LoopControlAttr>(attr))
       return op->emitError("'")
              << symbol << "' must be a spirv::LoopControlAttr";
diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
index 8c52ba8b85835..270cb6df20415 100644
--- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
@@ -165,6 +165,8 @@ spirv::getDefaultResourceLimits(MLIRContext *context) {
       /*cooperative_matrix_properties_nv=*/ArrayAttr{});
 }
 
+StringRef spirv::getLoopControlAttrName() { return "spirv.loop_control"; }
+
 StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
 
 spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {



More information about the Mlir-commits mailing list