[Mlir-commits] [mlir] [mlir][spirv] Add instruction OpGroupNonUniformRotateKHR (PR #133428)

Hsiangkai Wang llvmlistbot at llvm.org
Thu Apr 3 02:22:44 PDT 2025


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/133428

>From 8038300e9c5b317f7a4fd3bd0461cdc391b39d0c Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 27 Mar 2025 16:50:20 +0000
Subject: [PATCH 1/5] [mlir][spirv] Add instruction OpGroupNonUniformRotateKHR

Add an instruction under the extension SPV_KHR_subgroup_rotate.

The specification for the extension is here:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_subgroup_rotate.html
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  4 +-
 .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td    | 75 +++++++++++++++++++
 .../Dialect/SPIRV/IR/non-uniform-ops.mlir     | 23 ++++++
 3 files changed, 101 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d5359da2a590e..cd5d201c3d5da 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4489,6 +4489,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
 def SPIRV_OC_OpUDot                           : I32EnumAttrCase<"OpUDot", 4451>;
@@ -4598,7 +4599,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
       SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
-      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR,
+      SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat,
       SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
       SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 98e435c18d3d7..f195adfc0e73d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1361,4 +1361,79 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo
 
 // -----
 
+def SPIRV_GroupNonUniformRotateKHR : SPIRV_Op<"GroupNonUniformRotateKHR", []> {
+  let summary = [{
+    Rotate values across invocations within a subgroup.
+  }];
+
+  let description = [{
+    Return the Value of the invocation whose id within the group is calculated
+    as follows:
+
+    LocalId = SubgroupLocalInvocationId if Execution is Subgroup or
+              LocalInvocationId if Execution is Workgroup
+    RotationGroupSize = ClusterSize when ClusterSize is present, otherwise
+    RotationGroupSize = SubgroupMaxSize if the Kernel capability is declared
+                        and SubgroupSize if not.
+    Invocation ID = ( (LocalId + Delta) & (RotationGroupSize - 1) ) +
+                    (LocalId & ~(RotationGroupSize - 1))
+
+    Result Type must be a scalar or vector of floating-point type, integer
+    type, or Boolean type.
+
+    Execution is a Scope. It must be either Workgroup or Subgroup.
+
+    The type of Value must be the same as Result Type.
+
+    Delta must be a scalar of integer type, whose Signedness operand is 0.
+    Delta must be dynamically uniform within Execution.
+
+    Delta is treated as unsigned and the resulting value is undefined if the
+    selected lane is inactive.
+
+    ClusterSize is the size of cluster to use. ClusterSize must be a scalar of
+    integer type, whose Signedness operand is 0. ClusterSize must come from a
+    constant instruction. Behavior is undefined unless ClusterSize is at least
+    1 and a power of 2. If ClusterSize is greater than the declared
+    SubGroupSize, executing this instruction results in undefined behavior.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %four = spirv.Constant 4 : i32
+    %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %value, %delta : f32, i32 -> f32
+    %1 = spirv.GroupNonUniformRotateKHR <Workgroup>, %value, %delta,
+         clustersize(%four) : f32, i32, i32 -> f32
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_3>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_GroupNonUniformRotateKHR]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScopeAttr:$execution_scope,
+    SPIRV_Type:$value,
+    SPIRV_Integer:$delta,
+    Optional<SPIRV_Integer>:$cluster_size
+  );
+
+  let results = (outs
+    SPIRV_Type:$result
+  );
+
+  let hasVerifier = 0;
+
+  let assemblyFormat = [{
+    $execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
+  }];
+}
+
+// -----
+
 #endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 60ae1584d29fb..60b99d51363e9 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -604,3 +604,26 @@ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
   %0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
   return %0: i32
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformRotateKHR
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_rotate_khr
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+  // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup>, %{{.+}} : f32, i32 -> f32
+  %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta : f32, i32 -> f32
+  return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @group_non_uniform_rotate_khr
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+  // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup>, %{{.+}} : f32, i32, i32 -> f32
+  %four = spirv.Constant 4 : i32
+  %0 = spirv.GroupNonUniformRotateKHR <Workgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+  return %0: f32
+}

>From b9d6130b72bbc587534c8f19bb93718ee6145df2 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Tue, 1 Apr 2025 10:36:58 +0100
Subject: [PATCH 2/5] Add operand and result type constraint

---
 .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td    |  7 ++-
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        | 32 ++++++++++++++
 .../Dialect/SPIRV/IR/non-uniform-ops.mlir     | 44 +++++++++++++++++++
 3 files changed, 79 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index f195adfc0e73d..10609ea0c53d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1361,7 +1361,8 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo
 
 // -----
 
