[Mlir-commits] [mlir] 8f5d519 - [mlir][vector] Implement Workaround Lowerings for Masked `fm**imum` Reductions
Diego Caballero
llvmlistbot at llvm.org
Wed Sep 13 15:50:45 PDT 2023
Author: Daniil Dudkin
Date: 2023-09-13T22:49:08Z
New Revision: 8f5d519458aaf8ca7731ee974b912f6897078282
URL: https://github.com/llvm/llvm-project/commit/8f5d519458aaf8ca7731ee974b912f6897078282
DIFF: https://github.com/llvm/llvm-project/commit/8f5d519458aaf8ca7731ee974b912f6897078282.diff
LOG: [mlir][vector] Implement Workaround Lowerings for Masked `fm**imum` Reductions
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
Within LLVM, there are no masked reduction counterparts for vector reductions such as `fmaximum` and `fminimum`.
More information can be found here: https://github.com/llvm/llvm-project/issues/64940#issuecomment-1690694156.
To address this issue in MLIR, where we need to generate appropriate lowerings for these cases, we employ regular non-masked intrinsics.
However, we modify the input vector using the `arith.select` operation to effectively deactivate undesired elements using a "neutral mask value".
The neutral mask value is the smallest possible value for the `fmaximum` reduction and the largest possible value for the `fminimum` reduction.
Depends on D158618
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D158773
Added:
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 335e113d12b7e3c..a979237d1f63e17 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -15,13 +15,17 @@
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/Support/Casting.h"
#include <optional>
@@ -603,6 +607,51 @@ createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
return result;
}
+/// Reduction neutral classes for overloading
+class MaskNeutralFMaximum {};
+class MaskNeutralFMinimum {};
+
+/// Get the mask neutral floating point maximum value
+static llvm::APFloat
+getMaskNeutralValue(MaskNeutralFMaximum,
+ const llvm::fltSemantics &floatSemantics) {
+ return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
+}
+/// Get the mask neutral floating point minimum value
+static llvm::APFloat
+getMaskNeutralValue(MaskNeutralFMinimum,
+ const llvm::fltSemantics &floatSemantics) {
+ return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
+}
+
+/// Create the mask neutral floating point MLIR vector constant
+template <typename MaskNeutral>
+static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType,
+ Type vectorType) {
+ const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
+ auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
+ auto denseValue =
+ DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
+ return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
+}
+
+/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
+/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
+/// `fmaximum`/`fminimum`.
+/// More information: https://github.com/llvm/llvm-project/issues/64940
+template <class LLVMRedIntrinOp, class MaskNeutral>
+static Value lowerMaskedReductionWithRegular(
+ ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+ Value vectorOperand, Value accumulator, Value mask) {
+ const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
+ rewriter, loc, llvmType, vectorOperand.getType());
+ const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
+ loc, mask, vectorOperand, vectorMaskNeutral);
+ return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
+ rewriter, loc, llvmType, selectedVectorByMask, accumulator);
+}
+
/// Overloaded methods to lower a reduction to an llvm instrinsic that requires
/// a start value. This start value format spans across fp reductions without
/// mask and all the masked reduction intrinsics.
@@ -903,10 +952,16 @@ class MaskedReductionOpConversion
ReductionNeutralFPMin>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
- default:
- return rewriter.notifyMatchFailure(
- maskOp,
- "lowering to LLVM is not implemented for this masked operation");
+ case CombiningKind::MAXIMUMF:
+ result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
+ MaskNeutralFMaximum>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ case CombiningKind::MINIMUMF:
+ result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
+ MaskNeutralFMinimum>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
}
// Replace `vector.mask` operation altogether.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
index c1b1eb05077f2ee..fd2d6ae5a472f16 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -101,6 +101,36 @@ func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)
// -----
+func.func @masked_reduce_maximumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <maximumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @masked_reduce_maximumf_f32(
+// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<-1.401300e-45> : vector<16xf32>) : vector<16xf32>
+// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
+// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fmaximum(%[[MASKED]]) : (vector<16xf32>) -> f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @masked_reduce_minimumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <minimumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @masked_reduce_minimumf_f32(
+// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<3.40282347E+38> : vector<16xf32>) : vector<16xf32>
+// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
+// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fminimum(%[[MASKED]]) : (vector<16xf32>) -> f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
More information about the Mlir-commits
mailing list