[Mlir-commits] [mlir] [mlir][spirv] Simplify gpu reduction to spirv logic. NFC. (PR #73546)

Jakub Kuderski llvmlistbot at llvm.org
Mon Nov 27 09:17:48 PST 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/73546

>From b83ea3bc0d8901e4dc473e47c5524e29f4a186c1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 Nov 2023 12:03:38 -0500
Subject: [PATCH 1/2] [mlir][spirv] Simplify gpu reduction to spirv logic. NFC.

Check the type only once and then specialize op handlers based on it.
---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 60 +++++++++----------
 1 file changed, 30 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 6536bbe1f4dba47..655cd1f7a2c795e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -485,20 +485,22 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
                                                 Location loc, Value arg,
                                                 gpu::AllReduceOperation opType,
                                                 bool isGroup, bool isUniform) {
+  enum class ElemType {
+    Float, Boolean, Integer
+  };
   using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
   struct OpHandler {
-    gpu::AllReduceOperation type;
-    FuncT intFunc;
-    FuncT floatFunc;
+    gpu::AllReduceOperation kind;
+    ElemType elemType;
+    FuncT func;
   };
 
   Type type = arg.getType();
-  using MembptrT = FuncT OpHandler::*;
-  MembptrT handlerPtr;
+  ElemType elementType;
   if (isa<FloatType>(type)) {
-    handlerPtr = &OpHandler::floatFunc;
-  } else if (isa<IntegerType>(type)) {
-    handlerPtr = &OpHandler::intFunc;
+    elementType = ElemType::Float;
+  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+    elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean : ElemType::Integer;
   } else {
     return std::nullopt;
   }
@@ -510,48 +512,46 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
 
   using ReduceType = gpu::AllReduceOperation;
   const OpHandler handlers[] = {
-      {ReduceType::ADD,
+      {ReduceType::ADD, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupIAddOp,
-                                spirv::GroupNonUniformIAddOp>,
+                                spirv::GroupNonUniformIAddOp>},
+      {ReduceType::ADD, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFAddOp,
                                 spirv::GroupNonUniformFAddOp>},
-      {ReduceType::MUL,
+      {ReduceType::MUL, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
-                                spirv::GroupNonUniformIMulOp>,
+                                spirv::GroupNonUniformIMulOp>},
+      {ReduceType::MUL, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
                                 spirv::GroupNonUniformFMulOp>},
-      {ReduceType::MINUI,
+      {ReduceType::MINUI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupUMinOp,
-                                spirv::GroupNonUniformUMinOp>,
-       nullptr},
-      {ReduceType::MINSI,
+                                spirv::GroupNonUniformUMinOp>},
+      {ReduceType::MINSI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupSMinOp,
-                                spirv::GroupNonUniformSMinOp>,
-       nullptr},
-      {ReduceType::MINF, nullptr,
+                                spirv::GroupNonUniformSMinOp>},
+      {ReduceType::MINF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMinOp,
                                 spirv::GroupNonUniformFMinOp>},
-      {ReduceType::MAXUI,
+      {ReduceType::MAXUI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupUMaxOp,
-                                spirv::GroupNonUniformUMaxOp>,
-       nullptr},
-      {ReduceType::MAXSI,
+                                spirv::GroupNonUniformUMaxOp>},
+      {ReduceType::MAXSI, ElemType::Integer,
        &createGroupReduceOpImpl<spirv::GroupSMaxOp,
-                                spirv::GroupNonUniformSMaxOp>,
-       nullptr},
-      {ReduceType::MAXF, nullptr,
+                                spirv::GroupNonUniformSMaxOp>},
+      {ReduceType::MAXF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMaxOp,
                                 spirv::GroupNonUniformFMaxOp>},
-      {ReduceType::MINIMUMF, nullptr,
+      {ReduceType::MINIMUMF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMinOp,
                                 spirv::GroupNonUniformFMinOp>},
-      {ReduceType::MAXIMUMF, nullptr,
+      {ReduceType::MAXIMUMF, ElemType::Float,
        &createGroupReduceOpImpl<spirv::GroupFMaxOp,
                                 spirv::GroupNonUniformFMaxOp>}};
 
   for (const OpHandler &handler : handlers)
-    if (handler.type == opType)
-      return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
+    if (handler.kind == opType && elementType == handler.elemType)
+      return handler.func(builder, loc, arg, isGroup, isUniform);
 
   return std::nullopt;
 }

>From 3806820b571535724b473b93108c54e5cc244116 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 Nov 2023 12:17:29 -0500
Subject: [PATCH 2/2] Add tests

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp |   7 +-
 .../Conversion/GPUToSPIRV/reductions.mlir     | 100 ++++++++++++++++++
 2 files changed, 103 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 655cd1f7a2c795e..5a88ab351866bdc 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -485,9 +485,7 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
                                                 Location loc, Value arg,
                                                 gpu::AllReduceOperation opType,
                                                 bool isGroup, bool isUniform) {
-  enum class ElemType {
-    Float, Boolean, Integer
-  };
+  enum class ElemType { Float, Boolean, Integer };
   using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
   struct OpHandler {
     gpu::AllReduceOperation kind;
@@ -500,7 +498,8 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
   if (isa<FloatType>(type)) {
     elementType = ElemType::Float;
   } else if (auto intTy = dyn_cast<IntegerType>(type)) {
-    elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean : ElemType::Integer;
+    elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
+                                                       : ElemType::Integer;
   } else {
     return std::nullopt;
   }
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index feb7ee185a8a858..57803298ea27b8f 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -647,8 +647,108 @@ gpu.module @kernels {
     // CHECK: %{{.*}} = spirv.GroupNonUniformUMax "Subgroup" "Reduce" %[[ARG]] : i32
     %r0 = gpu.subgroup_reduce maxsi %arg : (i32) -> (i32)
     %r1 = gpu.subgroup_reduce maxui %arg : (i32) -> (i32)
+    %r3 = gpu.subgroup_reduce maxui %arg : (i32) -> (i32)
     gpu.return
   }
 }
 
 }
+
+// -----
+
+// TODO: Handle boolean reductions.
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  gpu.func @add(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce add %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @mul(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce mul %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @minsi(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce minsi %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @minui(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce minui %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @maxsi(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce maxsi %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @maxui(%arg : i1) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce maxui %arg : (i1) -> (i1)
+    gpu.return
+  }
+}
+}



More information about the Mlir-commits mailing list