[Mlir-commits] [mlir] 9452356 - [mlir][Vector] Add support for masked vector.contract

Diego Caballero llvmlistbot at llvm.org
Tue Feb 14 22:15:27 PST 2023


Author: Diego Caballero
Date: 2023-02-15T06:10:22Z
New Revision: 9452356ddcf44720387eda9b6316b7cadff1f60d

URL: https://github.com/llvm/llvm-project/commit/9452356ddcf44720387eda9b6316b7cadff1f60d
DIFF: https://github.com/llvm/llvm-project/commit/9452356ddcf44720387eda9b6316b7cadff1f60d.diff

LOG: [mlir][Vector] Add support for masked vector.contract

This patch adds support for masking vector.contract ops with the
vector.mask approach. This also includes the lowering of vector.contract
through the vector.outerproduct path to LLVM. For now, this only adds
support for one of the many potential flavors of
vector.contract/vector.outerproduct but unsupported cases will fail
gratefully.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D143965

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b366cf84467ac..6f6d80c11c96a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -91,6 +91,7 @@ def Vector_ContractionOp :
       PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
       PredOpTrait<"third operand acc and result have same element type",
                   TCresVTEtIsSameAsOpBase<0, 2>>,
+      DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
@@ -632,6 +633,7 @@ def Vector_ExtractOp :
 def Vector_FMAOp :
   Op<Vector_Dialect, "fma", [
        Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
+       DeclareOpInterfaceMethods<MaskableOpInterface>,
        DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
      ] # ElementwiseMappable.traits>,
     Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
@@ -923,7 +925,8 @@ def Vector_OuterProductOp :
     PredOpTrait<"lhs operand and result have same element type",
                 TCresVTEtIsSameAsOpBase<0, 0>>,
     PredOpTrait<"rhs operand and result have same element type",
-                TCresVTEtIsSameAsOpBase<0, 1>>]>,
+                TCresVTEtIsSameAsOpBase<0, 1>>,
+    DeclareOpInterfaceMethods<MaskableOpInterface>]>,
     Arguments<(ins AnyVector:$lhs, AnyType:$rhs,
                Variadic<AnyVector>:$acc,
                DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 68865d3c167cf..a1c1c3dca8362 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1107,12 +1107,48 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
     VectorType vType = fmaOp.getVectorType();
     if (vType.getRank() > 1)
       return failure();
+
+    // Masked fmas are lowered separately.
+    auto maskableOp = cast<MaskableOpInterface>(fmaOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
     return success();
   }
 };
 
+/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their
+/// LLVM counterpart representation. Non side effecting VP intrinsics are not
+/// fully supported by some backends, including x86, and they don't support
+/// pass-through values either. For these reasons, we generate an unmasked
+/// fma followed by a select instrution to emulate the masking behavior.
+/// This pattern is peepholed by some backends with support for masked fma
+/// instructions. This pattern does not match vectors of n >= 2 rank.
+class MaskedFMAOp1DConversion
+    : public VectorMaskOpConversionBase<vector::FMAOp> {
+public:
+  using VectorMaskOpConversionBase<vector::FMAOp>::VectorMaskOpConversionBase;
+
+  MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr)
+      : VectorMaskOpConversionBase<vector::FMAOp>(converter) {}
+
+  virtual LogicalResult matchAndRewriteMaskableOp(
+      vector::MaskOp maskOp, MaskableOpInterface maskableOp,
+      ConversionPatternRewriter &rewriter) const override {
+    auto fmaOp = cast<FMAOp>(maskableOp.getOperation());
+    Type llvmType = typeConverter->convertType(fmaOp.getVectorType());
+
+    Value fmulAddOp = rewriter.create<LLVM::FMulAddOp>(
+        fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(),
+        fmaOp.getAcc());
+    rewriter.replaceOpWithNewOp<LLVM::SelectOp>(
+        maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc());
+    return success();
+  }
+};
+
 class VectorInsertElementOpConversion
     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
 public:
@@ -1279,6 +1315,11 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
     if (vType.getRank() < 2)
       return failure();
 
