[Mlir-commits] [mlir] 09c54a8 - [mlir][SPIR-V] Support spirv.loop_control attribute on scf.for and scf.while (#189392)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 31 07:49:43 PDT 2026


Author: Arseniy Obolenskiy
Date: 2026-03-31T16:49:37+02:00
New Revision: 09c54a8f7afcd30c83862ab2792eacdb53c77a8f

URL: https://github.com/llvm/llvm-project/commit/09c54a8f7afcd30c83862ab2792eacdb53c77a8f
DIFF: https://github.com/llvm/llvm-project/commit/09c54a8f7afcd30c83862ab2792eacdb53c77a8f.diff

LOG: [mlir][SPIR-V] Support spirv.loop_control attribute on scf.for and scf.while (#189392)

Propagate the `spirv.loop_control` attribute from `scf.for` and
`scf.while` operations to the generated `spirv.mlir.loop` during
SCFToSPIRV conversion

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
    mlir/test/Conversion/SCFToSPIRV/for.mlir
    mlir/test/Conversion/SCFToSPIRV/while.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
index 24574bfaf6199..7e11eb653c126 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
@@ -109,6 +109,9 @@ DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op);
 /// "Table 46. Required Limits" of the Vulkan spec.
 ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context);
 
+/// Returns the attribute name for specifying loop control.
+StringRef getLoopControlAttrName();
+
 /// Returns the attribute name for specifying SPIR-V target environment.
 StringRef getTargetEnvAttrName();
 

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 55ed31ee695ff..a9c6f7db847d3 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -135,8 +136,11 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
     // a single back edge from the continue to header block, and a single exit
     // from header to merge.
     auto loc = forOp.getLoc();
-    auto loopOp =
-        spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
+    auto loopControl = spirv::LoopControl::None;
+    if (auto attr = forOp->getAttrOfType<spirv::LoopControlAttr>(
+            spirv::getLoopControlAttrName()))
+      loopControl = attr.getValue();
+    auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
     loopOp.addEntryAndMergeBlock(rewriter);
 
     OpBuilder::InsertionGuard guard(rewriter);
@@ -348,8 +352,11 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
   matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = whileOp.getLoc();
-    auto loopOp =
-        spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
+    auto loopControl = spirv::LoopControl::None;
+    if (auto attr = whileOp->getAttrOfType<spirv::LoopControlAttr>(
+            spirv::getLoopControlAttrName()))
+      loopControl = attr.getValue();
+    auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
     loopOp.addEntryAndMergeBlock(rewriter);
 
     Region &beforeRegion = whileOp.getBefore();

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index c9b22fe145d88..5782b42dba026 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1039,6 +1039,10 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
   } else if (symbol == spirv::getTargetEnvAttrName()) {
     if (!isa<spirv::TargetEnvAttr>(attr))
       return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
+  } else if (symbol == spirv::getLoopControlAttrName()) {
+    if (!isa<spirv::LoopControlAttr>(attr))
+      return op->emitError("'")
+             << symbol << "' must be a spirv::LoopControlAttr";
   } else {
     return op->emitError("found unsupported '")
            << symbol << "' attribute on operation";

diff  --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
index 8c52ba8b85835..270cb6df20415 100644
--- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
@@ -165,6 +165,8 @@ spirv::getDefaultResourceLimits(MLIRContext *context) {
       /*cooperative_matrix_properties_nv=*/ArrayAttr{});
 }
 
+StringRef spirv::getLoopControlAttrName() { return "spirv.loop_control"; }
+
 StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
 
 spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {

diff  --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir
index 9c552166cd72d..702bee476668f 100644
--- a/mlir/test/Conversion/SCFToSPIRV/for.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir
@@ -89,4 +89,30 @@ func.func @loop_yield(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>
   return
 }
 
+// CHECK-LABEL: @loop_unroll
+func.func @loop_unroll(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: spirv.mlir.loop control(Unroll) {
+  scf.for %arg4 = %lb to %ub step %step {
+    %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+  } {spirv.loop_control = #spirv.loop_control<Unroll>}
+  return
+}
+
+// CHECK-LABEL: @loop_dont_unroll
+func.func @loop_dont_unroll(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: spirv.mlir.loop control(DontUnroll) {
+  scf.for %arg4 = %lb to %ub step %step {
+    %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+  } {spirv.loop_control = #spirv.loop_control<DontUnroll>}
+  return
+}
+
 } // end module

diff  --git a/mlir/test/Conversion/SCFToSPIRV/while.mlir b/mlir/test/Conversion/SCFToSPIRV/while.mlir
index ff455383a7f05..29e12fc4ccbdf 100644
--- a/mlir/test/Conversion/SCFToSPIRV/while.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/while.mlir
@@ -127,4 +127,21 @@ func.func @while_loop_after_typeconv(%arg0: f32) -> index {
   return %res : index
 }
 
+// -----
+
+// CHECK-LABEL: @while_loop_unroll
+func.func @while_loop_unroll(%arg0: i32, %arg1: i32) -> i32 {
+  %c2_i32 = arith.constant 2 : i32
+  // CHECK: spirv.mlir.loop control(Unroll) {
+  %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : i32
+    scf.condition(%1) %arg3 : i32
+  } do {
+  ^bb0(%arg5: i32):
+    %1 = arith.muli %arg5, %c2_i32 : i32
+    scf.yield %1 : i32
+  } attributes {spirv.loop_control = #spirv.loop_control<Unroll>}
+  return %0 : i32
+}
+
 } // end module


        


More information about the Mlir-commits mailing list