[Mlir-commits] [mlir] e9b82a5 - [mlir][Vector] Add LLVM lowering for masked reductions

Diego Caballero llvmlistbot at llvm.org
Tue Feb 14 22:15:25 PST 2023


Author: Diego Caballero
Date: 2023-02-15T06:10:11Z
New Revision: e9b82a5c4fb6fa1c0af1b8e2536252b0730f41ef

URL: https://github.com/llvm/llvm-project/commit/e9b82a5c4fb6fa1c0af1b8e2536252b0730f41ef
DIFF: https://github.com/llvm/llvm-project/commit/e9b82a5c4fb6fa1c0af1b8e2536252b0730f41ef.diff

LOG: [mlir][Vector] Add LLVM lowering for masked reductions

This patch adds the conversion patterns to lower masked reduction
operations to the corresponding vp intrinsics in LLVM.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D142177

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6eabe4927ff6a..68865d3c167cf 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -408,15 +409,154 @@ class VectorCompressStoreOpConversion
   }
 };
 
+/// Reduction neutral classes for overloading.
+class ReductionNeutralZero {};
+class ReductionNeutralIntOne {};
+class ReductionNeutralFPOne {};
+class ReductionNeutralAllOnes {};
+class ReductionNeutralSIntMin {};
+class ReductionNeutralUIntMin {};
+class ReductionNeutralSIntMax {};
+class ReductionNeutralUIntMax {};
+class ReductionNeutralFPMin {};
+class ReductionNeutralFPMax {};
+
+/// Create the reduction neutral zero value.
+static Value createReductionNeutralValue(ReductionNeutralZero neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
+                                           rewriter.getZeroAttr(llvmType));
+}
+
+/// Create the reduction neutral integer one value.
+static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
+}
+
+/// Create the reduction neutral fp one value.
+static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
+}
+
+/// Create the reduction neutral all-ones value.
+static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, llvmType,
+      rewriter.getIntegerAttr(
+          llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral signed int minimum value.
+static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, llvmType,
+      rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
+                                            llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral unsigned int minimum value.
+static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, llvmType,
+      rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
+                                            llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral signed int maximum value.
+static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, llvmType,
+      rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
+                                            llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral unsigned int maximum value.
+static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, llvmType,
+      rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
+                                            llvmType.getIntOrFloatBitWidth())));
+}
+
+/// Create the reduction neutral fp minimum value.
+static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  auto floatType = llvmType.cast<FloatType>();
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, llvmType,
+      rewriter.getFloatAttr(
+          llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
+                                           /*Negative=*/false)));
+}
+
+/// Create the reduction neutral fp maximum value.
+static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
+                                         ConversionPatternRewriter &rewriter,
+                                         Location loc, Type llvmType) {
+  auto floatType = llvmType.cast<FloatType>();
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, llvmType,
+      rewriter.getFloatAttr(
+          llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
+                                           /*Negative=*/true)));
+}
+
+/// Returns `accumulator` if it has a valid value. Otherwise, creates and
+/// returns a new accumulator value using `ReductionNeutral`.
+template <class ReductionNeutral>
+static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
+                                    Location loc, Type llvmType,
+                                    Value accumulator) {
+  if (accumulator)
+    return accumulator;
+
+  return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
+                                     llvmType);
+}
+
+/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
+/// This is used as effective vector length by some intrinsics supporting
+/// dynamic vector lengths at runtime.
+static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
+                                     Location loc, Type llvmType) {
+  VectorType vType = cast<VectorType>(llvmType);
+  auto vShape = vType.getShape();
+  assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
+
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, rewriter.getI32Type(),
+      rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
+}
+
 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
 /// non-null.
