[Mlir-commits] [mlir] [mlir][spirv] Fix vector reduction lowerings for FP min/max (PR #69025)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 13 12:36:33 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Daniil Dudkin (unterumarmung)
<details>
<summary>Changes</summary>
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.
This patch addresses tasks 2.4 and 2.5 of the RFC.
Please note that this patch depends on #<!-- -->69023.
---
Patch is 23.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69025.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+150-29)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+143-11)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..040fa69e2e9f27b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -28,6 +29,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <cassert>
#include <cstdint>
@@ -351,15 +353,13 @@ struct VectorInsertStridedSliceOpConvert final
}
};
-template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
- class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
-struct VectorReductionPattern final
- : public OpConversionPattern<vector::ReductionOp> {
+template <typename Derived>
+struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ ConversionPatternRewriter &rewriter) const final {
Type resultType = typeConverter->convertType(reduceOp.getType());
if (!resultType)
return failure();
@@ -368,9 +368,22 @@ struct VectorReductionPattern final
if (!srcVectorType || srcVectorType.getRank() != 1)
return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
- // Extract all elements.
+ SmallVector<Value> extractedElements =
+ extractAllElements(reduceOp, adaptor, srcVectorType, rewriter);
+
+ const auto &self = static_cast<const Derived &>(*this);
+
+ return self.reduceExtracted(reduceOp, extractedElements, resultType,
+ rewriter);
+ }
+
+private:
+ SmallVector<Value>
+ extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+ VectorType srcVectorType,
+ ConversionPatternRewriter &rewriter) const {
int numElements = srcVectorType.getDimSize(0);
- SmallVector<Value, 4> values;
+ SmallVector<Value> values;
values.reserve(numElements + (adaptor.getAcc() != nullptr));
Location loc = reduceOp.getLoc();
for (int i = 0; i < numElements; ++i) {
@@ -381,9 +394,26 @@ struct VectorReductionPattern final
if (Value acc = adaptor.getAcc())
values.push_back(acc);
- // Reduce them.
- Value result = values.front();
- for (Value next : llvm::ArrayRef(values).drop_front()) {
+ return values;
+ }
+};
+
+#define VECTOR_REDUCTION_BASE \
+ VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, \
+ SPIRVSMaxOp, SPIRVSMinOp>>
+template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
+ typename SPIRVSMinOp>
+struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
+ using Base = VECTOR_REDUCTION_BASE;
+ using Base::Base;
+
+ LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+ ArrayRef<Value> extractedElements,
+ Type resultType,
+ ConversionPatternRewriter &rewriter) const {
+ mlir::Location loc = reduceOp->getLoc();
+ Value result = extractedElements.front();
+ for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
switch (reduceOp.getKind()) {
#define INT_AND_FLOAT_CASE(kind, iop, fop) \
@@ -403,10 +433,6 @@ struct VectorReductionPattern final
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
- INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
- INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
- INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
- INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,13 +442,105 @@ struct VectorReductionPattern final
case vector::CombiningKind::OR:
case vector::CombiningKind::XOR:
return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+ default:
+ return rewriter.notifyMatchFailure(reduceOp, "not handled here");
+ }
+ }
+
+ rewriter.replaceOp(reduceOp, result);
+ return success();
+ }
+};
+#undef VECTOR_REDUCTION_BASE
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
+
+#define MIN_MAX_PATTERN_BASE \
+ VectorReductionPatternBase< \
+ VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
+template <class SPIRVFMaxOp, class SPIRVFMinOp>
+struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
+ using Base = MIN_MAX_PATTERN_BASE;
+ using Base::Base;
+
+ LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+ ArrayRef<Value> extractedElements,
+ Type resultType,
+ ConversionPatternRewriter &rewriter) const {
+ mlir::Location loc = reduceOp->getLoc();
+ Value result = extractedElements.front();
+ for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
+ switch (reduceOp.getKind()) {
+
+#define INT_OR_FLOAT_CASE(kind, fop) \
+ case vector::CombiningKind::kind: { \
+ fop op = rewriter.create<fop>(loc, resultType, result, next); \
+ result = this->generateActionForOp(rewriter, loc, resultType, op, \
+ vector::CombiningKind::kind); \
+ break; \
+ }
+
+ INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
+ INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
+ INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
+ INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
+
+ default:
+ return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
}
rewriter.replaceOp(reduceOp, result);
return success();
}
+
+private:
+ enum class Action { Nothing, PropagateNaN, PropagateNonNaN };
+
+ template <typename Op>
+ Action getActionForOp(vector::CombiningKind kind) const {
+ constexpr bool isCLOp = std::is_same_v<Op, spirv::CLFMaxOp> ||
+ std::is_same_v<Op, spirv::CLFMinOp>;
+ switch (kind) {
+ case vector::CombiningKind::MINIMUMF:
+ case vector::CombiningKind::MAXIMUMF:
+ return Action::PropagateNaN;
+ case vector::CombiningKind::MINF:
+ case vector::CombiningKind::MAXF:
+ // CL ops already have the same semantic for NaNs as MINF/MAXF
+ // GL ops have undefined semantics for NaNs, so we need to explicitly
+ // propagate the non-NaN values
+ return isCLOp ? Action::Nothing : Action::PropagateNonNaN;
+ default:
+ llvm_unreachable("Unexpected case for the switch");
+ }
+ }
+
+ template <typename Op>
+ Value generateActionForOp(ConversionPatternRewriter &rewriter,
+ mlir::Location loc, Type resultType, Op op,
+ vector::CombiningKind kind) const {
+ Action action = getActionForOp<Op>(kind);
+
+ if (action == Action::Nothing) {
+ return op;
+ }
+
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getLhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getRhs());
+
+ Value select1 = rewriter.create<spirv::SelectOp>(
+ loc, resultType, lhsIsNan,
+ action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op);
+ Value select2 = rewriter.create<spirv::SelectOp>(
+ loc, resultType, rhsIsNan,
+ action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1);
+
+ return select2;
+ }
};
+#undef MIN_MAX_PATTERN_BASE
+#undef INT_OR_FLOAT_CASE
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
public:
@@ -604,25 +722,28 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
};
} // namespace
-#define CL_MAX_MIN_OPS \
- spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
- spirv::CLSMaxOp, spirv::CLSMinOp
+#define CL_INT_MAX_MIN_OPS \
+ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_INT_MAX_MIN_OPS \
+ spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
-#define GL_MAX_MIN_OPS \
- spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
- spirv::GLSMaxOp, spirv::GLSMinOp
+#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
+#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
- VectorExtractStridedSliceOpConvert,
- VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
- VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern>(typeConverter, patterns.getContext());
+ patterns.add<
+ VectorBitcastConvert, VectorBroadcastConvert,
+ VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
+ VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
+ VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorSplatPattern>(typeConverter, patterns.getContext());
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..91836e556147b8d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
-// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
-// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
-// CHECK: return %[[MAX2]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
@@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
+func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
+// CHECK: return %[[MAX2]]
+func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_minf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]]
// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]]
// CHECK: return %[[MIN2]]
-func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
- %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
}
@@ -522,9 +574,21 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
-// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]]
-// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]]
-// CHECK: return %[[MAX2]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
@@ -532,15 +596,55 @@ func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
// -----
+// CHECK-LABEL: func @reduction_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
+func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_minimumf
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
-// CHECK: ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/69025
More information about the Mlir-commits
mailing list