[Mlir-commits] [mlir] e9b82a5 - [mlir][Vector] Add LLVM lowering for masked reductions
Diego Caballero
llvmlistbot at llvm.org
Tue Feb 14 22:15:25 PST 2023
Author: Diego Caballero
Date: 2023-02-15T06:10:11Z
New Revision: e9b82a5c4fb6fa1c0af1b8e2536252b0730f41ef
URL: https://github.com/llvm/llvm-project/commit/e9b82a5c4fb6fa1c0af1b8e2536252b0730f41ef
DIFF: https://github.com/llvm/llvm-project/commit/e9b82a5c4fb6fa1c0af1b8e2536252b0730f41ef.diff
LOG: [mlir][Vector] Add LLVM lowering for masked reductions
This patch adds the conversion patterns to lower masked reduction
operations to the corresponding vp intrinsics in LLVM.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D142177
Added:
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6eabe4927ff6a..68865d3c167cf 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
@@ -408,15 +409,154 @@ class VectorCompressStoreOpConversion
}
};
+/// Reduction neutral classes for overloading.
+class ReductionNeutralZero {};
+class ReductionNeutralIntOne {};
+class ReductionNeutralFPOne {};
+class ReductionNeutralAllOnes {};
+class ReductionNeutralSIntMin {};
+class ReductionNeutralUIntMin {};
+class ReductionNeutralSIntMax {};
+class ReductionNeutralUIntMax {};
+class ReductionNeutralFPMin {};
+class ReductionNeutralFPMax {};
+
+/// Create the reduction neutral zero value.
+static Value createReductionNeutralValue(ReductionNeutralZero neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
+ rewriter.getZeroAttr(llvmType));
+}
+
+/// Create the reduction neutral integer one value.
+static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
+}
+
+/// Create the reduction neutral fp one value.
+static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
+}
+
+/// Create the reduction neutral all-ones value.
+static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, llvmType,
+ rewriter.getIntegerAttr(
+ llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral signed int minimum value.
+static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, llvmType,
+ rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
+ llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral unsigned int minimum value.
+static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, llvmType,
+ rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
+ llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral signed int maximum value.
+static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, llvmType,
+ rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
+ llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral unsigned int maximum value.
+static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, llvmType,
+ rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
+ llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral fp minimum value.
+static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ auto floatType = llvmType.cast<FloatType>();
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, llvmType,
+ rewriter.getFloatAttr(
+ llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
+ /*Negative=*/false)));
+}
+
+/// Create the reduction neutral fp maximum value.
+static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ auto floatType = llvmType.cast<FloatType>();
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, llvmType,
+ rewriter.getFloatAttr(
+ llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
+ /*Negative=*/true)));
+}
+
+/// Returns `accumulator` if it has a valid value. Otherwise, creates and
+/// returns a new accumulator value using `ReductionNeutral`.
+template <class ReductionNeutral>
+static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType,
+ Value accumulator) {
+ if (accumulator)
+ return accumulator;
+
+ return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
+ llvmType);
+}
+
+/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
+/// This is used as effective vector length by some intrinsics supporting
+/// dynamic vector lengths at runtime.
+static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType) {
+ VectorType vType = cast<VectorType>(llvmType);
+ auto vShape = vType.getShape();
+ assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
+
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
+}
+
/// Helper method to lower a `vector.reduction` op that performs an arithmetic
/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
/// and `ScalarOp` is the scalar operation used to add the accumulation value if
/// non-null.
-template <class VectorOp, class ScalarOp>
+template <class LLVMRedIntrinOp, class ScalarOp>
static Value createIntegerReductionArithmeticOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator) {
- Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+
+ Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
+
if (accumulator)
result = rewriter.create<ScalarOp>(loc, accumulator, result);
return result;
@@ -426,11 +566,11 @@ static Value createIntegerReductionArithmeticOpLowering(
/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
/// intrinsic to use and `predicate` is the predicate to use to compare+combine
/// the accumulator value if non-null.
-template <class VectorOp>
+template <class LLVMRedIntrinOp>
static Value createIntegerReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
- Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+ Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
if (accumulator) {
Value cmp =
rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
@@ -460,6 +600,91 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
}
+template <class LLVMRedIntrinOp>
+static Value createFPReductionComparisonOpLowering(
+ ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+ Value vectorOperand, Value accumulator, bool isMin) {
+ Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
+
+ if (accumulator)
+ result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
+
+ return result;
+}
+
+/// Overloaded methods to lower a reduction to an llvm instrinsic that requires
+/// a start value. This start value format spans across fp reductions without
+/// mask and all the masked reduction intrinsics.
+template <class LLVMVPRedIntrinOp, class ReductionNeutral>
+static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType,
+ Value vectorOperand,
+ Value accumulator) {
+ accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
+ llvmType, accumulator);
+ return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
+ /*startValue=*/accumulator,
+ vectorOperand);
+}
+
+template <class LLVMVPRedIntrinOp, class ReductionNeutral>
+static Value
+lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
+ Type llvmType, Value vectorOperand,
+ Value accumulator, bool reassociateFPReds) {
+ accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
+ llvmType, accumulator);
+ return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
+ /*startValue=*/accumulator,
+ vectorOperand, reassociateFPReds);
+}
+
+template <class LLVMVPRedIntrinOp, class ReductionNeutral>
+static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType,
+ Value vectorOperand,
+ Value accumulator, Value mask) {
+ accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
+ llvmType, accumulator);
+ Value vectorLength =
+ createVectorLengthValue(rewriter, loc, vectorOperand.getType());
+ return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
+ /*startValue=*/accumulator,
+ vectorOperand, mask, vectorLength);
+}
+
+template <class LLVMVPRedIntrinOp, class ReductionNeutral>
+static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType,
+ Value vectorOperand,
+ Value accumulator, Value mask,
+ bool reassociateFPReds) {
+ accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
+ llvmType, accumulator);
+ Value vectorLength =
+ createVectorLengthValue(rewriter, loc, vectorOperand.getType());
+ return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
+ /*startValue=*/accumulator,
+ vectorOperand, mask, vectorLength,
+ reassociateFPReds);
+}
+
+template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
+ class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
+static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
+ Location loc, Type llvmType,
+ Value vectorOperand,
+ Value accumulator, Value mask) {
+ if (llvmType.isIntOrIndex())
+ return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp,
+ IntReductionNeutral>(
+ rewriter, loc, llvmType, vectorOperand, accumulator, mask);
+
+ // FP dispatch.
+ return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>(
+ rewriter, loc, llvmType, vectorOperand, accumulator, mask);
+}
+
/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
@@ -478,6 +703,12 @@ class VectorReductionOpConversion
Value operand = adaptor.getVector();
Value acc = adaptor.getAcc();
Location loc = reductionOp.getLoc();
+
+ // Masked reductions are lowered separately.
+ auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
Value result;
@@ -544,45 +775,31 @@ class VectorReductionOpConversion
return failure();
// Floating-point reductions: add/mul/min/max
+ Value result;
if (kind == vector::CombiningKind::ADD) {
- // Optional accumulator (or zero).
- Value acc = adaptor.getOperands().size() > 1
- ? adaptor.getOperands()[1]
- : rewriter.create<LLVM::ConstantOp>(
- reductionOp->getLoc(), llvmType,
- rewriter.getZeroAttr(eltType));
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
- reductionOp, llvmType, acc, operand,
- rewriter.getBoolAttr(reassociateFPReductions));
+ result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
+ ReductionNeutralZero>(
+ rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
} else if (kind == vector::CombiningKind::MUL) {
- // Optional accumulator (or one).
- Value acc = adaptor.getOperands().size() > 1
- ? adaptor.getOperands()[1]
- : rewriter.create<LLVM::ConstantOp>(
- reductionOp->getLoc(), llvmType,
- rewriter.getFloatAttr(eltType, 1.0));
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
- reductionOp, llvmType, acc, operand,
- rewriter.getBoolAttr(reassociateFPReductions));
+ result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
+ ReductionNeutralFPOne>(
+ rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
} else if (kind == vector::CombiningKind::MINF) {
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
// NaNs/-0.0/+0.0 in the same way.
- Value result =
- rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
- if (acc)
- result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
- rewriter.replaceOp(reductionOp, result);
+ result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
+ rewriter, loc, llvmType, operand, acc,
+ /*isMin=*/true);
} else if (kind == vector::CombiningKind::MAXF) {
// FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
// NaNs/-0.0/+0.0 in the same way.
- Value result =
- rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
- if (acc)
- result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
- rewriter.replaceOp(reductionOp, result);
+ result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
+ rewriter, loc, llvmType, operand, acc,
+ /*isMin=*/false);
} else
return failure();
+ rewriter.replaceOp(reductionOp, result);
return success();
}
@@ -590,6 +807,127 @@ class VectorReductionOpConversion
const bool reassociateFPReductions;
};
+/// Base class to convert a `vector.mask` operation while matching traits
+/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
+/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
+/// method performs a second match against the maskable operation `MaskedOp`.
+/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
+/// implemented by the concrete conversion classes. This method can match
+/// against specific traits of the `vector.mask` and the maskable operation. It
+/// must replace the `vector.mask` operation.
+template <class MaskedOp>
+class VectorMaskOpConversionBase
+ : public ConvertOpToLLVMPattern<vector::MaskOp> {
+public:
+ using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override final {
+ // Match against the maskable operation kind.
+ Operation *maskableOp = maskOp.getMaskableOp();
+ if (!isa<MaskedOp>(maskableOp))
+ return failure();
+ return matchAndRewriteMaskableOp(
+ maskOp, cast<MaskedOp>(maskOp.getMaskableOp()), rewriter);
+ }
+
+protected:
+ virtual LogicalResult
+ matchAndRewriteMaskableOp(vector::MaskOp maskOp,
+ vector::MaskableOpInterface maskableOp,
+ ConversionPatternRewriter &rewriter) const = 0;
+};
+
+class MaskedReductionOpConversion
+ : public VectorMaskOpConversionBase<vector::ReductionOp> {
+
+public:
+ using VectorMaskOpConversionBase<
+ vector::ReductionOp>::VectorMaskOpConversionBase;
+
+ virtual LogicalResult matchAndRewriteMaskableOp(
+ vector::MaskOp maskOp, MaskableOpInterface maskableOp,
+ ConversionPatternRewriter &rewriter) const override {
+ auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
+ auto kind = reductionOp.getKind();
+ Type eltType = reductionOp.getDest().getType();
+ Type llvmType = typeConverter->convertType(eltType);
+ Value operand = reductionOp.getVector();
+ Value acc = reductionOp.getAcc();
+ Location loc = reductionOp.getLoc();
+
+ Value result;
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ result = lowerReductionWithStartValue<
+ LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
+ ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
+ maskOp.getMask());
+ break;
+ case vector::CombiningKind::MUL:
+ result = lowerReductionWithStartValue<
+ LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
+ ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
+ maskOp.getMask());
+ break;
+ case vector::CombiningKind::MINUI:
+ result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp,
+ ReductionNeutralUIntMax>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ case vector::CombiningKind::MINSI:
+ result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp,
+ ReductionNeutralSIntMax>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ case vector::CombiningKind::MAXUI:
+ result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp,
+ ReductionNeutralUIntMin>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ case vector::CombiningKind::MAXSI:
+ result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp,
+ ReductionNeutralSIntMin>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ case vector::CombiningKind::AND:
+ result = lowerReductionWithStartValue<LLVM::VPReduceAndOp,
+ ReductionNeutralAllOnes>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ case vector::CombiningKind::OR:
+ result = lowerReductionWithStartValue<LLVM::VPReduceOrOp,
+ ReductionNeutralZero>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ case vector::CombiningKind::XOR:
+ result = lowerReductionWithStartValue<LLVM::VPReduceXorOp,
+ ReductionNeutralZero>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ case vector::CombiningKind::MINF:
+ // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
+ // NaNs/-0.0/+0.0 in the same way.
+ result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp,
+ ReductionNeutralFPMax>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ case vector::CombiningKind::MAXF:
+ // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
+ // NaNs/-0.0/+0.0 in the same way.
+ result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp,
+ ReductionNeutralFPMin>(
+ rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+ break;
+ }
+
+ // Replace `vector.mask` operation altogether.
+ rewriter.replaceOp(maskOp, result);
+ return success();
+ }
+};
+
class VectorShuffleOpConversion
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
public:
@@ -1381,8 +1719,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering,
- VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>(
- converter);
+ VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
+ MaskedReductionOpConversion>(converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
index 8da65b8232ff0..0b612008c4d0f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -1,7 +1,6 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' | FileCheck %s
-// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' | FileCheck %s --check-prefix=REASSOC
+// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' -split-input-file | FileCheck %s --check-prefix=REASSOC
-//
// CHECK-LABEL: @reduce_add_f32(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
@@ -21,7 +20,8 @@ func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 {
return %0 : f32
}
-//
+// -----
+
// CHECK-LABEL: @reduce_mul_f32(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
@@ -40,3 +40,191 @@ func.func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 {
%0 = vector.reduction <mul>, %arg0 : vector<16xf32> into f32
return %0 : f32
}
+
+// -----
+
+func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @masked_reduce_add_f32(
+// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+
+
+// -----
+
+func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <mul>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @masked_reduce_mul_f32(
+// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.fmul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+
+
+// -----
+
+func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <minf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @masked_reduce_minf_f32(
+// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+
+// -----
+
+func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <maxf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @masked_reduce_maxf_f32(
+// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.fmax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+
+// -----
+
+func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_add_i8(
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+
+// -----
+
+func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <mul>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_mul_i8(
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(1 : i8) : i8
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = "llvm.intr.vp.reduce.mul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <minui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_minui_i8(
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_maxui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <maxui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_maxui_i8(
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.umax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_minsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <minsi>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_minsi_i8(
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(127 : i8) : i8
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.smin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <maxsi>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_maxsi_i8(
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_or_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <or>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_or_i8(
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.or"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+
+// -----
+
+func.func @masked_reduce_and_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <and>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_and_i8(
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.and"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <xor>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_xor_i8(
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+
More information about the Mlir-commits
mailing list