+    // Masked fmas are lowered separately.
+    auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
     auto loc = op.getLoc();
     auto elemType = vType.getElementType();
     Value zero = rewriter.create<arith::ConstantOp>(
@@ -1707,9 +1748,10 @@ void mlir::populateVectorToLLVMConversionPatterns(
   patterns
       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
            VectorExtractElementOpConversion, VectorExtractOpConversion,
-           VectorFMAOp1DConversion, VectorInsertElementOpConversion,
-           VectorInsertOpConversion, VectorPrintOpConversion,
-           VectorTypeCastOpConversion, VectorScaleOpConversion,
+           VectorFMAOp1DConversion, MaskedFMAOp1DConversion,
+           VectorInsertElementOpConversion, VectorInsertOpConversion,
+           VectorPrintOpConversion, VectorTypeCastOpConversion,
+           VectorScaleOpConversion,
            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
            VectorLoadStoreConversion<vector::MaskedLoadOp,
                                      vector::MaskedLoadOpAdaptor>,

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 02cee8c136557..3e145f1b66c91 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -889,6 +889,34 @@ LogicalResult ContractionOp::verify() {
   return success();
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type ContractionOp::getExpectedMaskType() {
+  auto indexingMaps = this->getIndexingMapsArray();
+  AffineMap lhsIdxMap = indexingMaps[0];
+  AffineMap rhsIdxMap = indexingMaps[1];
+  VectorType lhsType = this->getLhsType();
+  VectorType rhsType = this->getRhsType();
+
+  unsigned numVecDims = lhsIdxMap.getNumDims();
+  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
+
+  // Using the information in the indexing maps, extract the size of each
+  // dimension in the vector.contract operation from the two input operands.
+  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
+    maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
+    maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+
+  assert(!ShapedType::isDynamicShape(maskShape) &&
+         "Mask shape couldn't be computed");
+
+  return VectorType::get(maskShape,
+                         IntegerType::get(lhsType.getContext(), /*width=*/1));
+}
+
 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
   return SmallVector<StringRef>{getIndexingMapsAttrName(),
                                 getIteratorTypesAttrName(), getKindAttrName()};
@@ -1760,6 +1788,16 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type FMAOp::getExpectedMaskType() {
+  auto vecType = this->getVectorType();
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
@@ -2762,6 +2800,16 @@ LogicalResult OuterProductOp::verify() {
   return success();
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type OuterProductOp::getExpectedMaskType() {
+  auto vecType = this->getVectorType();
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c7aee6d537a80..9976cf79fb8e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -147,13 +147,13 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
 }
 
 /// Helper to create arithmetic operation associated with a kind of contraction.
-static std::optional<Value> createContractArithOp(Location loc, Value x,
-                                                  Value y, Value acc,
-                                                  vector::CombiningKind kind,
-                                                  PatternRewriter &rewriter,
-                                                  bool isInt) {
+static std::optional<Value>
+createContractArithOp(Location loc, Value x, Value y, Value acc,
+                      vector::CombiningKind kind, PatternRewriter &rewriter,
+                      bool isInt, Optional<Value> maybeMask = std::nullopt) {
   using vector::CombiningKind;
   Value mul;
+
   if (isInt) {
     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
       // Only valid for floating point types.
@@ -169,11 +169,17 @@ static std::optional<Value> createContractArithOp(Location loc, Value x,
       return std::nullopt;
     // Special case for fused multiply-add.
     if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
-      return std::optional<Value>(
-          rewriter.create<vector::FMAOp>(loc, x, y, acc));
+      Operation *fmaOp = rewriter.create<vector::FMAOp>(loc, x, y, acc);
+      if (maybeMask.has_value() && maybeMask.value())
+        fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value());
+      return fmaOp->getResult(0);
     }
     mul = rewriter.create<arith::MulFOp>(loc, x, y);
   }
+
+  assert((!maybeMask.has_value() || !maybeMask.value()) &&
+         "Unsupported masked case");
+
   if (!acc)
     return std::optional<Value>(mul);
   return makeArithReduction(rewriter, loc, kind, mul, acc);
@@ -550,14 +556,27 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
     vector::CombiningKind kind = op.getKind();
 
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+    Operation *rootOp;
+    Value mask;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = op;
+    }
+
     if (!rhsType) {
       // Special case: AXPY operation.
       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
       std::optional<Value> mult = createContractArithOp(
-          loc, op.getLhs(), b, acc, kind, rewriter, isInt);
+          loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
       if (!mult.has_value())
         return failure();
-      rewriter.replaceOp(op, *mult);
+      rewriter.replaceOp(rootOp, *mult);
       return success();
     }
 
@@ -571,13 +590,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
       Value r = nullptr;
       if (acc)
         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
-      std::optional<Value> m =
-          createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
+      std::optional<Value> m = createContractArithOp(
+          loc, a, op.getRhs(), r, kind, rewriter, isInt, mask);
       if (!m.has_value())
         return failure();
       result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
     }
-    rewriter.replaceOp(op, result);
+
+    rewriter.replaceOp(rootOp, result);
     return success();
   }
 };
@@ -601,7 +621,12 @@ struct ContractOpToElementwise
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
-    // TODO: implement masks
+    // TODO: Support vector.mask.
+    auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
+    // TODO: Remove native masks from contraction op?
     if (!contractOp.getMasks().empty())
       return failure();
 
@@ -1429,7 +1454,12 @@ namespace mlir {
 LogicalResult
 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
                                                  PatternRewriter &rew) const {
-  // TODO: implement masks
+  // TODO: Support vector.mask.
+  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+  if (maskableOp.isMasked())
+    return failure();
+
+  // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
   if (vectorTransformOptions.vectorContractLowering !=
@@ -1525,10 +1555,16 @@ struct UnrolledOuterProductGenerator
   UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
       : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
-        res(op.getAcc()), lhsType(op.getLhsType()) {}
+        res(op.getAcc()), lhsType(op.getLhsType()) {
+    auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+    if (maskableOp.isMasked())
+      mask = maskableOp.getMaskingOp().getMask();
+  }
 
   Value t(Value v) {
     static constexpr std::array<int64_t, 2> perm = {1, 0};
+    if (!v)
+      return v;
     return rewriter.create<vector::TransposeOp>(loc, v, perm);
   }
 
@@ -1547,16 +1583,27 @@ struct UnrolledOuterProductGenerator
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
   }
 
-  Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
+  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+                             Optional<Value> maybeMask = std::nullopt) {
     assert(reductionSize > 0);
+    // Incremental support for masking.
+    if (mask && !maybeMask.has_value())
+      return failure();
+
     Type resElementType = res.getType().cast<VectorType>().getElementType();
     for (int64_t k = 0; k < reductionSize; ++k) {
       Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
       Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
       extractA = promote(extractA, resElementType);
       extractB = promote(extractB, resElementType);
-      res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), extractA,
-                                             extractB, res, kind);
+      Value extractMask;
+      if (maybeMask.has_value() && maybeMask.value())
+        extractMask =
+            rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
+
+      Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
+          loc, res.getType(), extractA, extractB, res, kind);
+      res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
     }
     return res;
   }
