[Mlir-commits] [mlir] 1bdb2e8 - [mlir][spirv] Simplify gpu reduction to spirv logic (#73546)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 09:33:45 PST 2023


Author: Jakub Kuderski
Date: 2023-11-27T12:33:41-05:00
New Revision: 1bdb2e85509da95544fbad7af053196061feca8c

URL: https://github.com/llvm/llvm-project/commit/1bdb2e85509da95544fbad7af053196061feca8c
DIFF: https://github.com/llvm/llvm-project/commit/1bdb2e85509da95544fbad7af053196061feca8c.diff

LOG: [mlir][spirv] Simplify gpu reduction to spirv logic (#73546)

Check the type only once and then specialize op handlers based on it.

Do not use boolean types in group arithmetic ops that expect integers.

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/test/Conversion/GPUToSPIRV/reductions.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 6536bbe1f4dba47..5a88ab351866bdc 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -485,20 +485,21 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
                                                 Location loc, Value arg,
                                                 gpu::AllReduceOperation opType,
                                                 bool isGroup, bool isUniform) {
+  enum class ElemType { Float, Boolean, Integer };
   using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
   struct OpHandler {
-    gpu::AllReduceOperation type;
-    FuncT intFunc;
-    FuncT floatFunc;
+    gpu::AllReduceOperation kind;
+    ElemType elemType;
+    FuncT func;
   };
 
   Type type = arg.getType();
-  using MembptrT = FuncT OpHandler::*;
-  MembptrT handlerPtr;
+  ElemType elementType;
   if (isa<FloatType>(type)) {
-    handlerPtr = &OpHandler::floatFunc;
-  } else if (isa<IntegerType>(type)) {
-    handlerPtr = &OpHandler::intFunc;
+    elementType = ElemType::Float;
+  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+    elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
+                                                       : ElemType::Integer;
   } else {
     return std::nullopt;
   }
@@ -510,48 +511,46 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
 
   using ReduceType = gpu::AllReduceOperation;
   const OpHandler handlers[] = {
-      {ReduceType::ADD,
+      {ReduceType::ADD, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupIAddOp,
-                                spirv::GroupNonUniformIAddOp>,
+                                spirv::GroupNonUniformIAddOp>},
+      {ReduceType::ADD, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFAddOp,
                                 spirv::GroupNonUniformFAddOp>},
-      {ReduceType::MUL,
+      {ReduceType::MUL, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
-                                spirv::GroupNonUniformIMulOp>,
+                                spirv::GroupNonUniformIMulOp>},
+      {ReduceType::MUL, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
                                 spirv::GroupNonUniformFMulOp>},
-      {ReduceType::MINUI,
+      {ReduceType::MINUI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupUMinOp,
-                                spirv::GroupNonUniformUMinOp>,
-       nullptr},
-      {ReduceType::MINSI,
+                                spirv::GroupNonUniformUMinOp>},
+      {ReduceType::MINSI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupSMinOp,
-                                spirv::GroupNonUniformSMinOp>,
-       nullptr},
-      {ReduceType::MINF, nullptr,
+                                spirv::GroupNonUniformSMinOp>},
+      {ReduceType::MINF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMinOp,
                                 spirv::GroupNonUniformFMinOp>},
-      {ReduceType::MAXUI,
+      {ReduceType::MAXUI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupUMaxOp,
-                                spirv::GroupNonUniformUMaxOp>,
-       nullptr},
-      {ReduceType::MAXSI,
+                                spirv::GroupNonUniformUMaxOp>},
+      {ReduceType::MAXSI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupSMaxOp,
-                                spirv::GroupNonUniformSMaxOp>,
-       nullptr},
-      {ReduceType::MAXF, nullptr,
+                                spirv::GroupNonUniformSMaxOp>},
+      {ReduceType::MAXF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMaxOp,
                                 spirv::GroupNonUniformFMaxOp>},
-      {ReduceType::MINIMUMF, nullptr,
+      {ReduceType::MINIMUMF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMinOp,
                                 spirv::GroupNonUniformFMinOp>},
-      {ReduceType::MAXIMUMF, nullptr,
+      {ReduceType::MAXIMUMF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMaxOp,
                                 spirv::GroupNonUniformFMaxOp>}};
 
   for (const OpHandler &handler : handlers)
-    if (handler.type == opType)
-      return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
+    if (handler.kind == opType && elementType == handler.elemType)
+      return handler.func(builder, loc, arg, isGroup, isUniform);
 
   return std::nullopt;
 }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index feb7ee185a8a858..636078181cae726 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -652,3 +652,102 @@ gpu.module @kernels {
 }
 
 }
+
+// -----
+
+// TODO: Handle boolean reductions.
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  gpu.func @add(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce add %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @mul(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce mul %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @minsi(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce minsi %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @minui(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce minui %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @maxsi(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce maxsi %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @maxui(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce maxui %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}


        


More information about the Mlir-commits mailing list