-template <class VectorOp, class ScalarOp>
+template <class LLVMRedIntrinOp, class ScalarOp>
 static Value createIntegerReductionArithmeticOpLowering(
     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
     Value vectorOperand, Value accumulator) {
-  Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+
+  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
+
   if (accumulator)
     result = rewriter.create<ScalarOp>(loc, accumulator, result);
   return result;
@@ -426,11 +566,11 @@ static Value createIntegerReductionArithmeticOpLowering(
 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
 /// the accumulator value if non-null.
-template <class VectorOp>
+template <class LLVMRedIntrinOp>
 static Value createIntegerReductionComparisonOpLowering(
     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
-  Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
   if (accumulator) {
     Value cmp =
         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
@@ -460,6 +600,91 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
 }
 
+template <class LLVMRedIntrinOp>
+static Value createFPReductionComparisonOpLowering(
+    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+    Value vectorOperand, Value accumulator, bool isMin) {
+  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
+
+  if (accumulator)
+    result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
+
+  return result;
+}
+
+/// Overloaded methods to lower a reduction to an llvm instrinsic that requires
+/// a start value. This start value format spans across fp reductions without
+/// mask and all the masked reduction intrinsics.
+template <class LLVMVPRedIntrinOp, class ReductionNeutral>
+static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
+                                          Location loc, Type llvmType,
+                                          Value vectorOperand,
+                                          Value accumulator) {
+  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
+                                                         llvmType, accumulator);
+  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
+                                            /*startValue=*/accumulator,
+                                            vectorOperand);
+}
+
+template <class LLVMVPRedIntrinOp, class ReductionNeutral>
+static Value
+lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
+                             Type llvmType, Value vectorOperand,
+                             Value accumulator, bool reassociateFPReds) {
+  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
+                                                         llvmType, accumulator);
+  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
+                                            /*startValue=*/accumulator,
+                                            vectorOperand, reassociateFPReds);
+}
+
+template <class LLVMVPRedIntrinOp, class ReductionNeutral>
+static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
+                                          Location loc, Type llvmType,
+                                          Value vectorOperand,
+                                          Value accumulator, Value mask) {
+  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
+                                                         llvmType, accumulator);
+  Value vectorLength =
+      createVectorLengthValue(rewriter, loc, vectorOperand.getType());
+  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
+                                            /*startValue=*/accumulator,
+                                            vectorOperand, mask, vectorLength);
+}
+
+template <class LLVMVPRedIntrinOp, class ReductionNeutral>
+static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
+                                          Location loc, Type llvmType,
+                                          Value vectorOperand,
+                                          Value accumulator, Value mask,
+                                          bool reassociateFPReds) {
+  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
+                                                         llvmType, accumulator);
+  Value vectorLength =
+      createVectorLengthValue(rewriter, loc, vectorOperand.getType());
+  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
+                                            /*startValue=*/accumulator,
+                                            vectorOperand, mask, vectorLength,
+                                            reassociateFPReds);
+}
+
+template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
+          class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
+static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
+                                          Location loc, Type llvmType,
+                                          Value vectorOperand,
+                                          Value accumulator, Value mask) {
+  if (llvmType.isIntOrIndex())
+    return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp,
+                                        IntReductionNeutral>(
+        rewriter, loc, llvmType, vectorOperand, accumulator, mask);
+
+  // FP dispatch.
+  return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>(
+      rewriter, loc, llvmType, vectorOperand, accumulator, mask);
+}
+
 /// Conversion pattern for all vector reductions.
 class VectorReductionOpConversion
     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
@@ -478,6 +703,12 @@ class VectorReductionOpConversion
     Value operand = adaptor.getVector();
     Value acc = adaptor.getAcc();
     Location loc = reductionOp.getLoc();
