[Mlir-commits] [mlir] [mlir][spirv] Add spirv-to-llvm conversion for group operations (PR #115501)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 8 08:03:00 PST 2024
================
@@ -1089,6 +1096,186 @@ class ControlBarrierPattern
}
};
+namespace {
+
+StringRef getTypeMangling(Type type, bool isSigned) {
+ return llvm::TypeSwitch<Type, StringRef>(type)
+ .Case<Float16Type>([](auto) { return "Dh"; })
+ .template Case<Float32Type>([](auto) { return "f"; })
+ .template Case<Float64Type>([](auto) { return "d"; })
+ .template Case<IntegerType>([isSigned](IntegerType intTy) {
+ switch (intTy.getWidth()) {
+ case 1:
+ return "b";
+ case 8:
+ return (isSigned) ? "a" : "c";
+ case 16:
+ return (isSigned) ? "s" : "t";
+ case 32:
+ return (isSigned) ? "i" : "j";
+ case 64:
+ return (isSigned) ? "l" : "m";
+ default: {
+ assert(false && "Unsupported integer width");
+ return "";
+ }
+ }
+ })
+ .Default([](auto) {
+ assert(false && "No mangling defined");
+ return "";
+ });
+}
+
+template <typename ReduceOp>
+constexpr StringLiteral getGroupFuncName() {
+ assert(false && "No builtin defined");
+ return "";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
+ return "_Z17__spirv_GroupIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
+ return "_Z17__spirv_GroupFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
+ return "_Z17__spirv_GroupSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
+ return "_Z17__spirv_GroupUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
+ return "_Z17__spirv_GroupFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
+ return "_Z17__spirv_GroupSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
+ return "_Z17__spirv_GroupUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
+ return "_Z17__spirv_GroupFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
+ return "_Z27__spirv_GroupNonUniformIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
+ return "_Z27__spirv_GroupNonUniformFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
+ return "_Z27__spirv_GroupNonUniformIMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
+ return "_Z27__spirv_GroupNonUniformFMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
+ return "_Z27__spirv_GroupNonUniformSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
+ return "_Z27__spirv_GroupNonUniformUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
+ return "_Z27__spirv_GroupNonUniformFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
+ return "_Z33__spirv_GroupNonUniformBitwiseAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
+ return "_Z32__spirv_GroupNonUniformBitwiseOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
+ return "_Z33__spirv_GroupNonUniformBitwiseXorii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
+ return "_Z33__spirv_GroupNonUniformLogicalAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
+ return "_Z32__spirv_GroupNonUniformLogicalOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
+ return "_Z33__spirv_GroupNonUniformLogicalXorii";
+}
+} // namespace
+
+template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
+class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
+public:
+ using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type retTy = op.getResult().getType();
+ if (!retTy.isIntOrFloat()) {
+ return failure();
+ }
+ SmallString<20> funcName = getGroupFuncName<ReduceOp>();
----------------
FMarno wrote:
```suggestion
SmallString<36> funcName = getGroupFuncName<ReduceOp>();
```
NIT: Should cover all cases
https://github.com/llvm/llvm-project/pull/115501
More information about the Mlir-commits
mailing list