[Mlir-commits] [mlir] 2e7ed78 - [mlir][spirv] Add instruction OpGroupNonUniformRotateKHR (#133428)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 3 03:00:32 PDT 2025
Author: Hsiangkai Wang
Date: 2025-04-03T11:00:29+01:00
New Revision: 2e7ed78cff0ad3e3535443ce8c0c3c0e0925ff73
URL: https://github.com/llvm/llvm-project/commit/2e7ed78cff0ad3e3535443ce8c0c3c0e0925ff73
DIFF: https://github.com/llvm/llvm-project/commit/2e7ed78cff0ad3e3535443ce8c0c3c0e0925ff73.diff
LOG: [mlir][spirv] Add instruction OpGroupNonUniformRotateKHR (#133428)
Add an instruction under the extension SPV_KHR_subgroup_rotate.
The specification for the extension is here:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_subgroup_rotate.html
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d5359da2a590e..cd5d201c3d5da 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4489,6 +4489,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>;
@@ -4598,7 +4599,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
- SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+ SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR,
+ SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat,
SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 98e435c18d3d7..2dd3dbd28d436 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1361,4 +1361,78 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo
// -----
+def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
+ Pure, AllTypesMatch<["value", "result"]>]> {
+ let summary = [{
+ Rotate values across invocations within a subgroup.
+ }];
+
+ let description = [{
+ Return the Value of the invocation whose id within the group is calculated
+ as follows:
+
+ LocalId = SubgroupLocalInvocationId if Execution is Subgroup or
+ LocalInvocationId if Execution is Workgroup
+ RotationGroupSize = ClusterSize when ClusterSize is present, otherwise
+ RotationGroupSize = SubgroupMaxSize if the Kernel capability is declared
+ and SubgroupSize if not.
+ Invocation ID = ( (LocalId + Delta) & (RotationGroupSize - 1) ) +
+ (LocalId & ~(RotationGroupSize - 1))
+
+ Result Type must be a scalar or vector of floating-point type, integer
+ type, or Boolean type.
+
+ Execution is a Scope. It must be either Workgroup or Subgroup.
+
+ The type of Value must be the same as Result Type.
+
+ Delta must be a scalar of integer type, whose Signedness operand is 0.
+ Delta must be dynamically uniform within Execution.
+
+ Delta is treated as unsigned and the resulting value is undefined if the
+ selected lane is inactive.
+
+ ClusterSize is the size of cluster to use. ClusterSize must be a scalar of
+ integer type, whose Signedness operand is 0. ClusterSize must come from a
+ constant instruction. Behavior is undefined unless ClusterSize is at least
+ 1 and a power of 2. If ClusterSize is greater than the declared
+ SubGroupSize, executing this instruction results in undefined behavior.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %four = spirv.Constant 4 : i32
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %value, %delta : f32, i32 -> f32
+ %1 = spirv.GroupNonUniformRotateKHR <Workgroup>, %value, %delta,
+ clustersize(%four) : f32, i32, i32 -> f32
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_3>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[SPIRV_C_GroupNonUniformRotateKHR]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScopeAttr:$execution_scope,
+ AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value,
+ SPIRV_SignlessOrUnsignedInt:$delta,
+ Optional<SPIRV_SignlessOrUnsignedInt>:$cluster_size
+ );
+
+ let results = (outs
+ AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result
+ );
+
+ let assemblyFormat = [{
+ $execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
+ }];
+}
+
+// -----
+
#endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 8aeafda0eb755..461d037134dae 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -304,6 +304,29 @@ LogicalResult GroupNonUniformLogicalXorOp::verify() {
return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
}
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformRotateKHR
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformRotateKHROp::verify() {
+ spirv::Scope scope = getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ if (Value clusterSizeVal = getClusterSize()) {
+ mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
+ int32_t clusterSize = 0;
+
+ if (failed(extractValueFromConstOp(defOp, clusterSize)))
+ return emitOpError("cluster size operand must come from a constant op");
+
+ if (!llvm::isPowerOf2_32(clusterSize))
+ return emitOpError("cluster size operand must be a power of two");
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Group op verification
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 60ae1584d29fb..bf383d3837b6e 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -604,3 +604,70 @@ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
%0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformRotateKHR
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_rotate_khr
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+ // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup>, %{{.+}} : f32, i32 -> f32
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta : f32, i32 -> f32
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @group_non_uniform_rotate_khr
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+ // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup>, %{{.+}} : f32, i32, i32 -> f32
+ %four = spirv.Constant 4 : i32
+ %0 = spirv.GroupNonUniformRotateKHR <Workgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+ return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+ %four = spirv.Constant 4 : i32
+ // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ %0 = spirv.GroupNonUniformRotateKHR <Device>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+ return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
+ %four = spirv.Constant 4 : i32
+ // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32
+ return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+ %four = spirv.Constant 4 : si32
+ // expected-error @+1 {{op operand #2 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32
+ return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32, %four: i32) -> f32 {
+ // expected-error @+1 {{cluster size operand must come from a constant op}}
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+ return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+ %five = spirv.Constant 5 : i32
+ // expected-error @+1 {{cluster size operand must be a power of two}}
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
+ return %0: f32
+}
More information about the Mlir-commits
mailing list