[Mlir-commits] [mlir] [mlir][spirv] Simplify gpu reduction to spirv logic. NFC. (PR #73546)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 27 09:06:04 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
Check the type only once and then specialize op handlers based on it.
---
Full diff: https://github.com/llvm/llvm-project/pull/73546.diff
1 Files Affected:
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+30-30)
``````````diff
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 6536bbe1f4dba47..655cd1f7a2c795e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -485,20 +485,22 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
Location loc, Value arg,
gpu::AllReduceOperation opType,
bool isGroup, bool isUniform) {
+ enum class ElemType {
+ Float, Boolean, Integer
+ };
using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
struct OpHandler {
- gpu::AllReduceOperation type;
- FuncT intFunc;
- FuncT floatFunc;
+ gpu::AllReduceOperation kind;
+ ElemType elemType;
+ FuncT func;
};
Type type = arg.getType();
- using MembptrT = FuncT OpHandler::*;
- MembptrT handlerPtr;
+ ElemType elementType;
if (isa<FloatType>(type)) {
- handlerPtr = &OpHandler::floatFunc;
- } else if (isa<IntegerType>(type)) {
- handlerPtr = &OpHandler::intFunc;
+ elementType = ElemType::Float;
+ } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+ elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean : ElemType::Integer;
} else {
return std::nullopt;
}
@@ -510,48 +512,46 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
using ReduceType = gpu::AllReduceOperation;
const OpHandler handlers[] = {
- {ReduceType::ADD,
+ {ReduceType::ADD, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupIAddOp,
- spirv::GroupNonUniformIAddOp>,
+ spirv::GroupNonUniformIAddOp>},
+ {ReduceType::ADD, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFAddOp,
spirv::GroupNonUniformFAddOp>},
- {ReduceType::MUL,
+ {ReduceType::MUL, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupIMulKHROp,
- spirv::GroupNonUniformIMulOp>,
+ spirv::GroupNonUniformIMulOp>},
+ {ReduceType::MUL, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMulKHROp,
spirv::GroupNonUniformFMulOp>},
- {ReduceType::MINUI,
+ {ReduceType::MINUI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupUMinOp,
- spirv::GroupNonUniformUMinOp>,
- nullptr},
- {ReduceType::MINSI,
+ spirv::GroupNonUniformUMinOp>},
+ {ReduceType::MINSI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupSMinOp,
- spirv::GroupNonUniformSMinOp>,
- nullptr},
- {ReduceType::MINF, nullptr,
+ spirv::GroupNonUniformSMinOp>},
+ {ReduceType::MINF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
- {ReduceType::MAXUI,
+ {ReduceType::MAXUI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupUMaxOp,
- spirv::GroupNonUniformUMaxOp>,
- nullptr},
- {ReduceType::MAXSI,
+ spirv::GroupNonUniformUMaxOp>},
+ {ReduceType::MAXSI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupSMaxOp,
- spirv::GroupNonUniformSMaxOp>,
- nullptr},
- {ReduceType::MAXF, nullptr,
+ spirv::GroupNonUniformSMaxOp>},
+ {ReduceType::MAXF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>},
- {ReduceType::MINIMUMF, nullptr,
+ {ReduceType::MINIMUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
- {ReduceType::MAXIMUMF, nullptr,
+ {ReduceType::MAXIMUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>}};
for (const OpHandler &handler : handlers)
- if (handler.type == opType)
- return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
+ if (handler.kind == opType && elementType == handler.elemType)
+ return handler.func(builder, loc, arg, isGroup, isUniform);
return std::nullopt;
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/73546
More information about the Mlir-commits
mailing list