[Mlir-commits] [mlir] 8f5d519 - [mlir][vector] Implement Workaround Lowerings for Masked `fm**imum` Reductions

Diego Caballero llvmlistbot at llvm.org
Wed Sep 13 15:50:45 PDT 2023


Author: Daniil Dudkin
Date: 2023-09-13T22:49:08Z
New Revision: 8f5d519458aaf8ca7731ee974b912f6897078282

URL: https://github.com/llvm/llvm-project/commit/8f5d519458aaf8ca7731ee974b912f6897078282
DIFF: https://github.com/llvm/llvm-project/commit/8f5d519458aaf8ca7731ee974b912f6897078282.diff

LOG: [mlir][vector] Implement Workaround Lowerings for Masked `fm**imum` Reductions

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

Within LLVM, there are no masked reduction counterparts for vector reductions such as `fmaximum` and `fminimum`.
More information can be found here: https://github.com/llvm/llvm-project/issues/64940#issuecomment-1690694156.

To address this issue in MLIR, where we need to generate appropriate lowerings for these cases, we employ regular non-masked intrinsics.
However, we modify the input vector using the `arith.select` operation to effectively deactivate undesired elements using a "neutral mask value".
The neutral mask value is the smallest possible value for the `fmaximum` reduction and the largest possible value for the `fminimum` reduction.

Depends on D158618

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D158773

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 335e113d12b7e3c..a979237d1f63e17 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -15,13 +15,17 @@
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/Support/Casting.h"
 #include <optional>
 
@@ -603,6 +607,51 @@ createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
   return result;
 }
 
+/// Reduction neutral classes for overloading
+class MaskNeutralFMaximum {};
+class MaskNeutralFMinimum {};
+
+/// Get the mask neutral floating point maximum value
+static llvm::APFloat
+getMaskNeutralValue(MaskNeutralFMaximum,
+                    const llvm::fltSemantics &floatSemantics) {
+  return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
+}
+/// Get the mask neutral floating point minimum value
+static llvm::APFloat
+getMaskNeutralValue(MaskNeutralFMinimum,
+                    const llvm::fltSemantics &floatSemantics) {
+  return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
+}
+
+/// Create the mask neutral floating point MLIR vector constant
+template <typename MaskNeutral>
+static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
+                                    Location loc, Type llvmType,
+                                    Type vectorType) {
+  const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
+  auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
+  auto denseValue =
+      DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
+  return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
+}
+
+/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
+/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
+/// `fmaximum`/`fminimum`.
+/// More information: https://github.com/llvm/llvm-project/issues/64940
+template <class LLVMRedIntrinOp, class MaskNeutral>
+static Value lowerMaskedReductionWithRegular(
+    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+    Value vectorOperand, Value accumulator, Value mask) {
+  const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
+      rewriter, loc, llvmType, vectorOperand.getType());
+  const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
+      loc, mask, vectorOperand, vectorMaskNeutral);
+  return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
+      rewriter, loc, llvmType, selectedVectorByMask, accumulator);
+}
+
 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires
 /// a start value. This start value format spans across fp reductions without
 /// mask and all the masked reduction intrinsics.
@@ -903,10 +952,16 @@ class MaskedReductionOpConversion
                                             ReductionNeutralFPMin>(
           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
       break;
-    default:
-      return rewriter.notifyMatchFailure(
-          maskOp,
-          "lowering to LLVM is not implemented for this masked operation");
+    case CombiningKind::MAXIMUMF:
+      result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
+                                               MaskNeutralFMaximum>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    case CombiningKind::MINIMUMF:
+      result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
+                                               MaskNeutralFMinimum>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
     }
 
     // Replace `vector.mask` operation altogether.

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
index c1b1eb05077f2ee..fd2d6ae5a472f16 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -101,6 +101,36 @@ func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)
 
 // -----
 
+func.func @masked_reduce_maximumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <maximumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_maximumf_f32(
+// CHECK-SAME:                                      %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME:                                      %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK:           %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<-1.401300e-45> : vector<16xf32>) : vector<16xf32>
+// CHECK:           %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
+// CHECK:           %[[RESULT:.*]] = llvm.intr.vector.reduce.fmaximum(%[[MASKED]])  : (vector<16xf32>) -> f32
+// CHECK:           return %[[RESULT]]
+
+// -----
+
+func.func @masked_reduce_minimumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <minimumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_minimumf_f32(
+// CHECK-SAME:                                      %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME:                                      %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK:           %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<3.40282347E+38> : vector<16xf32>) : vector<16xf32>
+// CHECK:           %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
+// CHECK:           %[[RESULT:.*]] = llvm.intr.vector.reduce.fminimum(%[[MASKED]])  : (vector<16xf32>) -> f32
+// CHECK:           return %[[RESULT]]
+
+// -----
+
 func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
   %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
   return %0 : i8


        


More information about the Mlir-commits mailing list