[Mlir-commits] [mlir] [mlir][spirv] Add spirv-to-llvm conversion for group operations (PR #115501)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 8 07:43:29 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Lukas Sommer (sommerlukas)
<details>
<summary>Changes</summary>
Lowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect.
Similar to #<!-- -->111864, lower the operations to builtin functions understood by SPIR-V tools.
---
Patch is 45.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115501.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+235-5)
- (added) mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir (+312)
- (added) mlir/test/Conversion/SPIRVToLLVM/non-uniform-ops-to-llvm.mlir (+247)
``````````diff
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ba31936af5510d..3fd78d2b08a6bb 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
@@ -1027,7 +1028,8 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
StringRef name,
ArrayRef<Type> paramTypes,
- Type resultType) {
+ Type resultType,
+ bool convergent = true) {
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(symbolTable, name));
if (func)
@@ -1038,7 +1040,9 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
- func.setConvergent(true);
+ if (convergent) {
+ func.setConvergent(true);
+ }
func.setNoUnwind(true);
func.setWillReturn(true);
return func;
@@ -1046,10 +1050,13 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
LLVM::LLVMFuncOp func,
- ValueRange args) {
+ ValueRange args,
+ bool convergent = true) {
auto call = builder.create<LLVM::CallOp>(loc, func, args);
call.setCConv(func.getCConv());
- call.setConvergentAttr(func.getConvergentAttr());
+ if (convergent) {
+ call.setConvergentAttr(func.getConvergentAttr());
+ }
call.setNoUnwindAttr(func.getNoUnwindAttr());
call.setWillReturnAttr(func.getWillReturnAttr());
return call;
@@ -1089,6 +1096,186 @@ class ControlBarrierPattern
}
};
+namespace {
+
+StringRef getTypeMangling(Type type, bool isSigned) {
+ return llvm::TypeSwitch<Type, StringRef>(type)
+ .Case<Float16Type>([](auto) { return "Dh"; })
+ .template Case<Float32Type>([](auto) { return "f"; })
+ .template Case<Float64Type>([](auto) { return "d"; })
+ .template Case<IntegerType>([isSigned](IntegerType intTy) {
+ switch (intTy.getWidth()) {
+ case 1:
+ return "b";
+ case 8:
+ return (isSigned) ? "a" : "c";
+ case 16:
+ return (isSigned) ? "s" : "t";
+ case 32:
+ return (isSigned) ? "i" : "j";
+ case 64:
+ return (isSigned) ? "l" : "m";
+ default: {
+ assert(false && "Unsupported integer width");
+ return "";
+ }
+ }
+ })
+ .Default([](auto) {
+ assert(false && "No mangling defined");
+ return "";
+ });
+}
+
+template <typename ReduceOp>
+constexpr StringLiteral getGroupFuncName() {
+ assert(false && "No builtin defined");
+ return "";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
+ return "_Z17__spirv_GroupIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
+ return "_Z17__spirv_GroupFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
+ return "_Z17__spirv_GroupSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
+ return "_Z17__spirv_GroupUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
+ return "_Z17__spirv_GroupFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
+ return "_Z17__spirv_GroupSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
+ return "_Z17__spirv_GroupUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
+ return "_Z17__spirv_GroupFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
+ return "_Z27__spirv_GroupNonUniformIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
+ return "_Z27__spirv_GroupNonUniformFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
+ return "_Z27__spirv_GroupNonUniformIMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
+ return "_Z27__spirv_GroupNonUniformFMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
+ return "_Z27__spirv_GroupNonUniformSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
+ return "_Z27__spirv_GroupNonUniformUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
+ return "_Z27__spirv_GroupNonUniformFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
+ return "_Z33__spirv_GroupNonUniformBitwiseAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
+ return "_Z32__spirv_GroupNonUniformBitwiseOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
+ return "_Z33__spirv_GroupNonUniformBitwiseXorii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
+ return "_Z33__spirv_GroupNonUniformLogicalAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
+ return "_Z32__spirv_GroupNonUniformLogicalOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
+ return "_Z33__spirv_GroupNonUniformLogicalXorii";
+}
+} // namespace
+
+template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
+class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
+public:
+ using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type retTy = op.getResult().getType();
+ if (!retTy.isIntOrFloat()) {
+ return failure();
+ }
+ SmallString<20> funcName = getGroupFuncName<ReduceOp>();
+ funcName += getTypeMangling(retTy, false);
+
+ Type i32Ty = rewriter.getI32Type();
+ SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
+ if constexpr (NonUniform) {
+ if (adaptor.getClusterSize()) {
+ funcName += "j";
+ paramTypes.push_back(i32Ty);
+ }
+ }
+
+ Operation *symbolTable =
+ op->template getParentWithTrait<OpTrait::SymbolTable>();
+
+ LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
+ symbolTable, funcName, paramTypes, retTy, !NonUniform);
+
+ Location loc = op.getLoc();
+ Value scope = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
+ Value groupOp = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
+ SmallVector<Value> operands{scope, groupOp};
+ operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
+
+ auto call =
+ createSPIRVBuiltinCall(loc, rewriter, func, operands, !NonUniform);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
/// should be reachable for conversion to succeed. The structure of the loop in
/// LLVM dialect will be the following:
@@ -1722,7 +1909,50 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
ReturnPattern, ReturnValuePattern,
// Barrier ops
- ControlBarrierPattern>(patterns.getContext(), typeConverter);
+ ControlBarrierPattern,
+
+ // Group reduction operations
+ GroupReducePattern<spirv::GroupIAddOp>,
+ GroupReducePattern<spirv::GroupFAddOp>,
+ GroupReducePattern<spirv::GroupFMinOp>,
+ GroupReducePattern<spirv::GroupUMinOp>,
+ GroupReducePattern<spirv::GroupSMinOp, /*Signed*/ true>,
+ GroupReducePattern<spirv::GroupFMaxOp>,
+ GroupReducePattern<spirv::GroupUMaxOp>,
+ GroupReducePattern<spirv::GroupSMaxOp, /*Signed*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed*/ true,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed*/ true,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed*/ false,
+ /*NonUniform*/ true>
+ >(patterns.getContext(), typeConverter);
patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
typeConverter);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..8c8fc50349e795
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
@@ -0,0 +1,312 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// CHECK-LABEL: llvm.func spir_funccc @_Z17__spirv_GroupSMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupUMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupFMaxiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupSMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupUMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupFMiniif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupFAddiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupIAddiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+
+// CHECK-LABEL: llvm.func @group_reduce_iadd(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_iadd(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupIAdd <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_fadd(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_reduce_fadd(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFAdd <Workgroup> <Reduce> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_fmin(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_reduce_fmin(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFMin <Workgroup> <Reduce> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_umin(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_umin(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupUMin <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_smin(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_smin(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupSMin <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_fmax(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMaxiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_reduce_fmax(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFMax <Workgroup> <Reduce> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_umax(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_umax(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupUMax <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_smax(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_smax(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupSMax <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_iadd(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_inclusive_scan_iadd(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupIAdd <Workgroup> <InclusiveScan> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_fadd(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_inclusive_scan_fadd(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFAdd <Workgroup> <InclusiveScan> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_fmin(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_inclusive_scan_fmin(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFMin <Workgroup> <InclusiveScan> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_umin(
+//...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/115501
More information about the Mlir-commits
mailing list