[Mlir-commits] [mlir] [mlir][spirv] Simplify gpu reduction to spirv logic. NFC. (PR #73546)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 09:06:04 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

Check the type only once and then specialize op handlers based on it.

---
Full diff: https://github.com/llvm/llvm-project/pull/73546.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+30-30) 


``````````diff
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 6536bbe1f4dba47..655cd1f7a2c795e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -485,20 +485,22 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
                                                 Location loc, Value arg,
                                                 gpu::AllReduceOperation opType,
                                                 bool isGroup, bool isUniform) {
+  enum class ElemType {
+    Float, Boolean, Integer
+  };
   using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
   struct OpHandler {
-    gpu::AllReduceOperation type;
-    FuncT intFunc;
-    FuncT floatFunc;
+    gpu::AllReduceOperation kind;
+    ElemType elemType;
+    FuncT func;
   };
 
   Type type = arg.getType();
-  using MembptrT = FuncT OpHandler::*;
-  MembptrT handlerPtr;
+  ElemType elementType;
   if (isa<FloatType>(type)) {
-    handlerPtr = &OpHandler::floatFunc;
-  } else if (isa<IntegerType>(type)) {
-    handlerPtr = &OpHandler::intFunc;
+    elementType = ElemType::Float;
+  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+    elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean : ElemType::Integer;
   } else {
     return std::nullopt;
   }
@@ -510,48 +512,46 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
 
   using ReduceType = gpu::AllReduceOperation;
   const OpHandler handlers[] = {
-      {ReduceType::ADD,
+      {ReduceType::ADD, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupIAddOp,
-                                spirv::GroupNonUniformIAddOp>,
+                                spirv::GroupNonUniformIAddOp>},
+      {ReduceType::ADD, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFAddOp,
                                 spirv::GroupNonUniformFAddOp>},
-      {ReduceType::MUL,
+      {ReduceType::MUL, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
-                                spirv::GroupNonUniformIMulOp>,
+                                spirv::GroupNonUniformIMulOp>},
+      {ReduceType::MUL, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
                                 spirv::GroupNonUniformFMulOp>},
-      {ReduceType::MINUI,
+      {ReduceType::MINUI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupUMinOp,
-                                spirv::GroupNonUniformUMinOp>,
-       nullptr},
-      {ReduceType::MINSI,
+                                spirv::GroupNonUniformUMinOp>},
+      {ReduceType::MINSI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupSMinOp,
-                                spirv::GroupNonUniformSMinOp>,
-       nullptr},
-      {ReduceType::MINF, nullptr,
+                                spirv::GroupNonUniformSMinOp>},
+      {ReduceType::MINF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMinOp,
                                 spirv::GroupNonUniformFMinOp>},
-      {ReduceType::MAXUI,
+      {ReduceType::MAXUI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupUMaxOp,
-                                spirv::GroupNonUniformUMaxOp>,
-       nullptr},
-      {ReduceType::MAXSI,
+                                spirv::GroupNonUniformUMaxOp>},
+      {ReduceType::MAXSI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupSMaxOp,
-                                spirv::GroupNonUniformSMaxOp>,
-       nullptr},
-      {ReduceType::MAXF, nullptr,
+                                spirv::GroupNonUniformSMaxOp>},
+      {ReduceType::MAXF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMaxOp,
                                 spirv::GroupNonUniformFMaxOp>},
-      {ReduceType::MINIMUMF, nullptr,
+      {ReduceType::MINIMUMF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMinOp,
                                 spirv::GroupNonUniformFMinOp>},
-      {ReduceType::MAXIMUMF, nullptr,
+      {ReduceType::MAXIMUMF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMaxOp,
                                 spirv::GroupNonUniformFMaxOp>}};
 
   for (const OpHandler &handler : handlers)
-    if (handler.type == opType)
-      return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
+    if (handler.kind == opType && elementType == handler.elemType)
+      return handler.func(builder, loc, arg, isGroup, isUniform);
 
   return std::nullopt;
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/73546


More information about the Mlir-commits mailing list