[Mlir-commits] [mlir] dad9de0 - [mlir][vector] Improve lowering to LLVM for `minf`, `maxf` reductions

Daniil Dudkin llvmlistbot at llvm.org
Wed Aug 2 10:27:30 PDT 2023


Author: Daniil Dudkin
Date: 2023-08-02T20:26:59+03:00
New Revision: dad9de0ae5360b18c890985d212bec266bf8c122

URL: https://github.com/llvm/llvm-project/commit/dad9de0ae5360b18c890985d212bec266bf8c122
DIFF: https://github.com/llvm/llvm-project/commit/dad9de0ae5360b18c890985d212bec266bf8c122.diff

LOG: [mlir][vector] Improve lowering to LLVM for `minf`, `maxf` reductions

This patch improves the lowering by changing target LLVM intrinsics from
`reduce.fmax` and `reduce.fmin`,
which have different semantic for handling NaN,
to `reduce.fmaximum` and `reduce.fminimum` ones.

Fixes #63969

Depends on D155869

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D155877

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index fc93f0537c47f0..1e9913048e8ac4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -573,35 +573,31 @@ static Value createIntegerReductionComparisonOpLowering(
   return result;
 }
 
-/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
-/// with vector types.
-static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
-                           Value rhs, bool isMin) {
-  auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
-  Type i1Type = builder.getI1Type();
-  if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
-  Value cmp = builder.create<LLVM::FCmpOp>(
-      loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
-      lhs, rhs);
-  Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
-  Value isNan = builder.create<LLVM::FCmpOp>(
-      loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
-  Value nan = builder.create<LLVM::ConstantOp>(
-      loc, lhs.getType(),
-      builder.getFloatAttr(floatType,
-                           APFloat::getQNaN(floatType.getFloatSemantics())));
-  return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
-}
+namespace {
+template <typename Source>
+struct VectorToScalarMapper;
+template <>
+struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
+  using Type = LLVM::MaximumOp;
+};
+template <>
+struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
+  using Type = LLVM::MinimumOp;
+};
+} // namespace
 
 template <class LLVMRedIntrinOp>
-static Value createFPReductionComparisonOpLowering(
-    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
-    Value vectorOperand, Value accumulator, bool isMin) {
+static Value
+createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
+                                      Location loc, Type llvmType,
+                                      Value vectorOperand, Value accumulator) {
   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
 
-  if (accumulator)
-    result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
+  if (accumulator) {
+    result =
+        rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
+            loc, result, accumulator);
+  }
 
   return result;
 }
@@ -774,17 +770,13 @@ class VectorReductionOpConversion
                                             ReductionNeutralFPOne>(
           rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
     } else if (kind == vector::CombiningKind::MINF) {
-      // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
-      // NaNs/-0.0/+0.0 in the same way.
-      result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
-          rewriter, loc, llvmType, operand, acc,
-          /*isMin=*/true);
+      result =
+          createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
+              rewriter, loc, llvmType, operand, acc);
     } else if (kind == vector::CombiningKind::MAXF) {
-      // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
-      // NaNs/-0.0/+0.0 in the same way.
-      result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
-          rewriter, loc, llvmType, operand, acc,
-          /*isMin=*/false);
+      result =
+          createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
+              rewriter, loc, llvmType, operand, acc);
     } else
       return failure();
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 40d4934b2a40f6..fa119e290ae8d8 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1374,12 +1374,8 @@ func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
 }
 // CHECK-LABEL: @reduce_fmax_f32(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
-//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmax(%[[A]]) : (vector<16xf32>) -> f32
-//      CHECK: %[[C0:.*]] = llvm.fcmp "ogt" %[[V]], %[[B]] : f32
-//      CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
-//      CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
-//      CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
-//      CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
+//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmaximum(%[[A]]) : (vector<16xf32>) -> f32
+//      CHECK: %[[R:.*]] = llvm.intr.maximum(%[[V]], %[[B]]) : (f32, f32) -> f32
 //      CHECK: return %[[R]] : f32
 
 // -----
@@ -1390,12 +1386,8 @@ func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
 }
 // CHECK-LABEL: @reduce_fmin_f32(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
-//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmin(%[[A]]) : (vector<16xf32>) -> f32
-//      CHECK: %[[C0:.*]] = llvm.fcmp "olt" %[[V]], %[[B]] : f32
-//      CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
-//      CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
-//      CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
-//      CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
+//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fminimum(%[[A]]) : (vector<16xf32>) -> f32
+//      CHECK: %[[R:.*]] = llvm.intr.minimum(%[[V]], %[[B]]) : (f32, f32) -> f32
 //      CHECK: return %[[R]] : f32
 
 // -----


        


More information about the Mlir-commits mailing list