+
+    // Masked reductions are lowered separately.
+    auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
     if (eltType.isIntOrIndex()) {
       // Integer reductions: add/mul/min/max/and/or/xor.
       Value result;
@@ -544,45 +775,31 @@ class VectorReductionOpConversion
       return failure();
 
     // Floating-point reductions: add/mul/min/max
+    Value result;
     if (kind == vector::CombiningKind::ADD) {
-      // Optional accumulator (or zero).
-      Value acc = adaptor.getOperands().size() > 1
-                      ? adaptor.getOperands()[1]
-                      : rewriter.create<LLVM::ConstantOp>(
-                            reductionOp->getLoc(), llvmType,
-                            rewriter.getZeroAttr(eltType));
-      rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
-          reductionOp, llvmType, acc, operand,
-          rewriter.getBoolAttr(reassociateFPReductions));
+      result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
+                                            ReductionNeutralZero>(
+          rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
     } else if (kind == vector::CombiningKind::MUL) {
-      // Optional accumulator (or one).
-      Value acc = adaptor.getOperands().size() > 1
-                      ? adaptor.getOperands()[1]
-                      : rewriter.create<LLVM::ConstantOp>(
-                            reductionOp->getLoc(), llvmType,
-                            rewriter.getFloatAttr(eltType, 1.0));
-      rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
-          reductionOp, llvmType, acc, operand,
-          rewriter.getBoolAttr(reassociateFPReductions));
+      result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
+                                            ReductionNeutralFPOne>(
+          rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
     } else if (kind == vector::CombiningKind::MINF) {
       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
       // NaNs/-0.0/+0.0 in the same way.
-      Value result =
-          rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
-      if (acc)
-        result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
-      rewriter.replaceOp(reductionOp, result);
+      result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
+          rewriter, loc, llvmType, operand, acc,
+          /*isMin=*/true);
     } else if (kind == vector::CombiningKind::MAXF) {
       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
       // NaNs/-0.0/+0.0 in the same way.
-      Value result =
-          rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
-      if (acc)
-        result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
-      rewriter.replaceOp(reductionOp, result);
+      result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
+          rewriter, loc, llvmType, operand, acc,
+          /*isMin=*/false);
     } else
       return failure();
 
+    rewriter.replaceOp(reductionOp, result);
     return success();
   }
 
@@ -590,6 +807,127 @@ class VectorReductionOpConversion
   const bool reassociateFPReductions;
 };
 
+/// Base class to convert a `vector.mask` operation while matching traits
+/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
+/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
+/// method performs a second match against the maskable operation `MaskedOp`.
+/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
+/// implemented by the concrete conversion classes. This method can match
+/// against specific traits of the `vector.mask` and the maskable operation. It
+/// must replace the `vector.mask` operation.
+template <class MaskedOp>
+class VectorMaskOpConversionBase
+    : public ConvertOpToLLVMPattern<vector::MaskOp> {
+public:
+  using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override final {
+    // Match against the maskable operation kind.
+    Operation *maskableOp = maskOp.getMaskableOp();
+    if (!isa<MaskedOp>(maskableOp))
+      return failure();
+    return matchAndRewriteMaskableOp(
+        maskOp, cast<MaskedOp>(maskOp.getMaskableOp()), rewriter);
+  }
+
+protected:
+  virtual LogicalResult
+  matchAndRewriteMaskableOp(vector::MaskOp maskOp,
+                            vector::MaskableOpInterface maskableOp,
+                            ConversionPatternRewriter &rewriter) const = 0;
+};
+
+class MaskedReductionOpConversion
+    : public VectorMaskOpConversionBase<vector::ReductionOp> {
+
+public:
+  using VectorMaskOpConversionBase<
+      vector::ReductionOp>::VectorMaskOpConversionBase;
+
+  virtual LogicalResult matchAndRewriteMaskableOp(
+      vector::MaskOp maskOp, MaskableOpInterface maskableOp,
+      ConversionPatternRewriter &rewriter) const override {
+    auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
+    auto kind = reductionOp.getKind();
+    Type eltType = reductionOp.getDest().getType();
+    Type llvmType = typeConverter->convertType(eltType);
+    Value operand = reductionOp.getVector();
+    Value acc = reductionOp.getAcc();
+    Location loc = reductionOp.getLoc();
+
+    Value result;
+    switch (kind) {
+    case vector::CombiningKind::ADD:
+      result = lowerReductionWithStartValue<
+          LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
+          ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
+                                maskOp.getMask());
+      break;
+    case vector::CombiningKind::MUL:
+      result = lowerReductionWithStartValue<
+          LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
+          ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
+                                 maskOp.getMask());
+      break;
+    case vector::CombiningKind::MINUI:
+      result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp,
+                                            ReductionNeutralUIntMax>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    case vector::CombiningKind::MINSI:
+      result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp,
+                                            ReductionNeutralSIntMax>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    case vector::CombiningKind::MAXUI:
+      result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp,
+                                            ReductionNeutralUIntMin>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    case vector::CombiningKind::MAXSI:
+      result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp,
+                                            ReductionNeutralSIntMin>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    case vector::CombiningKind::AND:
+      result = lowerReductionWithStartValue<LLVM::VPReduceAndOp,
+                                            ReductionNeutralAllOnes>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    case vector::CombiningKind::OR:
+      result = lowerReductionWithStartValue<LLVM::VPReduceOrOp,
+                                            ReductionNeutralZero>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    case vector::CombiningKind::XOR:
+      result = lowerReductionWithStartValue<LLVM::VPReduceXorOp,
+                                            ReductionNeutralZero>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    case vector::CombiningKind::MINF:
+      // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
+      // NaNs/-0.0/+0.0 in the same way.
+      result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp,
+                                            ReductionNeutralFPMax>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    case vector::CombiningKind::MAXF:
+      // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
+      // NaNs/-0.0/+0.0 in the same way.
+      result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp,
+                                            ReductionNeutralFPMin>(
+          rewriter, loc, llvmType, operand, acc, maskOp.getMask());
+      break;
+    }
+
+    // Replace `vector.mask` operation altogether.
+    rewriter.replaceOp(maskOp, result);
+    return success();
+  }
+};
+
 class VectorShuffleOpConversion
     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
 public:
