[Mlir-commits] [mlir] [mlir][spirv] Simplify gpu reduction to spirv logic (PR #73546)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Nov 27 09:22:14 PST 2023
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/73546
>From b83ea3bc0d8901e4dc473e47c5524e29f4a186c1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 Nov 2023 12:03:38 -0500
Subject: [PATCH 1/3] [mlir][spirv] Simplify gpu reduction to spirv logic. NFC.
Check the type only once and then specialize op handlers based on it.
---
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 60 +++++++++----------
1 file changed, 30 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 6536bbe1f4dba47..655cd1f7a2c795e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -485,20 +485,22 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
Location loc, Value arg,
gpu::AllReduceOperation opType,
bool isGroup, bool isUniform) {
+ enum class ElemType {
+ Float, Boolean, Integer
+ };
using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
struct OpHandler {
- gpu::AllReduceOperation type;
- FuncT intFunc;
- FuncT floatFunc;
+ gpu::AllReduceOperation kind;
+ ElemType elemType;
+ FuncT func;
};
Type type = arg.getType();
- using MembptrT = FuncT OpHandler::*;
- MembptrT handlerPtr;
+ ElemType elementType;
if (isa<FloatType>(type)) {
- handlerPtr = &OpHandler::floatFunc;
- } else if (isa<IntegerType>(type)) {
- handlerPtr = &OpHandler::intFunc;
+ elementType = ElemType::Float;
+ } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+ elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean : ElemType::Integer;
} else {
return std::nullopt;
}
@@ -510,48 +512,46 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
using ReduceType = gpu::AllReduceOperation;
const OpHandler handlers[] = {
- {ReduceType::ADD,
+ {ReduceType::ADD, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupIAddOp,
- spirv::GroupNonUniformIAddOp>,
+ spirv::GroupNonUniformIAddOp>},
+ {ReduceType::ADD, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFAddOp,
spirv::GroupNonUniformFAddOp>},
- {ReduceType::MUL,
+ {ReduceType::MUL, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupIMulKHROp,
- spirv::GroupNonUniformIMulOp>,
+ spirv::GroupNonUniformIMulOp>},
+ {ReduceType::MUL, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMulKHROp,
spirv::GroupNonUniformFMulOp>},
- {ReduceType::MINUI,
+ {ReduceType::MINUI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupUMinOp,
- spirv::GroupNonUniformUMinOp>,
- nullptr},
- {ReduceType::MINSI,
+ spirv::GroupNonUniformUMinOp>},
+ {ReduceType::MINSI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupSMinOp,
- spirv::GroupNonUniformSMinOp>,
- nullptr},
- {ReduceType::MINF, nullptr,
+ spirv::GroupNonUniformSMinOp>},
+ {ReduceType::MINF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
- {ReduceType::MAXUI,
+ {ReduceType::MAXUI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupUMaxOp,
- spirv::GroupNonUniformUMaxOp>,
- nullptr},
- {ReduceType::MAXSI,
+ spirv::GroupNonUniformUMaxOp>},
+ {ReduceType::MAXSI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupSMaxOp,
- spirv::GroupNonUniformSMaxOp>,
- nullptr},
- {ReduceType::MAXF, nullptr,
+ spirv::GroupNonUniformSMaxOp>},
+ {ReduceType::MAXF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>},
- {ReduceType::MINIMUMF, nullptr,
+ {ReduceType::MINIMUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
- {ReduceType::MAXIMUMF, nullptr,
+ {ReduceType::MAXIMUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>}};
for (const OpHandler &handler : handlers)
- if (handler.type == opType)
- return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
+ if (handler.kind == opType && elementType == handler.elemType)
+ return handler.func(builder, loc, arg, isGroup, isUniform);
return std::nullopt;
}
>From 3806820b571535724b473b93108c54e5cc244116 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 Nov 2023 12:17:29 -0500
Subject: [PATCH 2/3] Add tests
---
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 7 +-
.../Conversion/GPUToSPIRV/reductions.mlir | 100 ++++++++++++++++++
2 files changed, 103 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 655cd1f7a2c795e..5a88ab351866bdc 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -485,9 +485,7 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
Location loc, Value arg,
gpu::AllReduceOperation opType,
bool isGroup, bool isUniform) {
- enum class ElemType {
- Float, Boolean, Integer
- };
+ enum class ElemType { Float, Boolean, Integer };
using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
struct OpHandler {
gpu::AllReduceOperation kind;
@@ -500,7 +498,8 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
if (isa<FloatType>(type)) {
elementType = ElemType::Float;
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
- elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean : ElemType::Integer;
+ elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
+ : ElemType::Integer;
} else {
return std::nullopt;
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index feb7ee185a8a858..57803298ea27b8f 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -647,8 +647,108 @@ gpu.module @kernels {
// CHECK: %{{.*}} = spirv.GroupNonUniformUMax "Subgroup" "Reduce" %[[ARG]] : i32
%r0 = gpu.subgroup_reduce maxsi %arg : (i32) -> (i32)
%r1 = gpu.subgroup_reduce maxui %arg : (i32) -> (i32)
+ %r3 = gpu.subgroup_reduce maxui %arg : (i32) -> (i32)
gpu.return
}
}
}
+
+// -----
+
+// TODO: Handle boolean reductions.
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+ gpu.func @add(%arg : i1) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+ %r0 = gpu.subgroup_reduce add %arg : (i1) -> (i1)
+ gpu.return
+ }
+}
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+ gpu.func @mul(%arg : i1) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+ %r0 = gpu.subgroup_reduce mul %arg : (i1) -> (i1)
+ gpu.return
+ }
+}
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+ gpu.func @minsi(%arg : i1) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+ %r0 = gpu.subgroup_reduce minsi %arg : (i1) -> (i1)
+ gpu.return
+ }
+}
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+ gpu.func @minui(%arg : i1) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+ %r0 = gpu.subgroup_reduce minui %arg : (i1) -> (i1)
+ gpu.return
+ }
+}
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+ gpu.func @maxsi(%arg : i1) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+ %r0 = gpu.subgroup_reduce maxsi %arg : (i1) -> (i1)
+ gpu.return
+ }
+}
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+ gpu.func @maxui(%arg : i1) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+ %r0 = gpu.subgroup_reduce maxui %arg : (i1) -> (i1)
+ gpu.return
+ }
+}
+}
>From ed4816d6dba6ef69c8e671c1792342f6546faba0 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 Nov 2023 12:21:54 -0500
Subject: [PATCH 3/3] nit
---
mlir/test/Conversion/GPUToSPIRV/reductions.mlir | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index 57803298ea27b8f..636078181cae726 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -647,7 +647,6 @@ gpu.module @kernels {
// CHECK: %{{.*}} = spirv.GroupNonUniformUMax "Subgroup" "Reduce" %[[ARG]] : i32
%r0 = gpu.subgroup_reduce maxsi %arg : (i32) -> (i32)
%r1 = gpu.subgroup_reduce maxui %arg : (i32) -> (i32)
- %r3 = gpu.subgroup_reduce maxui %arg : (i32) -> (i32)
gpu.return
}
}
More information about the Mlir-commits
mailing list