[Mlir-commits] [mlir] dd09221 - Revert "[mlir][gpu] Align reduction operations with vector combining kinds (#73423)"
Jakub Kuderski
llvmlistbot at llvm.org
Mon Nov 27 08:30:12 PST 2023
Author: Jakub Kuderski
Date: 2023-11-27T11:29:23-05:00
New Revision: dd09221a29506031415cad8a1308998358633d48
URL: https://github.com/llvm/llvm-project/commit/dd09221a29506031415cad8a1308998358633d48
DIFF: https://github.com/llvm/llvm-project/commit/dd09221a29506031415cad8a1308998358633d48.diff
LOG: Revert "[mlir][gpu] Align reduction operations with vector combining kinds (#73423)"
This reverts commit e0aac8c88d0d30e8da0f8a240ad1e6b4d88782e0.
I'm seeing some nvidia integration test failures:
https://lab.llvm.org/buildbot/#/builders/61/builds/52334.
Added:
mlir/test/Dialect/GPU/all-reduce-max.mlir
mlir/test/Dialect/GPU/all-reduce.mlir
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/include/mlir/IR/CommonTypeConstraints.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Conversion/GPUToSPIRV/reductions.mlir
mlir/test/Dialect/GPU/invalid.mlir
Removed:
mlir/test/Dialect/GPU/all-reduce-add.mlir
mlir/test/Dialect/GPU/all-reduce-maxf.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 7cad1cd89fd6335..826df0012fb8f0a 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -931,53 +931,38 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
}];
}
-// These mirror the reduction combining kinds from the vector dialect.
+// add, mul mirror the XLA ComparisonDirection enum.
def GPU_AllReduceOpAdd : I32EnumAttrCase<"ADD", 0, "add">;
-def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 1, "mul">;
-def GPU_AllReduceOpMinUI : I32EnumAttrCase<"MINUI", 2, "minui">;
-def GPU_AllReduceOpMinSI : I32EnumAttrCase<"MINSI", 3, "minsi">;
-// Follows the `arith.minnumf` semantics.
-def GPU_AllReduceOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
-def GPU_AllReduceOpMaxUI : I32EnumAttrCase<"MAXUI", 5, "maxui">;
-def GPU_AllReduceOpMaxSI : I32EnumAttrCase<"MAXSI", 6, "maxsi">;
-// Follows the `arith.maxnumf` semantics.
-def GPU_AllReduceOpMaxF : I32EnumAttrCase<"MAXF", 7, "maxf">;
-def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 8, "and">;
-def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 9, "or">;
-def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 10, "xor">;
-// Follows the `arith.minimumf` semantics.
-def GPU_AllReduceOpMinimumF : I32EnumAttrCase<"MINIMUMF", 11, "minimumf">;
-// Follows the `arith.maximumf` semantics.
-def GPU_AllReduceOpMaximumF : I32EnumAttrCase<"MAXIMUMF", 12, "maximumf">;
+def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 1, "and">;
+def GPU_AllReduceOpMax : I32EnumAttrCase<"MAX", 2, "max">;
+def GPU_AllReduceOpMin : I32EnumAttrCase<"MIN", 3, "min">;
+def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 4, "mul">;
+def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 5, "or">;
+def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 6, "xor">;
def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
"built-in reduction operations supported by gpu.allreduce.",
[
GPU_AllReduceOpAdd,
- GPU_AllReduceOpMul,
- GPU_AllReduceOpMinUI,
- GPU_AllReduceOpMinSI,
- GPU_AllReduceOpMinF,
- GPU_AllReduceOpMaxUI,
- GPU_AllReduceOpMaxSI,
- GPU_AllReduceOpMaxF,
GPU_AllReduceOpAnd,
+ GPU_AllReduceOpMax,
+ GPU_AllReduceOpMin,
+ GPU_AllReduceOpMul,
GPU_AllReduceOpOr,
- GPU_AllReduceOpXor,
- GPU_AllReduceOpMinimumF,
- GPU_AllReduceOpMaximumF
+ GPU_AllReduceOpXor
]>{
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::gpu";
}
-
-def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
-
def GPU_AllReduceOperationAttr : EnumAttr<GPU_Dialect, GPU_AllReduceOperation,
"all_reduce_op">;
def GPU_AllReduceOp : GPU_Op<"all_reduce",
- [SameOperandsAndResultType, IsolatedFromAbove]> {
+ [SameOperandsAndResultType, IsolatedFromAbove]>,
+ Arguments<(ins AnyType:$value,
+ OptionalAttr<GPU_AllReduceOperationAttr>:$op,
+ UnitAttr:$uniform)>,
+ Results<(outs AnyType)> {
let summary = "Reduce values among workgroup.";
let description = [{
The `all_reduce` op reduces the value of every work item across a local
@@ -996,23 +981,12 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
compute the sum of each work item's %0 value. The first version specifies
the accumulation as operation, whereas the second version specifies the
- accumulation as code region. The reduction operation must be one of:
- * Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
- `or`, `xor`
- * Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
- `maximumf`
+ accumulation as code region. The accumulation operation must be one of:
+ `add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
If `uniform` flag is set either none or all work items of a workgroup
need to execute this op in convergence.
}];
-
- let arguments = (ins
- AnyIntegerOrFloat:$value,
- OptionalAttr<GPU_AllReduceOperationAttr>:$op,
- UnitAttr:$uniform
- );
- let results = (outs AnyIntegerOrFloat:$result);
-
let regions = (region AnyRegion:$body);
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
(`uniform` $uniform^)? $body attr-dict
@@ -1022,7 +996,12 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
let hasRegionVerifier = 1;
}
-def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
+def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
+ [SameOperandsAndResultType]>,
+ Arguments<(ins AnyType:$value,
+ GPU_AllReduceOperationAttr:$op,
+ UnitAttr:$uniform)>,
+ Results<(outs AnyType)> {
let summary = "Reduce values among subgroup.";
let description = [{
The `subgroup_reduce` op reduces the value of every work item across a
@@ -1035,21 +1014,8 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
```
If `uniform` flag is set either none or all work items of a subgroup
- need to execute this op in convergence. The reduction operation must be one
- of:
- * Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
- `or`, `xor`
- * Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
- `maximumf`
+ need to execute this op in convergence.
}];
-
- let arguments = (ins
- AnyIntegerOrFloat:$value,
- GPU_AllReduceOperationAttr:$op,
- UnitAttr:$uniform
- );
- let results = (outs AnyIntegerOrFloat:$result);
-
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
(`uniform` $uniform^)? attr-dict
`:` functional-type(operands, results) }];
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 03180a687523bf7..b0b5348baaad963 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -34,7 +34,7 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
!::llvm::cast<VectorType>($_self).isScalable()}]>;
// Whether a type is a scalable VectorType.
-def IsVectorTypeWithAnyDimScalablePred
+def IsVectorTypeWithAnyDimScalablePred
: CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 4855fd187eb5861..9456784c406aebb 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -65,28 +65,17 @@ convertReduxKind(gpu::AllReduceOperation mode) {
switch (mode) {
case gpu::AllReduceOperation::ADD:
return NVVM::ReduxKind::ADD;
- case gpu::AllReduceOperation::MUL:
- return std::nullopt;
- case gpu::AllReduceOperation::MINSI:
- return NVVM::ReduxKind::MIN;
- case gpu::AllReduceOperation::MINUI:
- return std::nullopt;
- case gpu::AllReduceOperation::MINF:
- return NVVM::ReduxKind::MIN;
- case gpu::AllReduceOperation::MAXSI:
- return NVVM::ReduxKind::MAX;
- case gpu::AllReduceOperation::MAXUI:
- return std::nullopt;
- case gpu::AllReduceOperation::MAXF:
- return NVVM::ReduxKind::MAX;
case gpu::AllReduceOperation::AND:
return NVVM::ReduxKind::AND;
+ case gpu::AllReduceOperation::MAX:
+ return NVVM::ReduxKind::MAX;
+ case gpu::AllReduceOperation::MIN:
+ return NVVM::ReduxKind::MIN;
case gpu::AllReduceOperation::OR:
return NVVM::ReduxKind::OR;
case gpu::AllReduceOperation::XOR:
return NVVM::ReduxKind::XOR;
- case gpu::AllReduceOperation::MINIMUMF:
- case gpu::AllReduceOperation::MAXIMUMF:
+ case gpu::AllReduceOperation::MUL:
return std::nullopt;
}
return std::nullopt;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 6536bbe1f4dba47..693cc3f6236b574 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -503,53 +503,26 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
return std::nullopt;
}
- // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
- // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
- // reduction ops. We should account possible precision requirements in this
- // conversion.
-
using ReduceType = gpu::AllReduceOperation;
+ namespace spv = spirv;
const OpHandler handlers[] = {
{ReduceType::ADD,
- &createGroupReduceOpImpl<spirv::GroupIAddOp,
- spirv::GroupNonUniformIAddOp>,
- &createGroupReduceOpImpl<spirv::GroupFAddOp,
- spirv::GroupNonUniformFAddOp>},
+ &createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
+ &createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
{ReduceType::MUL,
- &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
- spirv::GroupNonUniformIMulOp>,
- &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
- spirv::GroupNonUniformFMulOp>},
- {ReduceType::MINUI,
- &createGroupReduceOpImpl<spirv::GroupUMinOp,
- spirv::GroupNonUniformUMinOp>,
- nullptr},
- {ReduceType::MINSI,
- &createGroupReduceOpImpl<spirv::GroupSMinOp,
- spirv::GroupNonUniformSMinOp>,
- nullptr},
- {ReduceType::MINF, nullptr,
- &createGroupReduceOpImpl<spirv::GroupFMinOp,
- spirv::GroupNonUniformFMinOp>},
- {ReduceType::MAXUI,
- &createGroupReduceOpImpl<spirv::GroupUMaxOp,
- spirv::GroupNonUniformUMaxOp>,
- nullptr},
- {ReduceType::MAXSI,
- &createGroupReduceOpImpl<spirv::GroupSMaxOp,
- spirv::GroupNonUniformSMaxOp>,
- nullptr},
- {ReduceType::MAXF, nullptr,
- &createGroupReduceOpImpl<spirv::GroupFMaxOp,
- spirv::GroupNonUniformFMaxOp>},
- {ReduceType::MINIMUMF, nullptr,
- &createGroupReduceOpImpl<spirv::GroupFMinOp,
- spirv::GroupNonUniformFMinOp>},
- {ReduceType::MAXIMUMF, nullptr,
- &createGroupReduceOpImpl<spirv::GroupFMaxOp,
- spirv::GroupNonUniformFMaxOp>}};
-
- for (const OpHandler &handler : handlers)
+ &createGroupReduceOpImpl<spv::GroupIMulKHROp,
+ spv::GroupNonUniformIMulOp>,
+ &createGroupReduceOpImpl<spv::GroupFMulKHROp,
+ spv::GroupNonUniformFMulOp>},
+ {ReduceType::MIN,
+ &createGroupReduceOpImpl<spv::GroupSMinOp, spv::GroupNonUniformSMinOp>,
+ &createGroupReduceOpImpl<spv::GroupFMinOp, spv::GroupNonUniformFMinOp>},
+ {ReduceType::MAX,
+ &createGroupReduceOpImpl<spv::GroupSMaxOp, spv::GroupNonUniformSMaxOp>,
+ &createGroupReduceOpImpl<spv::GroupFMaxOp, spv::GroupNonUniformFMaxOp>},
+ };
+
+ for (auto &handler : handlers)
if (handler.type == opType)
return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d31903ea201158f..1b6db1fb0c79f7c 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -27,9 +27,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/InliningUtils.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
@@ -488,23 +486,12 @@ static LogicalResult verifyAttributions(Operation *op,
// AllReduceOp
//===----------------------------------------------------------------------===//
-static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
- Type resType) {
- using Kind = gpu::AllReduceOperation;
- if (llvm::is_contained(
- {Kind::MINF, Kind::MAXF, Kind::MINIMUMF, Kind::MAXIMUMF}, opName)) {
- if (!isa<FloatType>(resType))
- return failure();
- }
-
- if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
- Kind::AND, Kind::OR, Kind::XOR},
- opName)) {
- if (!isa<IntegerType>(resType))
- return failure();
- }
-
- return success();
+static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
+ Type resType) {
+ return (opName != gpu::AllReduceOperation::AND &&
+ opName != gpu::AllReduceOperation::OR &&
+ opName != gpu::AllReduceOperation::XOR) ||
+ llvm::isa<IntegerType>(resType);
}
LogicalResult gpu::AllReduceOp::verifyRegions() {
@@ -531,13 +518,12 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
return emitError("expected gpu.yield op in region");
} else {
gpu::AllReduceOperation opName = *getOp();
- if (failed(verifyReduceOpAndType(opName, getType()))) {
- return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
- << "` reduction operation is not compatible with type "
- << getType();
+ if (!verifyReduceOpAndType(opName, getType())) {
+ return emitError()
+ << '`' << gpu::stringifyAllReduceOperation(opName)
+ << "` accumulator is only compatible with Integer type";
}
}
-
return success();
}
@@ -588,10 +574,9 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
LogicalResult gpu::SubgroupReduceOp::verify() {
gpu::AllReduceOperation opName = getOp();
- if (failed(verifyReduceOpAndType(opName, getType()))) {
+ if (!verifyReduceOpAndType(opName, getType())) {
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
- << "` reduction operation is not compatible with type "
- << getType();
+ << "` accumulator is only compatible with Integer type";
}
return success();
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index ecee9a7b45e32bd..acf4f6d0e3d6979 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -214,37 +214,32 @@ struct GpuAllReduceRewriter {
/// Returns an accumulator factory that creates an op specified by opName.
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
- using Kind = gpu::AllReduceOperation;
bool isFloatingPoint = isa<FloatType>(valueType);
switch (opName) {
- case Kind::ADD:
+ case gpu::AllReduceOperation::ADD:
return isFloatingPoint ? getFactory<arith::AddFOp>()
: getFactory<arith::AddIOp>();
- case Kind::MUL:
+ case gpu::AllReduceOperation::MUL:
return isFloatingPoint ? getFactory<arith::MulFOp>()
: getFactory<arith::MulIOp>();
- case Kind::MINSI:
- return getFactory<arith::MinSIOp>();
- case Kind::MINUI:
- return getFactory<arith::MinUIOp>();
- case Kind::MINF:
- return getFactory<arith::MinNumFOp>();
- case Kind::MAXSI:
- return getFactory<arith::MaxSIOp>();
- case Kind::MAXUI:
- return getFactory<arith::MaxUIOp>();
- case Kind::MAXF:
- return getFactory<arith::MaxNumFOp>();
- case Kind::AND:
+ case gpu::AllReduceOperation::AND:
return getFactory<arith::AndIOp>();
- case Kind::OR:
+ case gpu::AllReduceOperation::OR:
return getFactory<arith::OrIOp>();
- case Kind::XOR:
+ case gpu::AllReduceOperation::XOR:
return getFactory<arith::XOrIOp>();
- case Kind::MINIMUMF:
- return getFactory<arith::MinimumFOp>();
- case Kind::MAXIMUMF:
- return getFactory<arith::MaximumFOp>();
+ case gpu::AllReduceOperation::MAX:
+ return isFloatingPoint
+ ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
+ arith::CmpFPredicate::UGT>()
+ : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
+ arith::CmpIPredicate::ugt>();
+ case gpu::AllReduceOperation::MIN:
+ return isFloatingPoint
+ ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
+ arith::CmpFPredicate::ULT>()
+ : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
+ arith::CmpIPredicate::ult>();
}
llvm_unreachable("unknown GPU AllReduceOperation");
}
@@ -252,11 +247,21 @@ struct GpuAllReduceRewriter {
/// Returns an accumulator factory that creates an op of type T.
template <typename T>
AccumulatorFactory getFactory() {
- return [this](Value lhs, Value rhs) {
+ return [&](Value lhs, Value rhs) {
return create<T>(lhs.getType(), lhs, rhs);
};
}
+ /// Returns an accumulator for comparison such as min, max. T is the type
+ /// of the compare op.
+ template <typename T, typename PredicateEnum, PredicateEnum predicate>
+ AccumulatorFactory getCmpFactory() const {
+ return [&](Value lhs, Value rhs) {
+ Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
+ return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+ };
+ }
+
/// Creates an if-block skeleton and calls the two factories to generate the
/// ops in the `then` and `else` block..
///
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 20a200e812c1259..c18bb423a6e6001 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -582,22 +582,22 @@ gpu.module @test_module_30 {
%result = gpu.subgroup_reduce add %arg0 uniform {} : (i32) -> (i32)
gpu.return
}
- // CHECK-LABEL: @subgroup_reduce_minsi
- gpu.func @subgroup_reduce_minsi(%arg0 : i32) {
- // CHECK: nvvm.redux.sync min {{.*}}
- %result = gpu.subgroup_reduce minsi %arg0 uniform {} : (i32) -> (i32)
+ // CHECK-LABEL: func @subgroup_reduce_and
+ gpu.func @subgroup_reduce_and(%arg0 : i32) {
+ // CHECK: nvvm.redux.sync and {{.*}}
+ %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
gpu.return
}
- // CHECK-LABEL: @subgroup_reduce_maxsi
- gpu.func @subgroup_reduce_maxsi(%arg0 : i32) {
+ // CHECK-LABEL: @subgroup_reduce_max
+ gpu.func @subgroup_reduce_max(%arg0 : i32) {
// CHECK: nvvm.redux.sync max {{.*}}
- %result = gpu.subgroup_reduce maxsi %arg0 uniform {} : (i32) -> (i32)
+ %result = gpu.subgroup_reduce max %arg0 uniform {} : (i32) -> (i32)
gpu.return
}
- // CHECK-LABEL: func @subgroup_reduce_and
- gpu.func @subgroup_reduce_and(%arg0 : i32) {
- // CHECK: nvvm.redux.sync and {{.*}}
- %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
+ // CHECK-LABEL: @subgroup_reduce_min
+ gpu.func @subgroup_reduce_min(%arg0 : i32) {
+ // CHECK: nvvm.redux.sync min {{.*}}
+ %result = gpu.subgroup_reduce min %arg0 uniform {} : (i32) -> (i32)
gpu.return
}
// CHECK-LABEL: @subgroup_reduce_or
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index feb7ee185a8a858..1e5d64387650ce3 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -331,7 +331,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupFMin <Workgroup> <Reduce> %[[ARG]] : f32
- %reduced = gpu.all_reduce minf %arg uniform {} : (f32) -> (f32)
+ %reduced = gpu.all_reduce min %arg uniform {} : (f32) -> (f32)
gpu.return
}
}
@@ -351,7 +351,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformFMin "Workgroup" "Reduce" %[[ARG]] : f32
- %reduced = gpu.all_reduce minf %arg {} : (f32) -> (f32)
+ %reduced = gpu.all_reduce min %arg {} : (f32) -> (f32)
gpu.return
}
}
@@ -371,9 +371,7 @@ gpu.module @kernels {
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupSMin <Workgroup> <Reduce> %[[ARG]] : i32
- // CHECK: %{{.*}} = spirv.GroupUMin <Workgroup> <Reduce> %[[ARG]] : i32
- %r0 = gpu.all_reduce minsi %arg uniform {} : (i32) -> (i32)
- %r1 = gpu.all_reduce minui %arg uniform {} : (i32) -> (i32)
+ %reduced = gpu.all_reduce min %arg uniform {} : (i32) -> (i32)
gpu.return
}
}
@@ -392,9 +390,8 @@ gpu.module @kernels {
// CHECK-SAME: (%[[ARG:.*]]: i32)
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
- // CHECK: %{{.*}} = spirv.GroupNonUniformUMin "Workgroup" "Reduce" %[[ARG]] : i32
- %r0 = gpu.all_reduce minsi %arg {} : (i32) -> (i32)
- %r1 = gpu.all_reduce minui %arg {} : (i32) -> (i32)
+ // CHECK: %{{.*}} = spirv.GroupNonUniformSMin "Workgroup" "Reduce" %[[ARG]] : i32
+ %reduced = gpu.all_reduce min %arg {} : (i32) -> (i32)
gpu.return
}
}
@@ -414,7 +411,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupFMin <Subgroup> <Reduce> %[[ARG]] : f32
- %reduced = gpu.subgroup_reduce minf %arg uniform : (f32) -> (f32)
+ %reduced = gpu.subgroup_reduce min %arg uniform : (f32) -> (f32)
gpu.return
}
}
@@ -434,7 +431,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformFMin "Subgroup" "Reduce" %[[ARG]] : f32
- %reduced = gpu.subgroup_reduce minf %arg : (f32) -> (f32)
+ %reduced = gpu.subgroup_reduce min %arg : (f32) -> (f32)
gpu.return
}
}
@@ -454,9 +451,7 @@ gpu.module @kernels {
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupSMin <Subgroup> <Reduce> %[[ARG]] : i32
- // CHECK: %{{.*}} = spirv.GroupUMin <Subgroup> <Reduce> %[[ARG]] : i32
- %r0 = gpu.subgroup_reduce minsi %arg uniform : (i32) -> (i32)
- %r1 = gpu.subgroup_reduce minui %arg uniform : (i32) -> (i32)
+ %reduced = gpu.subgroup_reduce min %arg uniform : (i32) -> (i32)
gpu.return
}
}
@@ -476,9 +471,7 @@ gpu.module @kernels {
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformSMin "Subgroup" "Reduce" %[[ARG]] : i32
- // CHECK: %{{.*}} = spirv.GroupNonUniformUMin "Subgroup" "Reduce" %[[ARG]] : i32
- %r0 = gpu.subgroup_reduce minsi %arg : (i32) -> (i32)
- %r1 = gpu.subgroup_reduce minui %arg : (i32) -> (i32)
+ %reduced = gpu.subgroup_reduce min %arg : (i32) -> (i32)
gpu.return
}
}
@@ -498,7 +491,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupFMax <Workgroup> <Reduce> %[[ARG]] : f32
- %reduced = gpu.all_reduce maxf %arg uniform {} : (f32) -> (f32)
+ %reduced = gpu.all_reduce max %arg uniform {} : (f32) -> (f32)
gpu.return
}
}
@@ -518,7 +511,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformFMax "Workgroup" "Reduce" %[[ARG]] : f32
- %reduced = gpu.all_reduce maxf %arg {} : (f32) -> (f32)
+ %reduced = gpu.all_reduce max %arg {} : (f32) -> (f32)
gpu.return
}
}
@@ -538,9 +531,7 @@ gpu.module @kernels {
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupSMax <Workgroup> <Reduce> %[[ARG]] : i32
- // CHECK: %{{.*}} = spirv.GroupUMax <Workgroup> <Reduce> %[[ARG]] : i32
- %r0 = gpu.all_reduce maxsi %arg uniform {} : (i32) -> (i32)
- %r1 = gpu.all_reduce maxui %arg uniform {} : (i32) -> (i32)
+ %reduced = gpu.all_reduce max %arg uniform {} : (i32) -> (i32)
gpu.return
}
}
@@ -560,9 +551,7 @@ gpu.module @kernels {
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformSMax "Workgroup" "Reduce" %[[ARG]] : i32
- // CHECK: %{{.*}} = spirv.GroupNonUniformUMax "Workgroup" "Reduce" %[[ARG]] : i32
- %r0 = gpu.all_reduce maxsi %arg {} : (i32) -> (i32)
- %r1 = gpu.all_reduce maxui %arg {} : (i32) -> (i32)
+ %reduced = gpu.all_reduce max %arg {} : (i32) -> (i32)
gpu.return
}
}
@@ -582,7 +571,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupFMax <Subgroup> <Reduce> %[[ARG]] : f32
- %reduced = gpu.subgroup_reduce maxf %arg uniform : (f32) -> (f32)
+ %reduced = gpu.subgroup_reduce max %arg uniform : (f32) -> (f32)
gpu.return
}
}
@@ -602,7 +591,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformFMax "Subgroup" "Reduce" %[[ARG]] : f32
- %reduced = gpu.subgroup_reduce maxf %arg : (f32) -> (f32)
+ %reduced = gpu.subgroup_reduce max %arg : (f32) -> (f32)
gpu.return
}
}
@@ -622,9 +611,7 @@ gpu.module @kernels {
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupSMax <Subgroup> <Reduce> %[[ARG]] : i32
- // CHECK: %{{.*}} = spirv.GroupUMax <Subgroup> <Reduce> %[[ARG]] : i32
- %r0 = gpu.subgroup_reduce maxsi %arg uniform : (i32) -> (i32)
- %r1 = gpu.subgroup_reduce maxui %arg uniform : (i32) -> (i32)
+ %reduced = gpu.subgroup_reduce max %arg uniform : (i32) -> (i32)
gpu.return
}
}
@@ -644,9 +631,7 @@ gpu.module @kernels {
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformSMax "Subgroup" "Reduce" %[[ARG]] : i32
- // CHECK: %{{.*}} = spirv.GroupNonUniformUMax "Subgroup" "Reduce" %[[ARG]] : i32
- %r0 = gpu.subgroup_reduce maxsi %arg : (i32) -> (i32)
- %r1 = gpu.subgroup_reduce maxui %arg : (i32) -> (i32)
+ %reduced = gpu.subgroup_reduce max %arg : (i32) -> (i32)
gpu.return
}
}
diff --git a/mlir/test/Dialect/GPU/all-reduce-maxf.mlir b/mlir/test/Dialect/GPU/all-reduce-max.mlir
similarity index 69%
rename from mlir/test/Dialect/GPU/all-reduce-maxf.mlir
rename to mlir/test/Dialect/GPU/all-reduce-max.mlir
index b502e587637cdc8..a71544ba0e98d36 100644
--- a/mlir/test/Dialect/GPU/all-reduce-maxf.mlir
+++ b/mlir/test/Dialect/GPU/all-reduce-max.mlir
@@ -44,55 +44,65 @@ gpu.module @kernels {
// CHECK: [[VAL_34:%.*]], [[VAL_35:%.*]] = gpu.shuffle xor [[VAL_0]], [[VAL_6]], [[VAL_32]] : f32
// CHECK: cf.cond_br [[VAL_35]], ^bb2, ^bb3
// CHECK: ^bb2:
- // CHECK: [[VAL_36:%.*]] = arith.maxnumf [[VAL_0]], [[VAL_34]] : f32
- // CHECK: cf.br ^bb4([[VAL_36]] : f32)
+ // CHECK: [[VAL_36:%.*]] = arith.cmpf ugt, [[VAL_0]], [[VAL_34]] : f32
+ // CHECK: [[VAL_37:%.*]] = arith.select [[VAL_36]], [[VAL_0]], [[VAL_34]] : f32
+ // CHECK: cf.br ^bb4([[VAL_37]] : f32)
// CHECK: ^bb3:
// CHECK: cf.br ^bb4([[VAL_0]] : f32)
// CHECK: ^bb4([[VAL_38:%.*]]: f32):
// CHECK: [[VAL_39:%.*]], [[VAL_40:%.*]] = gpu.shuffle xor [[VAL_38]], [[VAL_7]], [[VAL_32]] : f32
// CHECK: cf.cond_br [[VAL_40]], ^bb5, ^bb6
// CHECK: ^bb5:
- // CHECK: [[VAL_41:%.*]] = arith.maxnumf [[VAL_38]], [[VAL_39]] : f32
- // CHECK: cf.br ^bb7([[VAL_41]] : f32)
+ // CHECK: [[VAL_41:%.*]] = arith.cmpf ugt, [[VAL_38]], [[VAL_39]] : f32
+ // CHECK: [[VAL_42:%.*]] = arith.select [[VAL_41]], [[VAL_38]], [[VAL_39]] : f32
+ // CHECK: cf.br ^bb7([[VAL_42]] : f32)
// CHECK: ^bb6:
// CHECK: cf.br ^bb7([[VAL_38]] : f32)
// CHECK: ^bb7([[VAL_43:%.*]]: f32):
// CHECK: [[VAL_44:%.*]], [[VAL_45:%.*]] = gpu.shuffle xor [[VAL_43]], [[VAL_8]], [[VAL_32]] : f32
// CHECK: cf.cond_br [[VAL_45]], ^bb8, ^bb9
// CHECK: ^bb8:
- // CHECK: [[VAL_46:%.*]] = arith.maxnumf [[VAL_43]], [[VAL_44]] : f32
- // CHECK: cf.br ^bb10([[VAL_46]] : f32)
+ // CHECK: [[VAL_46:%.*]] = arith.cmpf ugt, [[VAL_43]], [[VAL_44]] : f32
+ // CHECK: [[VAL_47:%.*]] = arith.select [[VAL_46]], [[VAL_43]], [[VAL_44]] : f32
+ // CHECK: cf.br ^bb10([[VAL_47]] : f32)
// CHECK: ^bb9:
// CHECK: cf.br ^bb10([[VAL_43]] : f32)
// CHECK: ^bb10([[VAL_48:%.*]]: f32):
// CHECK: [[VAL_49:%.*]], [[VAL_50:%.*]] = gpu.shuffle xor [[VAL_48]], [[VAL_9]], [[VAL_32]] : f32
// CHECK: cf.cond_br [[VAL_50]], ^bb11, ^bb12
// CHECK: ^bb11:
- // CHECK: [[VAL_51:%.*]] = arith.maxnumf [[VAL_48]], [[VAL_49]] : f32
- // CHECK: cf.br ^bb13([[VAL_51]] : f32)
+ // CHECK: [[VAL_51:%.*]] = arith.cmpf ugt, [[VAL_48]], [[VAL_49]] : f32
+ // CHECK: [[VAL_52:%.*]] = arith.select [[VAL_51]], [[VAL_48]], [[VAL_49]] : f32
+ // CHECK: cf.br ^bb13([[VAL_52]] : f32)
// CHECK: ^bb12:
// CHECK: cf.br ^bb13([[VAL_48]] : f32)
// CHECK: ^bb13([[VAL_53:%.*]]: f32):
// CHECK: [[VAL_54:%.*]], [[VAL_55:%.*]] = gpu.shuffle xor [[VAL_53]], [[VAL_10]], [[VAL_32]] : f32
// CHECK: cf.cond_br [[VAL_55]], ^bb14, ^bb15
// CHECK: ^bb14:
- // CHECK: [[VAL_56:%.*]] = arith.maxnumf [[VAL_53]], [[VAL_54]] : f32
- // CHECK: cf.br ^bb16([[VAL_56]] : f32)
+ // CHECK: [[VAL_56:%.*]] = arith.cmpf ugt, [[VAL_53]], [[VAL_54]] : f32
+ // CHECK: [[VAL_57:%.*]] = arith.select [[VAL_56]], [[VAL_53]], [[VAL_54]] : f32
+ // CHECK: cf.br ^bb16([[VAL_57]] : f32)
// CHECK: ^bb15:
// CHECK: cf.br ^bb16([[VAL_53]] : f32)
// CHECK: ^bb16([[VAL_58:%.*]]: f32):
// CHECK: cf.br ^bb18([[VAL_58]] : f32)
// CHECK: ^bb17:
// CHECK: [[VAL_59:%.*]], [[VAL_60:%.*]] = gpu.shuffle xor [[VAL_0]], [[VAL_6]], [[VAL_5]] : f32
- // CHECK: [[VAL_62:%.*]] = arith.maxnumf [[VAL_0]], [[VAL_59]] : f32
+ // CHECK: [[VAL_61:%.*]] = arith.cmpf ugt, [[VAL_0]], [[VAL_59]] : f32
+ // CHECK: [[VAL_62:%.*]] = arith.select [[VAL_61]], [[VAL_0]], [[VAL_59]] : f32
// CHECK: [[VAL_63:%.*]], [[VAL_64:%.*]] = gpu.shuffle xor [[VAL_62]], [[VAL_7]], [[VAL_5]] : f32
- // CHECK: [[VAL_66:%.*]] = arith.maxnumf [[VAL_62]], [[VAL_63]] : f32
+ // CHECK: [[VAL_65:%.*]] = arith.cmpf ugt, [[VAL_62]], [[VAL_63]] : f32
+ // CHECK: [[VAL_66:%.*]] = arith.select [[VAL_65]], [[VAL_62]], [[VAL_63]] : f32
// CHECK: [[VAL_67:%.*]], [[VAL_68:%.*]] = gpu.shuffle xor [[VAL_66]], [[VAL_8]], [[VAL_5]] : f32
- // CHECK: [[VAL_70:%.*]] = arith.maxnumf [[VAL_66]], [[VAL_67]] : f32
+ // CHECK: [[VAL_69:%.*]] = arith.cmpf ugt, [[VAL_66]], [[VAL_67]] : f32
+ // CHECK: [[VAL_70:%.*]] = arith.select [[VAL_69]], [[VAL_66]], [[VAL_67]] : f32
// CHECK: [[VAL_71:%.*]], [[VAL_72:%.*]] = gpu.shuffle xor [[VAL_70]], [[VAL_9]], [[VAL_5]] : f32
- // CHECK: [[VAL_74:%.*]] = arith.maxnumf [[VAL_70]], [[VAL_71]] : f32
+ // CHECK: [[VAL_73:%.*]] = arith.cmpf ugt, [[VAL_70]], [[VAL_71]] : f32
+ // CHECK: [[VAL_74:%.*]] = arith.select [[VAL_73]], [[VAL_70]], [[VAL_71]] : f32
// CHECK: [[VAL_75:%.*]], [[VAL_76:%.*]] = gpu.shuffle xor [[VAL_74]], [[VAL_10]], [[VAL_5]] : f32
- // CHECK: [[VAL_78:%.*]] = arith.maxnumf [[VAL_74]], [[VAL_75]] : f32
+ // CHECK: [[VAL_77:%.*]] = arith.cmpf ugt, [[VAL_74]], [[VAL_75]] : f32
+ // CHECK: [[VAL_78:%.*]] = arith.select [[VAL_77]], [[VAL_74]], [[VAL_75]] : f32
// CHECK: cf.br ^bb18([[VAL_78]] : f32)
// CHECK: ^bb18([[VAL_79:%.*]]: f32):
// CHECK: cf.cond_br [[VAL_30]], ^bb19, ^bb20
@@ -118,7 +128,8 @@ gpu.module @kernels {
// CHECK: [[VAL_88:%.*]], [[VAL_89:%.*]] = gpu.shuffle xor [[VAL_86]], [[VAL_6]], [[VAL_83]] : f32
// CHECK: cf.cond_br [[VAL_89]], ^bb24, ^bb25
// CHECK: ^bb24:
- // CHECK: [[VAL_91:%.*]] = arith.maxnumf [[VAL_86]], [[VAL_88]] : f32
+ // CHECK: [[VAL_90:%.*]] = arith.cmpf ugt, [[VAL_86]], [[VAL_88]] : f32
+ // CHECK: [[VAL_91:%.*]] = arith.select [[VAL_90]], [[VAL_86]], [[VAL_88]] : f32
// CHECK: cf.br ^bb26([[VAL_91]] : f32)
// CHECK: ^bb25:
// CHECK: cf.br ^bb26([[VAL_86]] : f32)
@@ -126,7 +137,8 @@ gpu.module @kernels {
// CHECK: [[VAL_93:%.*]], [[VAL_94:%.*]] = gpu.shuffle xor [[VAL_92]], [[VAL_7]], [[VAL_83]] : f32
// CHECK: cf.cond_br [[VAL_94]], ^bb27, ^bb28
// CHECK: ^bb27:
- // CHECK: [[VAL_96:%.*]] = arith.maxnumf [[VAL_92]], [[VAL_93]] : f32
+ // CHECK: [[VAL_95:%.*]] = arith.cmpf ugt, [[VAL_92]], [[VAL_93]] : f32
+ // CHECK: [[VAL_96:%.*]] = arith.select [[VAL_95]], [[VAL_92]], [[VAL_93]] : f32
// CHECK: cf.br ^bb29([[VAL_96]] : f32)
// CHECK: ^bb28:
// CHECK: cf.br ^bb29([[VAL_92]] : f32)
@@ -134,7 +146,8 @@ gpu.module @kernels {
// CHECK: [[VAL_98:%.*]], [[VAL_99:%.*]] = gpu.shuffle xor [[VAL_97]], [[VAL_8]], [[VAL_83]] : f32
// CHECK: cf.cond_br [[VAL_99]], ^bb30, ^bb31
// CHECK: ^bb30:
- // CHECK: [[VAL_101:%.*]] = arith.maxnumf [[VAL_97]], [[VAL_98]] : f32
+ // CHECK: [[VAL_100:%.*]] = arith.cmpf ugt, [[VAL_97]], [[VAL_98]] : f32
+ // CHECK: [[VAL_101:%.*]] = arith.select [[VAL_100]], [[VAL_97]], [[VAL_98]] : f32
// CHECK: cf.br ^bb32([[VAL_101]] : f32)
// CHECK: ^bb31:
// CHECK: cf.br ^bb32([[VAL_97]] : f32)
@@ -142,7 +155,8 @@ gpu.module @kernels {
// CHECK: [[VAL_103:%.*]], [[VAL_104:%.*]] = gpu.shuffle xor [[VAL_102]], [[VAL_9]], [[VAL_83]] : f32
// CHECK: cf.cond_br [[VAL_104]], ^bb33, ^bb34
// CHECK: ^bb33:
- // CHECK: [[VAL_106:%.*]] = arith.maxnumf [[VAL_102]], [[VAL_103]] : f32
+ // CHECK: [[VAL_105:%.*]] = arith.cmpf ugt, [[VAL_102]], [[VAL_103]] : f32
+ // CHECK: [[VAL_106:%.*]] = arith.select [[VAL_105]], [[VAL_102]], [[VAL_103]] : f32
// CHECK: cf.br ^bb35([[VAL_106]] : f32)
// CHECK: ^bb34:
// CHECK: cf.br ^bb35([[VAL_102]] : f32)
@@ -150,7 +164,8 @@ gpu.module @kernels {
// CHECK: [[VAL_108:%.*]], [[VAL_109:%.*]] = gpu.shuffle xor [[VAL_107]], [[VAL_10]], [[VAL_83]] : f32
// CHECK: cf.cond_br [[VAL_109]], ^bb36, ^bb37
// CHECK: ^bb36:
- // CHECK: [[VAL_111:%.*]] = arith.maxnumf [[VAL_107]], [[VAL_108]] : f32
+ // CHECK: [[VAL_110:%.*]] = arith.cmpf ugt, [[VAL_107]], [[VAL_108]] : f32
+ // CHECK: [[VAL_111:%.*]] = arith.select [[VAL_110]], [[VAL_107]], [[VAL_108]] : f32
// CHECK: cf.br ^bb38([[VAL_111]] : f32)
// CHECK: ^bb37:
// CHECK: cf.br ^bb38([[VAL_107]] : f32)
@@ -158,15 +173,20 @@ gpu.module @kernels {
// CHECK: cf.br ^bb40([[VAL_112]] : f32)
// CHECK: ^bb39:
// CHECK: [[VAL_113:%.*]], [[VAL_114:%.*]] = gpu.shuffle xor [[VAL_86]], [[VAL_6]], [[VAL_5]] : f32
- // CHECK: [[VAL_116:%.*]] = arith.maxnumf [[VAL_86]], [[VAL_113]] : f32
+ // CHECK: [[VAL_115:%.*]] = arith.cmpf ugt, [[VAL_86]], [[VAL_113]] : f32
+ // CHECK: [[VAL_116:%.*]] = arith.select [[VAL_115]], [[VAL_86]], [[VAL_113]] : f32
// CHECK: [[VAL_117:%.*]], [[VAL_118:%.*]] = gpu.shuffle xor [[VAL_116]], [[VAL_7]], [[VAL_5]] : f32
- // CHECK: [[VAL_120:%.*]] = arith.maxnumf [[VAL_116]], [[VAL_117]] : f32
+ // CHECK: [[VAL_119:%.*]] = arith.cmpf ugt, [[VAL_116]], [[VAL_117]] : f32
+ // CHECK: [[VAL_120:%.*]] = arith.select [[VAL_119]], [[VAL_116]], [[VAL_117]] : f32
// CHECK: [[VAL_121:%.*]], [[VAL_122:%.*]] = gpu.shuffle xor [[VAL_120]], [[VAL_8]], [[VAL_5]] : f32
- // CHECK: [[VAL_124:%.*]] = arith.maxnumf [[VAL_120]], [[VAL_121]] : f32
+ // CHECK: [[VAL_123:%.*]] = arith.cmpf ugt, [[VAL_120]], [[VAL_121]] : f32
+ // CHECK: [[VAL_124:%.*]] = arith.select [[VAL_123]], [[VAL_120]], [[VAL_121]] : f32
// CHECK: [[VAL_125:%.*]], [[VAL_126:%.*]] = gpu.shuffle xor [[VAL_124]], [[VAL_9]], [[VAL_5]] : f32
- // CHECK: [[VAL_128:%.*]] = arith.maxnumf [[VAL_124]], [[VAL_125]] : f32
+ // CHECK: [[VAL_127:%.*]] = arith.cmpf ugt, [[VAL_124]], [[VAL_125]] : f32
+ // CHECK: [[VAL_128:%.*]] = arith.select [[VAL_127]], [[VAL_124]], [[VAL_125]] : f32
// CHECK: [[VAL_129:%.*]], [[VAL_130:%.*]] = gpu.shuffle xor [[VAL_128]], [[VAL_10]], [[VAL_5]] : f32
- // CHECK: [[VAL_132:%.*]] = arith.maxnumf [[VAL_128]], [[VAL_129]] : f32
+ // CHECK: [[VAL_131:%.*]] = arith.cmpf ugt, [[VAL_128]], [[VAL_129]] : f32
+ // CHECK: [[VAL_132:%.*]] = arith.select [[VAL_131]], [[VAL_128]], [[VAL_129]] : f32
// CHECK: cf.br ^bb40([[VAL_132]] : f32)
// CHECK: ^bb40([[VAL_133:%.*]]: f32):
// CHECK: store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, #gpu.address_space<workgroup>>
@@ -175,7 +195,7 @@ gpu.module @kernels {
// CHECK: cf.br ^bb42
// CHECK: ^bb42:
// CHECK: gpu.barrier
- %sum = gpu.all_reduce maxf %arg0 uniform {} : (f32) -> (f32)
+ %sum = gpu.all_reduce max %arg0 uniform {} : (f32) -> (f32)
gpu.return
}
diff --git a/mlir/test/Dialect/GPU/all-reduce-add.mlir b/mlir/test/Dialect/GPU/all-reduce.mlir
similarity index 100%
rename from mlir/test/Dialect/GPU/all-reduce-add.mlir
rename to mlir/test/Dialect/GPU/all-reduce.mlir
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 17faccbd091a8bc..3a2197ad4d5a172 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -210,14 +210,6 @@ module attributes {gpu.container_module} {
// -----
-func.func @reduce_bad_type(%arg0 : vector<4xf32>) {
- // expected-error at +1 {{'gpu.all_reduce' op operand #0 must be Integer or Float}}
- %res = gpu.all_reduce add %arg0 {} : (vector<4xf32>) -> vector<4xf32>
- return
-}
-
-// -----
-
func.func @reduce_no_op_no_body(%arg0 : f32) {
// expected-error at +1 {{expected either an op attribute or a non-empty body}}
%res = "gpu.all_reduce"(%arg0) ({}) : (f32) -> (f32)
@@ -245,118 +237,22 @@ func.func @reduce_invalid_op(%arg0 : f32) {
// -----
-func.func @reduce_invalid_op_type_minsi(%arg0 : f32) {
- // expected-error at +1 {{`minsi` reduction operation is not compatible with type 'f32'}}
- %res = gpu.all_reduce minsi %arg0 {} : (f32) -> (f32)
- return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_minui(%arg0 : f32) {
- // expected-error at +1 {{`minui` reduction operation is not compatible with type 'f32'}}
- %res = gpu.all_reduce minui %arg0 {} : (f32) -> (f32)
- return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_maxsi(%arg0 : f32) {
- // expected-error at +1 {{`maxsi` reduction operation is not compatible with type 'f32'}}
- %res = gpu.all_reduce maxsi %arg0 {} : (f32) -> (f32)
- return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_maxui(%arg0 : f32) {
- // expected-error at +1 {{`maxui` reduction operation is not compatible with type 'f32'}}
- %res = gpu.all_reduce maxui %arg0 {} : (f32) -> (f32)
- return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_and(%arg0 : f32) {
- // expected-error at +1 {{`and` reduction operation is not compatible with type 'f32'}}
+func.func @reduce_invalid_op_type(%arg0 : f32) {
+ // expected-error at +1 {{`and` accumulator is only compatible with Integer type}}
%res = gpu.all_reduce and %arg0 {} : (f32) -> (f32)
return
}
// -----
-func.func @reduce_invalid_op_type_or(%arg0 : f32) {
- // expected-error at +1 {{`or` reduction operation is not compatible with type 'f32'}}
- %res = gpu.all_reduce or %arg0 {} : (f32) -> (f32)
- return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_xor(%arg0 : f32) {
- // expected-error at +1 {{`xor` reduction operation is not compatible with type 'f32'}}
- %res = gpu.all_reduce xor %arg0 {} : (f32) -> (f32)
- return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_minf(%arg0 : i32) {
- // expected-error at +1 {{`minf` reduction operation is not compatible with type 'i32'}}
- %res = gpu.all_reduce minf %arg0 {} : (i32) -> (i32)
- return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_maxf(%arg0 : i32) {
- // expected-error at +1 {{`maxf` reduction operation is not compatible with type 'i32'}}
- %res = gpu.all_reduce maxf %arg0 {} : (i32) -> (i32)
- return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_minimumf(%arg0 : i32) {
- // expected-error at +1 {{`minimumf` reduction operation is not compatible with type 'i32'}}
- %res = gpu.all_reduce minimumf %arg0 {} : (i32) -> (i32)
- return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
- // expected-error at +1 {{`maximumf` reduction operation is not compatible with type 'i32'}}
- %res = gpu.all_reduce maximumf %arg0 {} : (i32) -> (i32)
- return
-}
-
-// -----
-
-func.func @subgroup_reduce_bad_type(%arg0 : vector<2xf32>) {
- // expected-error at +1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float}}
- %res = gpu.subgroup_reduce add %arg0 : (vector<2xf32>) -> vector<2xf32>
- return
-}
-
-// -----
-
-func.func @subgroup_reduce_invalid_op_type_and(%arg0 : f32) {
- // expected-error at +1 {{`and` reduction operation is not compatible with type 'f32'}}
+func.func @subgroup_reduce_invalid_op_type(%arg0 : f32) {
+ // expected-error at +1 {{`and` accumulator is only compatible with Integer type}}
%res = gpu.subgroup_reduce and %arg0 : (f32) -> (f32)
return
}
// -----
-func.func @subgroup_reduce_invalid_op_type_maxf(%arg0 : i32) {
- // expected-error at +1 {{`maxf` reduction operation is not compatible with type 'i32'}}
- %res = gpu.subgroup_reduce maxf %arg0 : (i32) -> (i32)
- return
-}
-
-// -----
-
func.func @reduce_incorrect_region_arguments(%arg0 : f32) {
// expected-error at +1 {{expected two region arguments}}
%res = gpu.all_reduce %arg0 {
@@ -751,11 +647,11 @@ func.func @main() {
%shmemSize = arith.constant 10000 : i32
%c1 = arith.constant 1 : index
gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
- threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
// expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
- %0 = gpu.dynamic_shared_memory : memref<?xi8>
+ %0 = gpu.dynamic_shared_memory : memref<?xi8>
gpu.terminator
}
return
@@ -768,11 +664,11 @@ func.func @main() {
%shmemSize = arith.constant 8192 : i32
%c1 = arith.constant 1 : index
gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
- threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
// expected-error @below {{'gpu.dynamic_shared_memory' op result memref type must be memref<?xi8, #gpu.address_space<workgroup>>}}
- %0 = gpu.dynamic_shared_memory : memref<1xi8, #gpu.address_space<workgroup>>
+ %0 = gpu.dynamic_shared_memory : memref<1xi8, #gpu.address_space<workgroup>>
gpu.terminator
}
return
@@ -784,7 +680,7 @@ func.func @main(%arg0 : index) {
%shmemSize = arith.constant 8192 : i32
%c1 = arith.constant 1 : index
gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
- threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
// expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
@@ -800,7 +696,7 @@ func.func @main(%arg0 : index) {
%shmemSize = arith.constant 8192 : i32
%c1 = arith.constant 1 : index
gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
- threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
// expected-error @below {{'gpu.dynamic_shared_memory' op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, #gpu.address_space<workgroup>}}
More information about the Mlir-commits
mailing list