[Mlir-commits] [mlir] [mlir][spirv] Simplify gpu reduction to spirv logic. NFC. (PR #73546)

Jakub Kuderski llvmlistbot at llvm.org
Mon Nov 27 09:18:48 PST 2023


================
@@ -485,20 +485,22 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
                                                 Location loc, Value arg,
                                                 gpu::AllReduceOperation opType,
                                                 bool isGroup, bool isUniform) {
+  enum class ElemType {
+    Float, Boolean, Integer
+  };
   using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
   struct OpHandler {
-    gpu::AllReduceOperation type;
-    FuncT intFunc;
-    FuncT floatFunc;
+    gpu::AllReduceOperation kind;
+    ElemType elemType;
+    FuncT func;
   };
 
   Type type = arg.getType();
-  using MembptrT = FuncT OpHandler::*;
-  MembptrT handlerPtr;
+  ElemType elementType;
   if (isa<FloatType>(type)) {
-    handlerPtr = &OpHandler::floatFunc;
-  } else if (isa<IntegerType>(type)) {
-    handlerPtr = &OpHandler::intFunc;
+    elementType = ElemType::Float;
+  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+    elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean : ElemType::Integer;
----------------
kuhar wrote:

Good point, I added tests. In spir-v booleans are not considered integer types, so we can't use them in most arithmetic ops.

https://github.com/llvm/llvm-project/pull/73546


More information about the Mlir-commits mailing list