[Mlir-commits] [mlir] 2e4aa3b - [mlir][gpu][spirv] Lower gpu reduction ops to spirv

Ivan Butygin llvmlistbot at llvm.org
Fri Dec 30 08:45:47 PST 2022


Author: Ivan Butygin
Date: 2022-12-30T17:44:08+01:00
New Revision: 2e4aa3bd83faef5e89275cba97d99b1a77c4d25c

URL: https://github.com/llvm/llvm-project/commit/2e4aa3bd83faef5e89275cba97d99b1a77c4d25c
DIFF: https://github.com/llvm/llvm-project/commit/2e4aa3bd83faef5e89275cba97d99b1a77c4d25c.diff

LOG: [mlir][gpu][spirv] Lower gpu reduction ops to spirv

Supports only "add" and "mul" ops for now. More ops will be added later.

Differential Revision: https://reviews.llvm.org/D140576

Added: 
    mlir/test/Conversion/GPUToSPIRV/reductions.mlir

Modified: 
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 5c9de5ebcc955..86a71ac417255 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -426,6 +426,118 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Group ops
+//===----------------------------------------------------------------------===//
+
+template <typename UniformOp, typename NonUniformOp>
+static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
+                                     Value arg, bool isGroup, bool isUniform) {
+  Type type = arg.getType();
+  auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
+                                           isGroup ? spirv::Scope::Workgroup
+                                                   : spirv::Scope::Subgroup);
+  auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
+                                                spirv::GroupOperation::Reduce);
+  if (isUniform) {
+    return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
+        .getResult();
+  } else {
+    return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
+        .getResult();
+  }
+}
+
+static llvm::Optional<Value> createGroupReduceOp(OpBuilder &builder,
+                                                 Location loc, Value arg,
+                                                 gpu::AllReduceOperation opType,
+                                                 bool isGroup, bool isUniform) {
+  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
+  struct OpHandler {
+    gpu::AllReduceOperation type;
+    FuncT intFunc;
+    FuncT floatFunc;
+  };
+
+  Type type = arg.getType();
+  using MembptrT = FuncT OpHandler::*;
+  MembptrT handlerPtr;
+  if (type.isa<FloatType>()) {
+    handlerPtr = &OpHandler::floatFunc;
+  } else if (type.isa<IntegerType>()) {
+    handlerPtr = &OpHandler::intFunc;
+  } else {
+    return std::nullopt;
+  }
+
+  using ReduceType = gpu::AllReduceOperation;
+  namespace spv = spirv;
+  const OpHandler handlers[] = {
+      {ReduceType::ADD,
+       &createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
+       &createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
+      {ReduceType::MUL,
+       &createGroupReduceOpImpl<spv::GroupIMulKHROp,
+                                spv::GroupNonUniformIMulOp>,
+       &createGroupReduceOpImpl<spv::GroupFMulKHROp,
+                                spv::GroupNonUniformFMulOp>},
+  };
+
+  for (auto &handler : handlers)
+    if (handler.type == opType)
+      return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
+
+  return std::nullopt;
+}
+
+/// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
+class GPUAllReduceConversion final
+    : public OpConversionPattern<gpu::AllReduceOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto opType = op.getOp();
+
+    // gpu.all_reduce can have either reduction op attribute or reduction
+    // region. Only attribute version is supported.
+    if (!opType)
+      return failure();
+
+    auto result =
+        createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
+                            /*isGroup*/ true, op.getUniform());
+    if (!result)
+      return failure();
+
+    rewriter.replaceOp(op, *result);
+    return success();
+  }
+};
+
+/// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
+class GPUSubgroupReduceConversion final
+    : public OpConversionPattern<gpu::SubgroupReduceOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto opType = op.getOp();
+    auto result =
+        createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), opType,
+                            /*isGroup*/ false, op.getUniform());
+    if (!result)
+      return failure();
+
+    rewriter.replaceOp(op, *result);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // GPU To SPIRV Patterns.
 //===----------------------------------------------------------------------===//
@@ -448,5 +560,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                       spirv::BuiltIn::NumSubgroups>,
       SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
                                       spirv::BuiltIn::SubgroupSize>,
-      WorkGroupSizeConversion>(typeConverter, patterns.getContext());
+      WorkGroupSizeConversion, GPUAllReduceConversion,
+      GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
new file mode 100644
index 0000000000000..245704571c7df
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -0,0 +1,319 @@
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  gpu.func @test(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupFAdd <Workgroup> <Reduce> %[[ARG]] : f32
+    %reduced = gpu.all_reduce add %arg uniform {} : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  gpu.func @test(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd "Workgroup" "Reduce" %[[ARG]] : f32
+    %reduced = gpu.all_reduce add %arg {} : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+  gpu.func @test(%arg : i32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupIAdd <Workgroup> <Reduce> %[[ARG]] : i32
+    %reduced = gpu.all_reduce add %arg uniform {} : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+  gpu.func @test(%arg : i32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformIAdd "Workgroup" "Reduce" %[[ARG]] : i32
+    %reduced = gpu.all_reduce add %arg {} : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  gpu.func @test(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupFAdd <Subgroup> <Reduce> %[[ARG]] : f32
+    %reduced = gpu.subgroup_reduce add %arg uniform : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  gpu.func @test(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd "Subgroup" "Reduce" %[[ARG]] : f32
+    %reduced = gpu.subgroup_reduce add %arg : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+  gpu.func @test(%arg : i32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupIAdd <Subgroup> <Reduce> %[[ARG]] : i32
+    %reduced = gpu.subgroup_reduce add %arg uniform : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+  gpu.func @test(%arg : i32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[ARG]] : i32
+    %reduced = gpu.subgroup_reduce add %arg : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  gpu.func @test(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.KHR.GroupFMul <Workgroup> <Reduce> %[[ARG]] : f32
+    %reduced = gpu.all_reduce mul %arg uniform {} : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  gpu.func @test(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformFMul "Workgroup" "Reduce" %[[ARG]] : f32
+    %reduced = gpu.all_reduce mul %arg {} : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+  gpu.func @test(%arg : i32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.KHR.GroupIMul <Workgroup> <Reduce> %[[ARG]] : i32
+    %reduced = gpu.all_reduce mul %arg uniform {} : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+  gpu.func @test(%arg : i32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformIMul "Workgroup" "Reduce" %[[ARG]] : i32
+    %reduced = gpu.all_reduce mul %arg {} : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  gpu.func @test(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.KHR.GroupFMul <Subgroup> <Reduce> %[[ARG]] : f32
+    %reduced = gpu.subgroup_reduce mul %arg uniform : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  gpu.func @test(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformFMul "Subgroup" "Reduce" %[[ARG]] : f32
+    %reduced = gpu.subgroup_reduce mul %arg : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+  gpu.func @test(%arg : i32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.KHR.GroupIMul <Subgroup> <Reduce> %[[ARG]] : i32
+    %reduced = gpu.subgroup_reduce mul %arg uniform : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+  gpu.func @test(%arg : i32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformIMul "Subgroup" "Reduce" %[[ARG]] : i32
+    %reduced = gpu.subgroup_reduce mul %arg : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+}


        


More information about the Mlir-commits mailing list