[Mlir-commits] [mlir] [mlir][spirv] Fix vector reduction lowerings for FP min/max (PR #69025)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 13 12:36:33 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Daniil Dudkin (unterumarmung)

<details>
<summary>Changes</summary>

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
    
This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.
    
This patch addresses tasks 2.4 and 2.5 of the RFC.

Please note that this patch depends on #<!-- -->69023.


---

Patch is 23.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69025.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+150-29) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+143-11) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..040fa69e2e9f27b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -28,6 +29,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cassert>
 #include <cstdint>
@@ -351,15 +353,13 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
-template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
-          class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
-struct VectorReductionPattern final
-    : public OpConversionPattern<vector::ReductionOp> {
+template <typename Derived>
+struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
+                  ConversionPatternRewriter &rewriter) const final {
     Type resultType = typeConverter->convertType(reduceOp.getType());
     if (!resultType)
       return failure();
@@ -368,9 +368,22 @@ struct VectorReductionPattern final
     if (!srcVectorType || srcVectorType.getRank() != 1)
       return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
 
-    // Extract all elements.
+    SmallVector<Value> extractedElements =
+        extractAllElements(reduceOp, adaptor, srcVectorType, rewriter);
+
+    const auto &self = static_cast<const Derived &>(*this);
+
+    return self.reduceExtracted(reduceOp, extractedElements, resultType,
+                                rewriter);
+  }
+
+private:
+  SmallVector<Value>
+  extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                     VectorType srcVectorType,
+                     ConversionPatternRewriter &rewriter) const {
     int numElements = srcVectorType.getDimSize(0);
-    SmallVector<Value, 4> values;
+    SmallVector<Value> values;
     values.reserve(numElements + (adaptor.getAcc() != nullptr));
     Location loc = reduceOp.getLoc();
     for (int i = 0; i < numElements; ++i) {
@@ -381,9 +394,26 @@ struct VectorReductionPattern final
     if (Value acc = adaptor.getAcc())
       values.push_back(acc);
 
-    // Reduce them.
-    Value result = values.front();
-    for (Value next : llvm::ArrayRef(values).drop_front()) {
+    return values;
+  }
+};
+
+#define VECTOR_REDUCTION_BASE                                                  \
+  VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp,  \
+                                                    SPIRVSMaxOp, SPIRVSMinOp>>
+template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
+          typename SPIRVSMinOp>
+struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
+  using Base = VECTOR_REDUCTION_BASE;
+  using Base::Base;
+
+  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+                                ArrayRef<Value> extractedElements,
+                                Type resultType,
+                                ConversionPatternRewriter &rewriter) const {
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
       switch (reduceOp.getKind()) {
 
 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
@@ -403,10 +433,6 @@ struct VectorReductionPattern final
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
-        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
-        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
         INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
         INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
         INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,13 +442,105 @@ struct VectorReductionPattern final
       case vector::CombiningKind::OR:
       case vector::CombiningKind::XOR:
         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
+      }
+    }
+
+    rewriter.replaceOp(reduceOp, result);
+    return success();
+  }
+};
+#undef VECTOR_REDUCTION_BASE
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
+
+#define MIN_MAX_PATTERN_BASE                                                   \
+  VectorReductionPatternBase<                                                  \
+      VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
+template <class SPIRVFMaxOp, class SPIRVFMinOp>
+struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
+  using Base = MIN_MAX_PATTERN_BASE;
+  using Base::Base;
+
+  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+                                ArrayRef<Value> extractedElements,
+                                Type resultType,
+                                ConversionPatternRewriter &rewriter) const {
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
+      switch (reduceOp.getKind()) {
+
+#define INT_OR_FLOAT_CASE(kind, fop)                                           \
+  case vector::CombiningKind::kind: {                                          \
+    fop op = rewriter.create<fop>(loc, resultType, result, next);              \
+    result = this->generateActionForOp(rewriter, loc, resultType, op,          \
+                                       vector::CombiningKind::kind);           \
+    break;                                                                     \
+  }
+
+        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
+        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
+
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
     }
 
     rewriter.replaceOp(reduceOp, result);
     return success();
   }
