[Mlir-commits] [mlir] c339f9e - [mlir][Vector] Support masking for more contraction flavors
Diego Caballero
llvmlistbot at llvm.org
Tue Feb 21 17:52:42 PST 2023
Author: Diego Caballero
Date: 2023-02-22T01:47:44Z
New Revision: c339f9e1c3276bcd8db806bc87045a5ef2079fec
URL: https://github.com/llvm/llvm-project/commit/c339f9e1c3276bcd8db806bc87045a5ef2079fec
DIFF: https://github.com/llvm/llvm-project/commit/c339f9e1c3276bcd8db806bc87045a5ef2079fec.diff
LOG: [mlir][Vector] Support masking for more contraction flavors
This patch adds masking support for more contraction flavors including those
with any combiner operation (add, mul, min, max, and, or, etc.) and
regular matmul contractions.
Combiner operations that are performing vertical reductions (and,
therefore, they are not represented with a horizontal reduction
operation) can be executed unmasked. However, the previous value of
the accumulator must be propagated for lanes that shouldn't accumulate.
We achieve this goal by introducing a select operation after the
accumulator to choose between the combined and the previous accumulator
value. This design decision is made to avoid introducing masking support
to all the arithmetic and logical operations in the Arith dialect. VP
intrinsics do not support pass-thru values either so we would have to
generate the same sequence when lowering to LLVM. The op + select
pattern is peepholed by some backend with native masking support for those
operations.
Consequently, this patch removes masking support from the vector.fma
operation to follow the same approach for all the combiner operations.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D144239
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index deb86df396d1c..56f8b4bf22d21 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -191,7 +191,7 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
/// Return the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
- Value v1, Value v2);
+ Value v1, Value acc, Value mask = Value());
/// Returns true if `attr` has "parallel" iterator type semantics.
inline bool isParallelIterator(Attribute attr) {
@@ -214,8 +214,17 @@ void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
/// Creates a vector.mask operation around a maskable operation. Returns the
/// vector.mask operation if the mask provided is valid. Otherwise, returns the
/// maskable operation itself.
-Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp,
- Value mask);
+Operation *maskOperation(OpBuilder &builder, Operation *maskableOp,
+ Value mask, Value passthru = Value());
+
+/// Creates a vector select operation that picks values from `newValue` or
+/// `passthru` for each result vector lane based on `mask`. This utility is used
+/// to propagate the pass-thru value for masked-out or expeculatively executed
+/// lanes. VP intrinsics do not support pass-thru values and every mask-out lane
+/// is set to poison. LLVM backends are usually able to match op + select
+/// patterns and fold them into a native target instructions.
+Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
+ Value passthru);
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c5ebe9f4bcebc..04fb36a520265 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -633,7 +633,6 @@ def Vector_ExtractOp :
def Vector_FMAOp :
Op<Vector_Dialect, "fma", [
Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
- DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
] # ElementwiseMappable.traits>,
Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 159bae829133c..b73c01afa8f93 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -704,11 +704,6 @@ class VectorReductionOpConversion
Value acc = adaptor.getAcc();
Location loc = reductionOp.getLoc();
- // Masked reductions are lowered separately.
- auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
Value result;
@@ -1108,47 +1103,12 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
if (vType.getRank() > 1)
return failure();
- // Masked fmas are lowered separately.
- auto maskableOp = cast<MaskableOpInterface>(fmaOp.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
return success();
}
};
-/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their
-/// LLVM counterpart representation. Non side effecting VP intrinsics are not
-/// fully supported by some backends, including x86, and they don't support
-/// pass-through values either. For these reasons, we generate an unmasked
-/// fma followed by a select instrution to emulate the masking behavior.
-/// This pattern is peepholed by some backends with support for masked fma
-/// instructions. This pattern does not match vectors of n >= 2 rank.
-class MaskedFMAOp1DConversion
- : public VectorMaskOpConversionBase<vector::FMAOp> {
-public:
- using VectorMaskOpConversionBase<vector::FMAOp>::VectorMaskOpConversionBase;
-
- MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr)
- : VectorMaskOpConversionBase<vector::FMAOp>(converter) {}
-
- virtual LogicalResult matchAndRewriteMaskableOp(
- vector::MaskOp maskOp, MaskableOpInterface maskableOp,
- ConversionPatternRewriter &rewriter) const override {
- auto fmaOp = cast<FMAOp>(maskableOp.getOperation());
- Type llvmType = typeConverter->convertType(fmaOp.getVectorType());
-
- Value fmulAddOp = rewriter.create<LLVM::FMulAddOp>(
- fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(),
- fmaOp.getAcc());
- rewriter.replaceOpWithNewOp<LLVM::SelectOp>(
- maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc());
- return success();
- }
-};
-
class VectorInsertElementOpConversion
: public ConvertOpToLLVMPattern<vector::InsertElementOp> {
public:
@@ -1315,11 +1275,6 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
if (vType.getRank() < 2)
return failure();
- // Masked fmas are lowered separately.
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
auto loc = op.getLoc();
auto elemType = vType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
@@ -1748,10 +1703,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, MaskedFMAOp1DConversion,
- VectorInsertElementOpConversion, VectorInsertOpConversion,
- VectorPrintOpConversion, VectorTypeCastOpConversion,
- VectorScaleOpConversion,
+ VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorInsertOpConversion, VectorPrintOpConversion,
+ VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedLoadOp,
vector::MaskedLoadOpAdaptor>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c6609d98b439..eb58f904462a6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1790,16 +1790,6 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
-// MaskableOpInterface methods.
-
-/// Returns the mask type expected by this operation. Mostly used for
-/// verification purposes. It requires the operation to be vectorized."
-Type FMAOp::getExpectedMaskType() {
- auto vecType = this->getVectorType();
- return VectorType::get(vecType.getShape(),
- IntegerType::get(vecType.getContext(), /*width=*/1));
-}
-
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
@@ -5807,53 +5797,71 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
}
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
- CombiningKind kind, Value v1, Value v2) {
+ CombiningKind kind, Value v1, Value acc,
+ Value mask) {
Type t1 = getElementTypeOrSelf(v1.getType());
- Type t2 = getElementTypeOrSelf(v2.getType());
+ Type tAcc = getElementTypeOrSelf(acc.getType());
+ Value result;
+
switch (kind) {
case CombiningKind::ADD:
- if (t1.isIntOrIndex() && t2.isIntOrIndex())
- return b.createOrFold<arith::AddIOp>(loc, v1, v2);
- else if (t1.isa<FloatType>() && t2.isa<FloatType>())
- return b.createOrFold<arith::AddFOp>(loc, v1, v2);
- llvm_unreachable("invalid value types for ADD reduction");
+ if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
+ result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
+ else if (t1.isa<FloatType>() && tAcc.isa<FloatType>())
+ result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+ else
+ llvm_unreachable("invalid value types for ADD reduction");
+ break;
case CombiningKind::AND:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::AndIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
+ break;
case CombiningKind::MAXF:
- assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ assert(t1.isa<FloatType>() && tAcc.isa<FloatType>() &&
"expected float values");
- return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
+ result = b.createOrFold<arith::MaxFOp>(loc, v1, acc);
+ break;
case CombiningKind::MINF:
- assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ assert(t1.isa<FloatType>() && tAcc.isa<FloatType>() &&
"expected float values");
- return b.createOrFold<arith::MinFOp>(loc, v1, v2);
+ result = b.createOrFold<arith::MinFOp>(loc, v1, acc);
+ break;
case CombiningKind::MAXSI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
+ break;
case CombiningKind::MINSI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
+ break;
case CombiningKind::MAXUI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
+ break;
case CombiningKind::MINUI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
+ break;
case CombiningKind::MUL:
- if (t1.isIntOrIndex() && t2.isIntOrIndex())
- return b.createOrFold<arith::MulIOp>(loc, v1, v2);
- else if (t1.isa<FloatType>() && t2.isa<FloatType>())
- return b.createOrFold<arith::MulFOp>(loc, v1, v2);
- llvm_unreachable("invalid value types for MUL reduction");
+ if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
+ result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
+ else if (t1.isa<FloatType>() && tAcc.isa<FloatType>())
+ result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+ else
+ llvm_unreachable("invalid value types for MUL reduction");
+ break;
case CombiningKind::OR:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::OrIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
+ break;
case CombiningKind::XOR:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
+ break;
};
- llvm_unreachable("unknown CombiningKind");
+
+ assert(result && "unknown CombiningKind");
+ return selectPassthru(b, mask, result, acc);
}
//===----------------------------------------------------------------------===//
@@ -5875,13 +5883,34 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder,
/// Creates a vector.mask operation around a maskable operation. Returns the
/// vector.mask operation if the mask provided is valid. Otherwise, returns
/// the maskable operation itself.
-Operation *mlir::vector::maskOperation(RewriterBase &rewriter,
- Operation *maskableOp, Value mask) {
+Operation *mlir::vector::maskOperation(OpBuilder &builder,
+ Operation *maskableOp, Value mask,
+ Value passthru) {
if (!mask)
return maskableOp;
- return rewriter.create<MaskOp>(maskableOp->getLoc(),
- maskableOp->getResultTypes(), mask, maskableOp,
- createMaskOpRegion);
+ if (passthru)
+ return builder.create<MaskOp>(maskableOp->getLoc(),
+ maskableOp->getResultTypes(), mask, passthru,
+ maskableOp, createMaskOpRegion);
+ return builder.create<MaskOp>(maskableOp->getLoc(),
+ maskableOp->getResultTypes(), mask, maskableOp,
+ createMaskOpRegion);
+}
+
+/// Creates a vector select operation that picks values from `newValue` or
+/// `passthru` for each result vector lane based on `mask`. This utility is used
+/// to propagate the pass-thru value of vector.mask or for cases where only the
+/// pass-thru value propagation is needed. VP intrinsics do not support
+/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
+/// usually able to match op + select patterns and fold them into a native
+/// target instructions.
+Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
+ Value newValue, Value passthru) {
+ if (!mask)
+ return newValue;
+
+ return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
+ mask, newValue, passthru);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index e7b8cd55df53e..eecf9701e3cab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -151,8 +151,7 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
static std::optional<Value>
createContractArithOp(Location loc, Value x, Value y, Value acc,
vector::CombiningKind kind, PatternRewriter &rewriter,
- bool isInt,
- std::optional<Value> maybeMask = std::nullopt) {
+ bool isInt, Value mask = Value()) {
using vector::CombiningKind;
Value mul;
@@ -171,20 +170,20 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
return std::nullopt;
// Special case for fused multiply-add.
if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
- Operation *fmaOp = rewriter.create<vector::FMAOp>(loc, x, y, acc);
- if (maybeMask.has_value() && maybeMask.value())
- fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value());
- return fmaOp->getResult(0);
+ Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
+ if (mask)
+ // The fma op doesn't need explicit masking. However, fma ops used in
+ // reductions must preserve previous 'acc' values for masked-out lanes.
+ fma = selectPassthru(rewriter, mask, fma, acc);
+ return fma;
}
mul = rewriter.create<arith::MulFOp>(loc, x, y);
}
- assert((!maybeMask.has_value() || !maybeMask.value()) &&
- "Unsupported masked case");
-
if (!acc)
return std::optional<Value>(mul);
- return makeArithReduction(rewriter, loc, kind, mul, acc);
+
+ return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
}
/// Return the positions of the reductions in the given map.
@@ -587,13 +586,17 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d);
Value x =
- rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
+ rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value r = nullptr;
if (acc)
- r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
+ r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
+ Value extrMask;
+ if (mask)
+ extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
+
std::optional<Value> m = createContractArithOp(
- loc, a, op.getRhs(), r, kind, rewriter, isInt, mask);
+ loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
if (!m.has_value())
return failure();
result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
@@ -638,6 +641,7 @@ struct ContractOpToElementwise
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::ParallelArith)
return failure();
+
ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
@@ -1564,8 +1568,7 @@ struct UnrolledOuterProductGenerator
mask = maskableOp.getMaskingOp().getMask();
}
- Value t(Value v) {
- static constexpr std::array<int64_t, 2> perm = {1, 0};
+ Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
if (!v)
return v;
return rewriter.create<vector::TransposeOp>(loc, v, perm);
@@ -1620,7 +1623,8 @@ struct UnrolledOuterProductGenerator
bindDims(rewriter.getContext(), m, n, k);
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
+ t(mask, {2, 0, 1}));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = t(lhs);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index a72ab1515a6d1..0b9b86ade3323 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -418,16 +418,132 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v
// -----
-func.func @masked_vector_contract(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
return %0 : vector<2xf32>
}
-// We can't check for the intermediate 'vector.mask { vector.fma }' state so we
-// just make sure the vector.fma is lowered.
+// CHECK-LABEL: func.func @masked_float_add_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
-// CHECK: llvm.intr.fmuladd
-// CHECK: llvm.select
+// -----
+
+func.func @masked_float_mul_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<mul>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_float_mul_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.mulf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
+
+// -----
+
+func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_float_max_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.maxf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
+
+// -----
+
+func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<minf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_float_min_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.minf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
+
+// -----
+
+func.func @masked_int_add_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_add_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_mul_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<mul>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_mul_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_max_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<maxsi>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_max_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.maxsi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_min_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<minui>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_min_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.minui %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_and_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<and>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_and_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_or_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<or>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_or_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.ori %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
// -----
@@ -2157,17 +2273,3 @@ func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> {
%0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
return %0 : vector<8xf32>
}
-
-// -----
-
-// CHECK-LABEL: func.func @masked_vector_fma(
-// CHECK-SAME: %[[INPUT:.*]]: vector<8xf32>,
-// CHECK-SAME: %[[MASK:.*]]: vector<8xi1>) -> vector<8xf32>
-// CHECK: %[[FMA:.*]] = llvm.intr.fmuladd(%[[INPUT]], %[[INPUT]], %[[INPUT]]) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-// CHECK: llvm.select %[[MASK]], %[[FMA]], %[[INPUT]] : vector<8xi1>, vector<8xf32>
-
-func.func @masked_vector_fma(%a: vector<8xf32>, %m: vector<8xi1>) -> vector<8xf32> {
- %0 = vector.mask %m { vector.fma %a, %a, %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
- return %0 : vector<8xf32>
-}
-
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 1617cb176e12f..6ad8a096df20f 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -76,6 +76,30 @@ func.func @extract_contract2(%arg0: vector<2x3xf32>,
return %0 : vector<2xf32>
}
+// OUTERPRODUCT-LABEL: func.func @masked_extract_contract2(
+// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>,
+// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1>
+// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct
+
+// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1>
+// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct
+
+// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1>
+// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct
+
+func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<2xf32>,
+ %m: vector<2x3xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+ : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: func @extract_contract2_int
// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
@@ -182,6 +206,32 @@ func.func @extract_contract4(%arg0: vector<2x2xf32>,
return %0 : vector<2x2xf32>
}
+// OUTERPRODUCT-LABEL: func.func @masked_extract_contract4(
+// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// OUTERPRODUCT: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
+ %arg1: vector<5x7xf32>,
+ %arg2: vector<3x7xf32>,
+ %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
+ return %0 : vector<3x7xf32>
+}
+
#contraction2d_accesses = [
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>,
@@ -1197,26 +1247,4 @@ func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vec
return %0 : f32
}
-func.func @masked_vector_contract(%arg0: vector<2x3xf32>,
- %arg1: vector<3xf32>,
- %arg2: vector<2xf32>,
- %m: vector<2x3xi1>) -> vector<2xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
- : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
- return %0 : vector<2xf32>
-}
-// OUTERPRODUCT-LABEL: func.func @masked_vector_contract(
-// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
-// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>,
-// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>,
-// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
-// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
-// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1>
-// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct
-
-// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1>
-// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct
-
-// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1>
-// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct
More information about the Mlir-commits
mailing list