[Mlir-commits] [mlir] [mlir][spirv] Fix vector reduction lowerings for FP min/max (PR #69053)
Daniil Dudkin
llvmlistbot at llvm.org
Sat Oct 14 08:14:16 PDT 2023
https://github.com/unterumarmung created https://github.com/llvm/llvm-project/pull/69053
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.
This patch addresses tasks 2.4 and 2.5 of the RFC.
>From 4154c1db7f600f07a64ef37b77b4acf6f910f72a Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Sat, 14 Oct 2023 18:13:44 +0300
Subject: [PATCH] [mlir][spirv] Fix vector reduction lowerings for FP min/max
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.
This patch addresses tasks 2.4 and 2.5 of the RFC.
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 54 +++++-
.../VectorToSPIRV/vector-to-spirv.mlir | 154 ++++++++++++++++--
2 files changed, 194 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..1d46d9503e9760d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -397,9 +397,12 @@ struct VectorReductionPattern final
break
#define INT_OR_FLOAT_CASE(kind, fop) \
- case vector::CombiningKind::kind: \
- result = rewriter.create<fop>(loc, resultType, result, next); \
- break
+ case vector::CombiningKind::kind: { \
+ fop op = rewriter.create<fop>(loc, resultType, result, next); \
+ result = this->generateActionForOp(rewriter, loc, resultType, op, \
+ vector::CombiningKind::kind); \
+ break; \
+ }
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
@@ -422,6 +425,51 @@ struct VectorReductionPattern final
rewriter.replaceOp(reduceOp, result);
return success();
}
+
+private:
+ enum class Action { Nothing, PropagateNaN, PropagateNonNaN };
+
+ template <typename Op>
+ Action getActionForOp(vector::CombiningKind kind) const {
+ constexpr bool isCLOp = std::is_same_v<Op, spirv::CLFMaxOp> ||
+ std::is_same_v<Op, spirv::CLFMinOp>;
+ switch (kind) {
+ case vector::CombiningKind::MINIMUMF:
+ case vector::CombiningKind::MAXIMUMF:
+ return Action::PropagateNaN;
+ case vector::CombiningKind::MINF:
+ case vector::CombiningKind::MAXF:
+ // CL ops already have the same semantic for NaNs as MINF/MAXF
+ // GL ops have undefined semantics for NaNs, so we need to explicitly
+ // propagate the non-NaN values
+ return isCLOp ? Action::Nothing : Action::PropagateNonNaN;
+ default:
+ return Action::Nothing;
+ }
+ }
+
+ template <typename Op>
+ Value generateActionForOp(ConversionPatternRewriter &rewriter,
+ mlir::Location loc, Type resultType, Op op,
+ vector::CombiningKind kind) const {
+ Action action = getActionForOp<Op>(kind);
+
+ if (action == Action::Nothing) {
+ return op;
+ }
+
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getLhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getRhs());
+
+ Value select1 = rewriter.create<spirv::SelectOp>(
+ loc, resultType, lhsIsNan,
+ action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op);
+ Value select2 = rewriter.create<spirv::SelectOp>(
+ loc, resultType, rhsIsNan,
+ action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1);
+
+ return select2;
+ }
};
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..91836e556147b8d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
-// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
-// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
-// CHECK: return %[[MAX2]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
@@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
+func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
+// CHECK: return %[[MAX2]]
+func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_minf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]]
// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]]
// CHECK: return %[[MIN2]]
-func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
- %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
}
@@ -522,9 +574,21 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
-// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]]
-// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]]
-// CHECK: return %[[MAX2]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
@@ -532,15 +596,55 @@ func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
// -----
+// CHECK-LABEL: func @reduction_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
+func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_minimumf
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
-// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[MIN0]], %[[S2]]
-// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[MIN1]], %[[S]]
-// CHECK: return %[[MIN2]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
@@ -548,6 +652,34 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
// -----
+// CHECK-LABEL: func @reduction_minf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MIN0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MIN1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MIN2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
+func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_maxsi
// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
More information about the Mlir-commits
mailing list