+
+private:
+  enum class Action { Nothing, PropagateNaN, PropagateNonNaN };
+
+  template <typename Op>
+  Action getActionForOp(vector::CombiningKind kind) const {
+    constexpr bool isCLOp = std::is_same_v<Op, spirv::CLFMaxOp> ||
+                            std::is_same_v<Op, spirv::CLFMinOp>;
+    switch (kind) {
+    case vector::CombiningKind::MINIMUMF:
+    case vector::CombiningKind::MAXIMUMF:
+      return Action::PropagateNaN;
+    case vector::CombiningKind::MINF:
+    case vector::CombiningKind::MAXF:
+      // CL ops already have the same semantic for NaNs as MINF/MAXF
+      // GL ops have undefined semantics for NaNs, so we need to explicitly
+      // propagate the non-NaN values
+      return isCLOp ? Action::Nothing : Action::PropagateNonNaN;
+    default:
+      llvm_unreachable("Unexpected case for the switch");
+    }
+  }
+
+  template <typename Op>
+  Value generateActionForOp(ConversionPatternRewriter &rewriter,
+                            mlir::Location loc, Type resultType, Op op,
+                            vector::CombiningKind kind) const {
+    Action action = getActionForOp<Op>(kind);
+
+    if (action == Action::Nothing) {
+      return op;
+    }
+
+    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getLhs());
+    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getRhs());
+
+    Value select1 = rewriter.create<spirv::SelectOp>(
+        loc, resultType, lhsIsNan,
+        action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op);
+    Value select2 = rewriter.create<spirv::SelectOp>(
+        loc, resultType, rhsIsNan,
+        action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1);
+
+    return select2;
+  }
 };
+#undef MIN_MAX_PATTERN_BASE
+#undef INT_OR_FLOAT_CASE
 
 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
 public:
@@ -604,25 +722,28 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
 };
 
 } // namespace
-#define CL_MAX_MIN_OPS                                                         \
-  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
-      spirv::CLSMaxOp, spirv::CLSMinOp
+#define CL_INT_MAX_MIN_OPS                                                     \
+  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_INT_MAX_MIN_OPS                                                     \
+  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
 
-#define GL_MAX_MIN_OPS                                                         \
-  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp,          \
-      spirv::GLSMaxOp, spirv::GLSMinOp
+#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
+#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
-  patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
-               VectorExtractElementOpConvert, VectorExtractOpConvert,
-               VectorExtractStridedSliceOpConvert,
-               VectorFmaOpConvert<spirv::GLFmaOp>,
-               VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-               VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
-               VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
-               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-               VectorSplatPattern>(typeConverter, patterns.getContext());
+  patterns.add<
+      VectorBitcastConvert, VectorBroadcastConvert,
+      VectorExtractElementOpConvert, VectorExtractOpConvert,
+      VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+      VectorSplatPattern>(typeConverter, patterns.getContext());
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..91836e556147b8d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
-//       CHECK:   %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
-//       CHECK:   %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
-//       CHECK:   return %[[MAX2]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
 func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
   %reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
@@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
+func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_minf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
 //       CHECK:   %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]]
 //       CHECK:   %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]]
 //       CHECK:   return %[[MIN2]]
-func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
-  %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
 }
 
@@ -522,9 +574,21 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
-//       CHECK:   %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]]
-//       CHECK:   %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]]
-//       CHECK:   return %[[MAX2]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
 func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
   %reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
@@ -532,15 +596,55 @@ func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
+//       CHECK:   %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+//       CHECK:   %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+//       CHECK:   %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32
+//       CHECK:   %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
+//       CHECK:   %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+//       CHECK:   %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+//       CHECK:   %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+//       CHECK:   %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32
+//       CHECK:   %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
+//       CHECK:   %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+//       CHECK:   %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+//       CHECK:   %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+//       CHECK:   %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32
+//       CHECK:   %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
+//       CHECK:   return %[[SELECT5]]
+func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_minimumf
 //  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
 //       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
 //       CHECK:   %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
-//       CHECK:  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/69025


More information about the Mlir-commits mailing list