-def SPIRV_GroupNonUniformRotateKHR : SPIRV_Op<"GroupNonUniformRotateKHR", []> {
+def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
+  Pure, AllTypesMatch<["value", "result"]>]> {
   let summary = [{
     Rotate values across invocations within a subgroup.
   }];
@@ -1424,11 +1425,9 @@ def SPIRV_GroupNonUniformRotateKHR : SPIRV_Op<"GroupNonUniformRotateKHR", []> {
   );
 
   let results = (outs
-    SPIRV_Type:$result
+    AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result
   );
 
-  let hasVerifier = 0;
-
   let assemblyFormat = [{
     $execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
   }];
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 8aeafda0eb755..f466d4d09854f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -304,6 +304,38 @@ LogicalResult GroupNonUniformLogicalXorOp::verify() {
   return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformRotateKHR
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformRotateKHROp::verify() {
+  spirv::Scope scope = getExecutionScope();
+  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+  if (getDelta().getType().isSignedInteger())
+    return emitOpError("delta must be a singless/unsigned integer");
+
+  auto clusterSizeVal = getClusterSize();
+  if (clusterSizeVal) {
+    if (clusterSizeVal.getType().isSignedInteger())
+      return emitOpError("cluster size must be a singless/unsigned integer");
+
+    mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
+    int32_t clusterSize = 0;
+
+    if (failed(extractValueFromConstOp(defOp, clusterSize)))
+      return emitOpError(
+          "cluster size operand must come from a constant op");
+
+    if (!llvm::isPowerOf2_32(clusterSize))
+      return emitOpError(
+          "cluster size operand must be a power of two");
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Group op verification
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 60b99d51363e9..714fb8c057fac 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -627,3 +627,47 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
   %0 = spirv.GroupNonUniformRotateKHR <Workgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
   return %0: f32
 }
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+  %four = spirv.Constant 4 : i32
+  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  %0 = spirv.GroupNonUniformRotateKHR <Device>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+  return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
+  %four = spirv.Constant 4 : i32
+  // expected-error @+1 {{delta must be a singless/unsigned integer}}
+  %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32
+  return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+  %four = spirv.Constant 4 : si32
+  // expected-error @+1 {{cluster size must be a singless/unsigned integer}}
+  %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32
+  return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32, %four: i32) -> f32 {
+  // expected-error @+1 {{cluster size operand must come from a constant op}}
+  %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+  return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+  %five = spirv.Constant 5 : i32
+  // expected-error @+1 {{cluster size operand must be a power of two}}
+  %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
+  return %0: f32
+}

>From 46baabd2d47b4eb87c1bb05a190baf582515b14c Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Tue, 1 Apr 2025 14:30:10 +0100
Subject: [PATCH 3/5] clang-format

---
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index f466d4d09854f..7ef1b1a484e03 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -325,12 +325,10 @@ LogicalResult GroupNonUniformRotateKHROp::verify() {
     int32_t clusterSize = 0;
 
     if (failed(extractValueFromConstOp(defOp, clusterSize)))
-      return emitOpError(
-          "cluster size operand must come from a constant op");
+      return emitOpError("cluster size operand must come from a constant op");
 
     if (!llvm::isPowerOf2_32(clusterSize))
-      return emitOpError(
-          "cluster size operand must be a power of two");
+      return emitOpError("cluster size operand must be a power of two");
   }
 
   return success();

>From 5ca11e3a133c359f566ba463e3ee1bcb783a4f59 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 2 Apr 2025 09:34:24 +0100
Subject: [PATCH 4/5] Address comments

---
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 6 +++---
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp                   | 9 +--------
 mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir          | 4 ++--
 3 files changed, 6 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 10609ea0c53d0..2dd3dbd28d436 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1419,9 +1419,9 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
 
   let arguments = (ins
     SPIRV_ScopeAttr:$execution_scope,
-    SPIRV_Type:$value,
-    SPIRV_Integer:$delta,
-    Optional<SPIRV_Integer>:$cluster_size
+    AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value,
+    SPIRV_SignlessOrUnsignedInt:$delta,
+    Optional<SPIRV_SignlessOrUnsignedInt>:$cluster_size
   );
 
   let results = (outs
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 7ef1b1a484e03..94881338ee12a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -313,14 +313,7 @@ LogicalResult GroupNonUniformRotateKHROp::verify() {
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
 
-  if (getDelta().getType().isSignedInteger())
-    return emitOpError("delta must be a singless/unsigned integer");
-
-  auto clusterSizeVal = getClusterSize();
-  if (clusterSizeVal) {
-    if (clusterSizeVal.getType().isSignedInteger())
-      return emitOpError("cluster size must be a singless/unsigned integer");
-
+  if (TypedValue<Type> clusterSizeVal = getClusterSize()) {
     mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
     int32_t clusterSize = 0;
 
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 714fb8c057fac..bf383d3837b6e 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -641,7 +641,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
 
 func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
   %four = spirv.Constant 4 : i32
-  // expected-error @+1 {{delta must be a singless/unsigned integer}}
+  // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
   %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32
   return %0: f32
 }
@@ -650,7 +650,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
 
 func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
   %four = spirv.Constant 4 : si32
-  // expected-error @+1 {{cluster size must be a singless/unsigned integer}}
+  // expected-error @+1 {{op operand #2 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
   %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32
   return %0: f32
 }

>From e198ca1ee5c7fc0cd63108c1102e7b557ec68876 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 3 Apr 2025 10:18:20 +0100
Subject: [PATCH 5/5] Address comment

---
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 94881338ee12a..461d037134dae 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -313,7 +313,7 @@ LogicalResult GroupNonUniformRotateKHROp::verify() {
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
 
-  if (TypedValue<Type> clusterSizeVal = getClusterSize()) {
+  if (Value clusterSizeVal = getClusterSize()) {
     mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
     int32_t clusterSize = 0;
 



More information about the Mlir-commits mailing list