[Mlir-commits] [mlir] [mlir][SPIR-V] Support spirv.loop_control attribute on scf.for and scf.while (PR #189392)
Arseniy Obolenskiy
llvmlistbot at llvm.org
Mon Mar 30 21:03:08 PDT 2026
https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/189392
>From f330c2b5fb8e055f8c29d16f1c282f213db48119 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Mon, 30 Mar 2026 16:13:58 +0200
Subject: [PATCH 1/2] [mlir][SPIR-V] Support spirv.loop_control attribute on
scf.for and scf.while
Propagate the `spirv.loop_control` attribute from `scf.for` and `scf.while` operations to the generated `spirv.mlir.loop` during SCF-to-SPIR-V conversion
---
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 14 +++++++---
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 4 +++
mlir/test/Conversion/SCFToSPIRV/for.mlir | 26 +++++++++++++++++++
mlir/test/Conversion/SCFToSPIRV/while.mlir | 17 ++++++++++++
4 files changed, 57 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 55ed31ee695ff..a7dcb7045a256 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -135,8 +135,11 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
// a single back edge from the continue to header block, and a single exit
// from header to merge.
auto loc = forOp.getLoc();
- auto loopOp =
- spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
+ auto loopControl = spirv::LoopControl::None;
+ if (auto attr =
+ forOp->getAttrOfType<spirv::LoopControlAttr>("spirv.loop_control"))
+ loopControl = attr.getValue();
+ auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
loopOp.addEntryAndMergeBlock(rewriter);
OpBuilder::InsertionGuard guard(rewriter);
@@ -348,8 +351,11 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = whileOp.getLoc();
- auto loopOp =
- spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
+ auto loopControl = spirv::LoopControl::None;
+ if (auto attr = whileOp->getAttrOfType<spirv::LoopControlAttr>(
+ "spirv.loop_control"))
+ loopControl = attr.getValue();
+ auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
loopOp.addEntryAndMergeBlock(rewriter);
Region &beforeRegion = whileOp.getBefore();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index c9b22fe145d88..8f508f9f0374b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1039,6 +1039,10 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
} else if (symbol == spirv::getTargetEnvAttrName()) {
if (!isa<spirv::TargetEnvAttr>(attr))
return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
+ } else if (symbol == "spirv.loop_control") {
+ if (!isa<spirv::LoopControlAttr>(attr))
+ return op->emitError("'")
+ << symbol << "' must be a spirv::LoopControlAttr";
} else {
return op->emitError("found unsupported '")
<< symbol << "' attribute on operation";
diff --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir
index 9c552166cd72d..702bee476668f 100644
--- a/mlir/test/Conversion/SCFToSPIRV/for.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir
@@ -89,4 +89,30 @@ func.func @loop_yield(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>
return
}
+// CHECK-LABEL: @loop_unroll
+func.func @loop_unroll(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+ // CHECK: spirv.mlir.loop control(Unroll) {
+ scf.for %arg4 = %lb to %ub step %step {
+ %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+ memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+ } {spirv.loop_control = #spirv.loop_control<Unroll>}
+ return
+}
+
+// CHECK-LABEL: @loop_dont_unroll
+func.func @loop_dont_unroll(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+ // CHECK: spirv.mlir.loop control(DontUnroll) {
+ scf.for %arg4 = %lb to %ub step %step {
+ %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+ memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+ } {spirv.loop_control = #spirv.loop_control<DontUnroll>}
+ return
+}
+
} // end module
diff --git a/mlir/test/Conversion/SCFToSPIRV/while.mlir b/mlir/test/Conversion/SCFToSPIRV/while.mlir
index ff455383a7f05..29e12fc4ccbdf 100644
--- a/mlir/test/Conversion/SCFToSPIRV/while.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/while.mlir
@@ -127,4 +127,21 @@ func.func @while_loop_after_typeconv(%arg0: f32) -> index {
return %res : index
}
+// -----
+
+// CHECK-LABEL: @while_loop_unroll
+func.func @while_loop_unroll(%arg0: i32, %arg1: i32) -> i32 {
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: spirv.mlir.loop control(Unroll) {
+ %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : i32
+ scf.condition(%1) %arg3 : i32
+ } do {
+ ^bb0(%arg5: i32):
+ %1 = arith.muli %arg5, %c2_i32 : i32
+ scf.yield %1 : i32
+ } attributes {spirv.loop_control = #spirv.loop_control<Unroll>}
+ return %0 : i32
+}
+
} // end module
>From 233201ffd4e57e411d8d6459e6bd648a78ce3845 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 31 Mar 2026 06:02:48 +0200
Subject: [PATCH 2/2] Add function with attribute name
---
mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h | 3 +++
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 7 ++++---
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 2 +-
mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp | 2 ++
4 files changed, 10 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
index 24574bfaf6199..7e11eb653c126 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
@@ -109,6 +109,9 @@ DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op);
/// "Table 46. Required Limits" of the Vulkan spec.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context);
+/// Returns the attribute name for specifying loop control.
+StringRef getLoopControlAttrName();
+
/// Returns the attribute name for specifying SPIR-V target environment.
StringRef getTargetEnvAttrName();
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index a7dcb7045a256..a9c6f7db847d3 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
@@ -136,8 +137,8 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
// from header to merge.
auto loc = forOp.getLoc();
auto loopControl = spirv::LoopControl::None;
- if (auto attr =
- forOp->getAttrOfType<spirv::LoopControlAttr>("spirv.loop_control"))
+ if (auto attr = forOp->getAttrOfType<spirv::LoopControlAttr>(
+ spirv::getLoopControlAttrName()))
loopControl = attr.getValue();
auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
loopOp.addEntryAndMergeBlock(rewriter);
@@ -353,7 +354,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
auto loc = whileOp.getLoc();
auto loopControl = spirv::LoopControl::None;
if (auto attr = whileOp->getAttrOfType<spirv::LoopControlAttr>(
- "spirv.loop_control"))
+ spirv::getLoopControlAttrName()))
loopControl = attr.getValue();
auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
loopOp.addEntryAndMergeBlock(rewriter);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8f508f9f0374b..5782b42dba026 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1039,7 +1039,7 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
} else if (symbol == spirv::getTargetEnvAttrName()) {
if (!isa<spirv::TargetEnvAttr>(attr))
return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
- } else if (symbol == "spirv.loop_control") {
+ } else if (symbol == spirv::getLoopControlAttrName()) {
if (!isa<spirv::LoopControlAttr>(attr))
return op->emitError("'")
<< symbol << "' must be a spirv::LoopControlAttr";
diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
index 8c52ba8b85835..270cb6df20415 100644
--- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
@@ -165,6 +165,8 @@ spirv::getDefaultResourceLimits(MLIRContext *context) {
/*cooperative_matrix_properties_nv=*/ArrayAttr{});
}
+StringRef spirv::getLoopControlAttrName() { return "spirv.loop_control"; }
+
StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
More information about the Mlir-commits
mailing list