[Mlir-commits] [mlir] [mlir][spirv] Simplify gpu reduction to spirv logic. NFC. (PR #73546)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Nov 27 09:05:33 PST 2023
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/73546
Check the type only once and then specialize op handlers based on it.
>From b83ea3bc0d8901e4dc473e47c5524e29f4a186c1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 Nov 2023 12:03:38 -0500
Subject: [PATCH] [mlir][spirv] Simplify gpu reduction to spirv logic. NFC.
Check the type only once and then specialize op handlers based on it.
---
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 60 +++++++++----------
1 file changed, 30 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 6536bbe1f4dba47..655cd1f7a2c795e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -485,20 +485,22 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
Location loc, Value arg,
gpu::AllReduceOperation opType,
bool isGroup, bool isUniform) {
+ enum class ElemType {
+ Float, Boolean, Integer
+ };
using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
struct OpHandler {
- gpu::AllReduceOperation type;
- FuncT intFunc;
- FuncT floatFunc;
+ gpu::AllReduceOperation kind;
+ ElemType elemType;
+ FuncT func;
};
Type type = arg.getType();
- using MembptrT = FuncT OpHandler::*;
- MembptrT handlerPtr;
+ ElemType elementType;
if (isa<FloatType>(type)) {
- handlerPtr = &OpHandler::floatFunc;
- } else if (isa<IntegerType>(type)) {
- handlerPtr = &OpHandler::intFunc;
+ elementType = ElemType::Float;
+ } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+ elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean : ElemType::Integer;
} else {
return std::nullopt;
}
@@ -510,48 +512,46 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
using ReduceType = gpu::AllReduceOperation;
const OpHandler handlers[] = {
- {ReduceType::ADD,
+ {ReduceType::ADD, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupIAddOp,
- spirv::GroupNonUniformIAddOp>,
+ spirv::GroupNonUniformIAddOp>},
+ {ReduceType::ADD, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFAddOp,
spirv::GroupNonUniformFAddOp>},
- {ReduceType::MUL,
+ {ReduceType::MUL, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupIMulKHROp,
- spirv::GroupNonUniformIMulOp>,
+ spirv::GroupNonUniformIMulOp>},
+ {ReduceType::MUL, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMulKHROp,
spirv::GroupNonUniformFMulOp>},
- {ReduceType::MINUI,
+ {ReduceType::MINUI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupUMinOp,
- spirv::GroupNonUniformUMinOp>,
- nullptr},
- {ReduceType::MINSI,
+ spirv::GroupNonUniformUMinOp>},
+ {ReduceType::MINSI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupSMinOp,
- spirv::GroupNonUniformSMinOp>,
- nullptr},
- {ReduceType::MINF, nullptr,
+ spirv::GroupNonUniformSMinOp>},
+ {ReduceType::MINF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
- {ReduceType::MAXUI,
+ {ReduceType::MAXUI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupUMaxOp,
- spirv::GroupNonUniformUMaxOp>,
- nullptr},
- {ReduceType::MAXSI,
+ spirv::GroupNonUniformUMaxOp>},
+ {ReduceType::MAXSI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupSMaxOp,
- spirv::GroupNonUniformSMaxOp>,
- nullptr},
- {ReduceType::MAXF, nullptr,
+ spirv::GroupNonUniformSMaxOp>},
+ {ReduceType::MAXF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>},
- {ReduceType::MINIMUMF, nullptr,
+ {ReduceType::MINIMUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
- {ReduceType::MAXIMUMF, nullptr,
+ {ReduceType::MAXIMUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>}};
for (const OpHandler &handler : handlers)
- if (handler.type == opType)
- return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
+ if (handler.kind == opType && elementType == handler.elemType)
+ return handler.func(builder, loc, arg, isGroup, isUniform);
return std::nullopt;
}
More information about the Mlir-commits
mailing list