@@ -1381,8 +1719,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
            VectorGatherOpConversion, VectorScatterOpConversion,
            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
            VectorSplatOpLowering, VectorSplatNdOpLowering,
-           VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>(
-          converter);
+           VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
+           MaskedReductionOpConversion>(converter);
   // Transfer ops with rank > 1 are handled by VectorToSCF.
   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
index 8da65b8232ff0..0b612008c4d0f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -1,7 +1,6 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' | FileCheck %s
-// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' | FileCheck %s --check-prefix=REASSOC
+// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' -split-input-file | FileCheck %s --check-prefix=REASSOC
 
-//
 // CHECK-LABEL: @reduce_add_f32(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>)
 //      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
@@ -21,7 +20,8 @@ func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 {
   return %0 : f32
 }
 
-//
+// -----
+
 // CHECK-LABEL: @reduce_mul_f32(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>)
 //      CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
@@ -40,3 +40,191 @@ func.func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 {
   %0 = vector.reduction <mul>, %arg0 : vector<16xf32> into f32
   return %0 : f32
 }
+
+// -----
+
+func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_add_f32(
+// CHECK-SAME:                              %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME:                              %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+
+
+// -----
+
+func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <mul>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_mul_f32(
+// CHECK-SAME:                              %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME:                              %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.fmul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+
+
+// -----
+
+func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <minf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_minf_f32(
+// CHECK-SAME:                                      %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME:                                      %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+
+// -----
+
+func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <maxf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_maxf_f32(
+// CHECK-SAME:                                      %[[INPUT:.*]]: vector<16xf32>,
+// CHECK-SAME:                                      %[[MASK:.*]]: vector<16xi1>) -> f32 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.fmax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+
+// -----
+
+func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_add_i8(
+// CHECK-SAME:                             %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME:                             %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+
+// -----
+
+func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <mul>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_mul_i8(
+// CHECK-SAME:                             %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME:                             %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(1 : i8) : i8
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = "llvm.intr.vp.reduce.mul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <minui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_minui_i8(
+// CHECK-SAME:                               %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME:                               %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_maxui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <maxui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_maxui_i8(
+// CHECK-SAME:                               %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME:                               %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.umax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_minsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <minsi>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_minsi_i8(
+// CHECK-SAME:                               %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME:                               %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(127 : i8) : i8
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.smin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <maxsi>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_maxsi_i8(
+// CHECK-SAME:                               %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME:                               %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_or_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <or>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_or_i8(
+// CHECK-SAME:                            %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME:                            %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.or"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+
+// -----
+
+func.func @masked_reduce_and_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <and>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_and_i8(
+// CHECK-SAME:                             %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME:                             %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.and"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+// -----
+
+func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <xor>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_xor_i8(
+// CHECK-SAME:                             %[[INPUT:.*]]: vector<32xi8>,
+// CHECK-SAME:                             %[[MASK:.*]]: vector<32xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+
+


        


More information about the Mlir-commits mailing list