[Mlir-commits] [mlir] 8fe076f - [mlir][VectorToLLVM] Fix bug in lowering of vector.reduce fmax/fmin

Thomas Raoux llvmlistbot at llvm.org
Tue Jul 12 15:03:49 PDT 2022


Author: Thomas Raoux
Date: 2022-07-12T22:03:39Z
New Revision: 8fe076ffe09028bef761b0d0ebdd5842c595ca87

URL: https://github.com/llvm/llvm-project/commit/8fe076ffe09028bef761b0d0ebdd5842c595ca87
DIFF: https://github.com/llvm/llvm-project/commit/8fe076ffe09028bef761b0d0ebdd5842c595ca87.diff

LOG: [mlir][VectorToLLVM] Fix bug in lowering of vector.reduce fmax/fmin

The lowering of fmax/fmin reduce was ignoring the optional accumulator.

Differential Revision: https://reviews.llvm.org/D129597

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index fa4920486aad8..bf9967f0001d3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -393,6 +394,27 @@ static Value createIntegerReductionComparisonOpLowering(
   return result;
 }
 
+/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
+/// with vector types.
+static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
+                           Value rhs, bool isMin) {
+  auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
+  Type i1Type = builder.getI1Type();
+  if (auto vecType = lhs.getType().dyn_cast<VectorType>())
+    i1Type = VectorType::get(vecType.getShape(), i1Type);
+  Value cmp = builder.create<LLVM::FCmpOp>(
+      loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
+      lhs, rhs);
+  Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
+  Value isNan = builder.create<LLVM::FCmpOp>(
+      loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
+  Value nan = builder.create<LLVM::ConstantOp>(
+      loc, lhs.getType(),
+      builder.getFloatAttr(floatType,
+                           APFloat::getQNaN(floatType.getFloatSemantics())));
+  return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
+}
+
 /// Conversion pattern for all vector reductions.
 class VectorReductionOpConversion
     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
@@ -497,18 +519,25 @@ class VectorReductionOpConversion
       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
           reductionOp, llvmType, acc, operand,
           rewriter.getBoolAttr(reassociateFPReductions));
-    } else if (kind == vector::CombiningKind::MINF)
+    } else if (kind == vector::CombiningKind::MINF) {
       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
       // NaNs/-0.0/+0.0 in the same way.
-      rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
-                                                            llvmType, operand);
-    else if (kind == vector::CombiningKind::MAXF)
+      Value result =
+          rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
+      if (acc)
+        result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
+      rewriter.replaceOp(reductionOp, result);
+    } else if (kind == vector::CombiningKind::MAXF) {
       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
       // NaNs/-0.0/+0.0 in the same way.
-      rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
-                                                            llvmType, operand);
-    else
+      Value result =
+          rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
+      if (acc)
+        result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
+      rewriter.replaceOp(reductionOp, result);
+    } else
       return failure();
+
     return success();
   }
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 0d8406c793d67..1bca6b35685d4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1213,6 +1213,38 @@ func.func @reduce_mul_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
 
 // -----
 
+func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
+  %0 = vector.reduction <maxf>, %arg0, %arg1 : vector<16xf32> into f32
+  return %0 : f32
+}
+// CHECK-LABEL: @reduce_fmax_f32(
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
+//      CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmax"(%[[A]]) : (vector<16xf32>) -> f32
+//      CHECK: %[[C0:.*]] = llvm.fcmp "ogt" %[[V]], %[[B]] : f32
+//      CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
+//      CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
+//      CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+//      CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
+//      CHECK: return %[[R]] : f32
+
+// -----
+
+func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
+  %0 = vector.reduction <minf>, %arg0, %arg1 : vector<16xf32> into f32
+  return %0 : f32
+}
+// CHECK-LABEL: @reduce_fmin_f32(
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
+//      CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmin"(%[[A]]) : (vector<16xf32>) -> f32
+//      CHECK: %[[C0:.*]] = llvm.fcmp "olt" %[[V]], %[[B]] : f32
+//      CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
+//      CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
+//      CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+//      CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
+//      CHECK: return %[[R]] : f32
+
+// -----
+
 func.func @reduce_minui_i32(%arg0: vector<16xi32>) -> i32 {
   %0 = vector.reduction <minui>, %arg0 : vector<16xi32> into i32
   return %0 : i32


        


More information about the Mlir-commits mailing list