@@ -1607,7 +1654,7 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
@@ -1646,7 +1693,7 @@ struct UnrolledOuterProductGenerator
 
 private:
   vector::CombiningKind kind;
-  Value lhs, rhs, res;
+  Value lhs, rhs, res, mask;
   VectorType lhsType;
 };
 } // namespace
@@ -1668,7 +1715,7 @@ struct UnrolledOuterProductGenerator
 /// otherwise supports any layout permutation of the matrix-multiply.
 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
     vector::ContractionOp op, PatternRewriter &rewriter) const {
-  // TODO: implement masks
+  // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
 
@@ -1679,20 +1726,31 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   if (failed(filter(op)))
     return failure();
 
+  // Vector mask setup.
+  OpBuilder::InsertionGuard guard(rewriter);
+  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+  Operation *rootOp;
+  if (maskableOp.isMasked()) {
+    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+    rootOp = maskableOp.getMaskingOp();
+  } else {
+    rootOp = op;
+  }
+
   UnrolledOuterProductGenerator e(rewriter, op);
   FailureOr<Value> matmatRes = e.matmat();
   if (succeeded(matmatRes)) {
-    rewriter.replaceOp(op, *matmatRes);
+    rewriter.replaceOp(rootOp, *matmatRes);
     return success();
   }
   FailureOr<Value> matvecRes = e.matvec();
   if (succeeded(matvecRes)) {
-    rewriter.replaceOp(op, *matvecRes);
+    rewriter.replaceOp(rootOp, *matvecRes);
     return success();
   }
   FailureOr<Value> tmatvecRes = e.tmatvec();
   if (succeeded(tmatvecRes)) {
-    rewriter.replaceOp(op, *tmatvecRes);
+    rewriter.replaceOp(rootOp, *tmatvecRes);
     return success();
   }
 
@@ -1702,7 +1760,12 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
 LogicalResult
 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
                                             PatternRewriter &rewriter) const {
-  // TODO: implement masks
+  // TODO: Support vector.mask.
+  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+  if (maskableOp.isMasked())
+    return failure();
+
+  // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
 
@@ -1834,7 +1897,12 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
 LogicalResult
 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
                                        PatternRewriter &rewriter) const {
-  // TODO: implement masks.
+  // TODO: Support vector.mask.
+  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+  if (maskableOp.isMasked())
+    return failure();
+
+  // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f38f79968726c..a72ab1515a6d1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -416,6 +416,18 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v
 // CHECK:       %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
 // CHECK:       return %[[T19]] : vector<2x3xf32>
 
+// -----
+
+func.func @masked_vector_contract(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// We can't check for the intermediate 'vector.mask { vector.fma }' state so we
+// just make sure the vector.fma is lowered.
+
+// CHECK:           llvm.intr.fmuladd
+// CHECK:           llvm.select
 
 // -----
 
@@ -2145,3 +2157,17 @@ func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> {
   %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
   return %0 : vector<8xf32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @masked_vector_fma(
+// CHECK-SAME:                                 %[[INPUT:.*]]: vector<8xf32>,
+// CHECK-SAME:                                 %[[MASK:.*]]: vector<8xi1>) -> vector<8xf32>
+// CHECK:           %[[FMA:.*]] = llvm.intr.fmuladd(%[[INPUT]], %[[INPUT]], %[[INPUT]])  : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+// CHECK:           llvm.select %[[MASK]], %[[FMA]], %[[INPUT]] : vector<8xi1>, vector<8xf32>
+
+func.func @masked_vector_fma(%a: vector<8xf32>, %m: vector<8xi1>) -> vector<8xf32> {
+  %0 = vector.mask %m { vector.fma %a, %a, %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index d4dc35d670c38..1617cb176e12f 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -1196,3 +1196,27 @@ func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vec
   %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32
   return %0 : f32
 }
+
+func.func @masked_vector_contract(%arg0: vector<2x3xf32>,
+                                  %arg1: vector<3xf32>,
+                                  %arg2: vector<2xf32>,
+                                  %m: vector<2x3xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// OUTERPRODUCT-LABEL:   func.func @masked_vector_contract(
+// OUTERPRODUCT-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
+// OUTERPRODUCT-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
+// OUTERPRODUCT-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
+// OUTERPRODUCT-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// OUTERPRODUCT:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// OUTERPRODUCT:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1>
+// OUTERPRODUCT:           vector.mask %[[MASK0]] { vector.outerproduct
+
+// OUTERPRODUCT:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1>
+// OUTERPRODUCT:           vector.mask %[[MASK1]] { vector.outerproduct
+
+// OUTERPRODUCT:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1>
+// OUTERPRODUCT:           vector.mask %[[MASK2]] { vector.outerproduct


        


More information about the Mlir-commits mailing list