[Mlir-commits] [mlir] [mlir][spirv] Add spirv-to-llvm conversion for group operations (PR #115501)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 8 08:03:00 PST 2024


================
@@ -1089,6 +1096,186 @@ class ControlBarrierPattern
   }
 };
 
+namespace {
+
+StringRef getTypeMangling(Type type, bool isSigned) {
+  return llvm::TypeSwitch<Type, StringRef>(type)
+      .Case<Float16Type>([](auto) { return "Dh"; })
+      .template Case<Float32Type>([](auto) { return "f"; })
+      .template Case<Float64Type>([](auto) { return "d"; })
+      .template Case<IntegerType>([isSigned](IntegerType intTy) {
+        switch (intTy.getWidth()) {
+        case 1:
+          return "b";
+        case 8:
+          return (isSigned) ? "a" : "c";
+        case 16:
+          return (isSigned) ? "s" : "t";
+        case 32:
+          return (isSigned) ? "i" : "j";
+        case 64:
+          return (isSigned) ? "l" : "m";
+        default: {
+          assert(false && "Unsupported integer width");
+          return "";
+        }
+        }
+      })
+      .Default([](auto) {
+        assert(false && "No mangling defined");
+        return "";
+      });
+}
+
+template <typename ReduceOp>
+constexpr StringLiteral getGroupFuncName() {
+  assert(false && "No builtin defined");
+  return "";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
+  return "_Z17__spirv_GroupIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
+  return "_Z17__spirv_GroupFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
+  return "_Z17__spirv_GroupSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
+  return "_Z17__spirv_GroupUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
+  return "_Z17__spirv_GroupFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
+  return "_Z17__spirv_GroupSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
+  return "_Z17__spirv_GroupUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
+  return "_Z17__spirv_GroupFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
+  return "_Z27__spirv_GroupNonUniformIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
+  return "_Z27__spirv_GroupNonUniformFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
+  return "_Z27__spirv_GroupNonUniformIMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
+  return "_Z27__spirv_GroupNonUniformFMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
+  return "_Z27__spirv_GroupNonUniformSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
+  return "_Z27__spirv_GroupNonUniformUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
+  return "_Z27__spirv_GroupNonUniformFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
+  return "_Z27__spirv_GroupNonUniformSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
+  return "_Z27__spirv_GroupNonUniformUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
+  return "_Z27__spirv_GroupNonUniformFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
+  return "_Z33__spirv_GroupNonUniformBitwiseAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
+  return "_Z32__spirv_GroupNonUniformBitwiseOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
+  return "_Z33__spirv_GroupNonUniformBitwiseXorii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
+  return "_Z33__spirv_GroupNonUniformLogicalAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
+  return "_Z32__spirv_GroupNonUniformLogicalOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
+  return "_Z33__spirv_GroupNonUniformLogicalXorii";
+}
+} // namespace
+
+template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
+class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
+public:
+  using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type retTy = op.getResult().getType();
+    if (!retTy.isIntOrFloat()) {
+      return failure();
+    }
+    SmallString<20> funcName = getGroupFuncName<ReduceOp>();
----------------
FMarno wrote:

```suggestion
    SmallString<36> funcName = getGroupFuncName<ReduceOp>();
```
NIT: Should cover all cases

https://github.com/llvm/llvm-project/pull/115501


More information about the Mlir-commits mailing list