[Mlir-commits] [mlir] b37174e - [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (#174590)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 16 19:57:03 PST 2026
Author: Arun Thangamani
Date: 2026-02-17T09:26:59+05:30
New Revision: b37174efc0a5296f3a5b6f4c14d5820772e2d928
URL: https://github.com/llvm/llvm-project/commit/b37174efc0a5296f3a5b6f4c14d5820772e2d928
DIFF: https://github.com/llvm/llvm-project/commit/b37174efc0a5296f3a5b6f4c14d5820772e2d928.diff
LOG: [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (#174590)
This patch shuffles the output of a `bf16` type `non-vnni` packed
`vector.contract` operation (`flat` layout). The output of the
contraction operation is shuffle to match the `flat` layout, before get
stored in the `acc` matrix.
Following this transform schedule, the `vector.contract` will be lowered
to one of the following operations:
- x86vector::DotBF16Op with `B` matrix shuffled to compensate the `flat`
layout (supported as part of this PR), or
- vector.fma with loads + broadcast using `bf16` packed operations
(supported as part of this PR).
Added:
Modified:
mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 2de9a3122cbd9..0f209a0b815e4 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -9,6 +9,9 @@
#ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
#define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include <cstdint>
@@ -26,6 +29,40 @@ namespace x86vector {
bool isInVnniLayout(Operation *op, llvm::ArrayRef<AffineMap> indexingMaps,
std::optional<unsigned> blockingFactor = std::nullopt);
+// Returns true if two contraction ops form a valid pair for VNNI packing.
+// It verifies that both contractions share the appropriate operand, read from
+// the same source buffer, and use constant indices that
diff er by 8 or 16.
+bool validatePairVectorContract(vector::ContractionOp contractOp,
+ vector::ContractionOp pairContOp,
+ bool rhsHasMultipleNonUnitDims,
+ int64_t nonUnitDimValue);
+
+// Walks backward from a value to find its originating vector read-like op
+// (vector.transfer_read or vector.load), following scf.for iter-args but
+// stopping at layout-transforming ops; returns the read op or nullptr.
+Operation *traceToVectorReadLikeParentOperation(Value v);
+
+// Recursively traces a value to find a downstream vector write-like op
+// (vector.transfer_write or vector.store), crossing scf.for/yield but
+// stopping at layout-altering ops. Returns nullptr if no vector writer/store
+// ops or there are multiple users.
+Operation *traceToVectorWriteLikeUserOperation(Value v);
+
+// Packs the accumulators of two flat BF16 vector.contraction ops into a
+// VNNI-packed layout and replaces the original accumulators to enable post-read
+// packing transformations.
+LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA,
+ Operation *opB,
+ vector::ContractionOp contractA,
+ vector::ContractionOp contractB,
+ int64_t nonUnitDimAcc, VectorType accTy);
+
+// Shuffles vectors produced by vector.contraction ops into a flat layout
+// before they are written to memory.
+LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter,
+ Operation *opA, Operation *opB,
+ int64_t nonUnitDimAcc, VectorType accTy);
+
} // namespace x86vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index c60d9b91c18e5..66fbfbd93c8e3 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
@@ -29,7 +30,7 @@ using namespace mlir::x86vector;
// Verifies that the LHS and RHS operands of a vector.contract are load or
// vector.transfer_read operations on a memref source buffer, and checks
// their bounds, dimensions, offsets, and strides.
-static bool validateVectorContractOperands(Value prodOp) {
+static bool validateVectorContractOperands(Value prodOp, bool isVnni) {
Operation *defOp = prodOp.getDefiningOp();
if (!defOp)
return false;
@@ -62,11 +63,12 @@ static bool validateVectorContractOperands(Value prodOp) {
// Return false if the two innermost strides of the memref are not contiguous.
// The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require
// an eight-element tuple of bf16 values to be contiguous.
- if (!llvm::cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(2))
+ int dimsToCheck = isVnni ? 2 : 1;
+ if (!cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(dimsToCheck))
return false;
// Return false if the vnni offset of load or transfer_read is not zero.
- if (getConstantIntValue(indexVals.back()) != 0)
+ if (isVnni && getConstantIntValue(indexVals.back()) != 0)
return false;
return true;
@@ -96,7 +98,8 @@ static bool validateVectorContractOperands(Value prodOp) {
// ```
static SmallVector<memref::SubViewOp>
getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
- ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim) {
+ ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim,
+ bool isVNNI) {
Operation *defOp = prodOp.getDefiningOp();
@@ -122,11 +125,26 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
}
}
- int vnniDimSize = isUnitDim ? 1 : 2;
+ auto one = rewriter.getIndexAttr(1);
+ SmallVector<memref::SubViewOp> subviews;
+
+ if (!isVNNI) {
+ SmallVector<OpFoldResult> strides(indexVals.size(), one);
+ SmallVector<OpFoldResult> sizes(indexVals.size(), one);
+ // Retrive twice the nonUnit dim BF16 element for both even and odd
+ // index elements.
+ if (!isUnitDim)
+ mnDimSize = 2 * mnDimSize;
+ sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
+ auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
+ sizes, strides);
+ subviews.push_back(subview);
+ return subviews;
+ }
+ int vnniDimSize = isUnitDim ? 1 : 2;
auto nonVNNIDimSize = indexVals.size() - 1;
// Create the size and stride offsets.
- auto one = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> strides(indexVals.size(), one);
SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
@@ -139,7 +157,6 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
if (isUnitDim)
indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
- llvm::SmallVector<memref::SubViewOp> subviews;
auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
sizes, strides);
subviews.push_back(subview);
@@ -168,7 +185,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// Implements outer product contraction as a sequence of BF16-packed
// operation even/odd loads and FMA operations.
//
-// For example:
+// For example (VNNI packed):
// ```
// %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
// %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
@@ -183,6 +200,24 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
// return vector.fma %4, %5, %3
// ```
+//
+// For example (Flat layout):
+// ```
+// %1 = vector.load from memref (%m1) -> vector<1x1xbf16>
+// %2 = vector.load from memref (%m2) -> vector<1x8xbf16>
+// %3 = vector.contract %1, %2, %arg1
+// %4 = vector.load from memref (%m2) -> vector<1x8xbf16>
+// %5 = vector.contract %1, %4, %arg2
+// scf.yield %3, %4
+// ```
+// to
+// ```
+// %1 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
+// %2 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
+// %3 = vector.fma %1, %2, %arg1
+// %4 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
+// %5 = vector.fma %1, %4, %arg2
+// scf.yield %3, %5
struct VectorContractBF16ToFMA
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -202,11 +237,9 @@ struct VectorContractBF16ToFMA
return rewriter.notifyMatchFailure(contractOp,
"Only BF16 lowering is supported.");
- if (!isInVnniLayout(contractOp.getOperation(),
- contractOp.getIndexingMapsArray(),
- /*blockingFactor=*/2))
- return rewriter.notifyMatchFailure(contractOp,
- "Input matrices not in VNNI format.");
+ bool isVnni = isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(),
+ /*blockingFactor=*/2);
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
@@ -216,6 +249,14 @@ struct VectorContractBF16ToFMA
return rewriter.notifyMatchFailure(
contractOp, "Only F32 acumulation supported for BF16 type.");
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B should be a non-unit dim in acc.");
+
ArrayRef<int64_t> lhsShape = lhsTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimLhs;
llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
@@ -227,35 +268,37 @@ struct VectorContractBF16ToFMA
llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
[](int64_t dim) { return dim != 1; });
- if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
+ if (isVnni && (nonUnitDimLhs.size() - 1) > 0 &&
+ (nonUnitDimRhs.size() - 1) > 0)
return rewriter.notifyMatchFailure(contractOp,
"Excepts unit dimensions for either "
"LHS or RHS shape other than VNNI.");
- if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
+ if (isVnni && (nonUnitDimLhs.size() - 1) != 1 &&
+ (nonUnitDimRhs.size() - 1) != 1)
return rewriter.notifyMatchFailure(
contractOp,
"Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
- ArrayRef<int64_t> accShape = accTy.getShape();
- llvm::SmallVector<int64_t> nonUnitDimAcc;
- llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
- [](int64_t dim) { return dim != 1; });
- if (nonUnitDimAcc.size() != 1)
+ if (!isVnni && nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Excepts unit dimensions for either "
+ "LHS or RHS shape.");
+
+ if (!isVnni && nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
return rewriter.notifyMatchFailure(
- contractOp, "A or B should be a non-unit dim in acc.");
+ contractOp,
+ "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
// Non-unit dimensions should match the vector length of BF16.
- unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
- : nonUnitDimRhs.front();
- if (nonUnitDim != 4 && nonUnitDim != 8 &&
- !(nonUnitDimAcc.front() == nonUnitDim))
+ unsigned int nonUnitDim = nonUnitDimAcc.front();
+ if (nonUnitDim != 4 && nonUnitDim != 8)
return rewriter.notifyMatchFailure(
contractOp, "BF16 packed load operation expects non-unit (LHR or "
"RHS) dim and acc dim of size 4/8.");
- if (!validateVectorContractOperands(contractOp.getLhs()) ||
- !validateVectorContractOperands(contractOp.getRhs())) {
+ if (!validateVectorContractOperands(contractOp.getLhs(), isVnni) ||
+ !validateVectorContractOperands(contractOp.getRhs(), isVnni)) {
return rewriter.notifyMatchFailure(
contractOp, "The LHS or RHS is in an invalid format. Either it has "
"false in-bounds, "
@@ -273,7 +316,12 @@ struct VectorContractBF16ToFMA
// vector<1x8x2xbf16>, we create two subview for the LHS and one subview
// for the RHS. In the opposite case (non-unit dimension on the LHS), we
// do vice-versa.
- bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
+
+ bool rhsHasMultipleNonUnitDims = nonUnitDimRhs.size() > 0;
+ if (isVnni) {
+ rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
+ }
+
// Select which operand is "unit" and which is "non-unit".
Value unitSrc =
rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
@@ -283,12 +331,74 @@ struct VectorContractBF16ToFMA
ArrayRef<int64_t> nonUnitDimShape =
rhsHasMultipleNonUnitDims ? rhsShape : lhsShape;
+ // Get the pair vector.contract operation. The pair is decided on:
+ // (1) - the unitDim operand Lhs or Rhs should be same,
+ // (2) - the defining source memref should be same for nonUnitDim
+ // operation, (3) - the nonUnit dim offset
diff erence between the
+ // vector.contracts should be 8.
+ vector::ContractionOp pairContractOp;
+ if (!isVnni) {
+ Operation *nextOp = contractOp;
+ while ((nextOp = nextOp->getNextNode())) {
+ auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
+
+ if (!contOp)
+ continue;
+
+ if (validatePairVectorContract(contractOp, contOp,
+ rhsHasMultipleNonUnitDims,
+ nonUnitDimAcc.front())) {
+ pairContractOp = contOp;
+ break;
+ }
+ }
+
+ if (!pairContractOp)
+ return failure();
+
+ Operation *accReadOp0 =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+ Operation *accReadOp1 =
+ traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+ // Iterate down to find the users of contact operations until it is store
+ // or transfer_write.
+ Operation *resultWriteOp0 =
+ traceToVectorWriteLikeUserOperation(contractOp.getResult());
+ Operation *resultWriteOp1 =
+ traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+ if (!accReadOp0 || !accReadOp1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Operand doesn't have load or transfer_read as its parent op");
+
+ if (!resultWriteOp0 || !resultWriteOp1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "The use of contract operations are neither vector.store "
+ "or transfer_write or has multiple users");
+
+ if (contractOp->getBlock() == accReadOp1->getBlock() &&
+ contractOp->isBeforeInBlock(accReadOp1))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The load/read operation of pair contract operation is "
+ "after the contractOp");
+
+ if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
+ resultWriteOp0->isBeforeInBlock(pairContractOp)) {
+ return rewriter.notifyMatchFailure(
+ contractOp, "The store/write operation of contract operation is "
+ "before the pair contract operation");
+ }
+ }
+
// Build subviews.
- auto unitDimSubview = getSubviewFromVectorInput(loc, rewriter, unitSrc,
- nonUnitDimShape, true);
+ auto unitDimSubview = getSubviewFromVectorInput(
+ loc, rewriter, unitSrc, nonUnitDimShape, true, isVnni);
auto nonUnitDimSubview = getSubviewFromVectorInput(
- loc, rewriter, nonUnitSrc, nonUnitDimShape, false);
+ loc, rewriter, nonUnitSrc, nonUnitDimShape, false, isVnni);
auto castAcc = vector::ShapeCastOp::create(
rewriter, loc,
@@ -297,6 +407,79 @@ struct VectorContractBF16ToFMA
VectorType dstType =
VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+ if (!isVnni) {
+
+ // Validate and shuffle the accumulator
+ Operation *accReadOp0 =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+ Operation *accReadOp1 =
+ traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+ // Iterate down to find the users of contact operations until it is store
+ // or transfer_write.
+ Operation *resultWriteOp0 =
+ traceToVectorWriteLikeUserOperation(contractOp.getResult());
+ Operation *resultWriteOp1 =
+ traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+ // Shuffle the accumulators of the contract operations.
+ LogicalResult readShuffle =
+ shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+ pairContractOp, nonUnitDim, accTy);
+
+ if (failed(readShuffle))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Accumulator read is not by transfer_read or load");
+
+ // Shuffle the output of contract operations before its use.
+ LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
+ rewriter, resultWriteOp0, resultWriteOp1, nonUnitDim, accTy);
+
+ if (failed(writeShuffle))
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Write to accumulator is not by transfer_write or store");
+
+ rewriter.setInsertionPoint(contractOp);
+ castAcc = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+
+ auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
+ rewriter, loc, dstType, unitDimSubview[0]);
+ auto loadEvenIdxElementF32 =
+ x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+ nonUnitDimSubview[0]);
+ auto evenIdxFMA =
+ vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
+ loadEvenIdxElementF32, castAcc);
+ auto castEvenFma =
+ vector::ShapeCastOp::create(rewriter, loc, accTy, evenIdxFMA);
+ rewriter.replaceOp(contractOp, castEvenFma);
+
+ rewriter.setInsertionPoint(pairContractOp);
+ auto pairContOpLoc = pairContractOp.getLoc();
+ VectorType accTyPairCont =
+ dyn_cast<VectorType>(pairContractOp.getAccType());
+ auto castAccPairCont = vector::ShapeCastOp::create(
+ rewriter, pairContOpLoc,
+ VectorType::get(nonUnitDimAcc.front(),
+ accTyPairCont.getElementType()),
+ pairContractOp.getAcc());
+
+ auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+ rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
+ auto oddIdxFMA = vector::FMAOp::create(
+ rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
+ loadOddIdxElementF32, castAccPairCont);
+ auto castOddFma = vector::ShapeCastOp::create(rewriter, pairContOpLoc,
+ accTyPairCont, oddIdxFMA);
+ rewriter.replaceOp(pairContractOp, castOddFma);
+
+ return success();
+ }
+
// Load, broadcast, and do FMA for odd indexed BF16 elements.
auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]);
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 89aa53307b95d..56f214171af96 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
@@ -27,10 +28,81 @@ using namespace mlir::x86vector;
namespace {
+// Returns true if the A or B matrix vector is packed (shuffled) to
+// VNNI layout, already.
+static bool isNonUnitDimOperandShuffled(Value nonUnitDimOperand) {
+ if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
+ if (isa<vector::ShuffleOp>(defOp))
+ return true;
+
+ if (isa<vector::ShapeCastOp>(defOp)) {
+ Operation *defOpShpCst = defOp->getOperand(0).getDefiningOp();
+ if (isa<vector::ShuffleOp>(defOpShpCst))
+ return true;
+ }
+ }
+
+ return false;
+}
+
+static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
+ mlir::Operation *targetContract,
+ mlir::PatternRewriter &rewriter) {
+ for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
+
+ mlir::Operation *user = use.getOwner();
+ if (mlir::isa<mlir::vector::ContractionOp>(user) ||
+ mlir::isa<mlir::scf::ForOp>(user)) {
+ use.set(newVal);
+ }
+ }
+}
+
+// Function to convert the flat layout A or B matrix vector<32xbf16>
+// into VNNI packed layout using the vpunpack operations
+static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
+ mlir::Operation *opA,
+ mlir::Operation *opB,
+ mlir::vector::ContractionOp contractA,
+ mlir::vector::ContractionOp contractB,
+ int64_t nonUnitDimAcc,
+ mlir::VectorType Ty) {
+ mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+ rewriter.setInsertionPointAfter(insertAfter);
+ mlir::Location loc = insertAfter->getLoc();
+
+ auto elemTy = Ty.getElementType();
+ auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+ auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+ opA->getResult(0));
+ auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+ opB->getResult(0));
+
+ static constexpr int64_t maskLo[] = {
+ 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
+ 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
+ static constexpr int64_t maskHi[] = {
+ 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
+ 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
+
+ auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, maskLo);
+ auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, maskHi);
+
+ auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
+ auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
+
+ rewriteUses(opA->getResult(0), newA.getResult(), contractA, rewriter);
+ rewriteUses(opB->getResult(0), newB.getResult(), contractB, rewriter);
+}
+
// Implements packed type outer product contraction as a sequence
// of broadcast and packed dot-product operations.
//
-// For example - for F32 type:
+// For example - for bf16 type (VNNI):
// ```
// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
// ```
@@ -39,6 +111,25 @@ namespace {
// vector.broadcast %lhs to <32xbf16>
// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
// ```
+//
+// For example - for bf16 type (Flat layout):
+// ```
+// %1 = vector.load -> <2x16xbf16>
+// %2 = vector.load -> <2x16xbf16>
+// vector.contract <1x2xbf16>, %1 into <1x16xf32>
+// vector.contract <1x2xbf16>, %2 into <1x16xf32>
+// ```
+// to
+// ```
+// %1 = vector.load -> <2x16xbf16>
+// %2 = vector.load -> <2x16xbf16>
+// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
+// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
+// vector.broadcast %lhs to <32xbf16>
+// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
+// vector.broadcast %lhs to <32xbf16>
+// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
+// ```
struct VectorContractToPackedTypeDotProduct
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -57,10 +148,40 @@ struct VectorContractToPackedTypeDotProduct
contractOp, "Only BF16/Int8 lowering is supported.");
unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
- if (!isInVnniLayout(contractOp.getOperation(),
- contractOp.getIndexingMapsArray(), blockingFactor))
- return rewriter.notifyMatchFailure(contractOp,
- "Input matrices not in VNNI format.");
+ bool isVnni =
+ isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(), blockingFactor);
+
+ if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
+ return failure();
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B should be a non-unit dim in acc.");
+
+ int64_t nonUnitDimValue = nonUnitDimAcc.front();
+ // Non-unit dimensions should match the vector length of BF16 or Int8
+ // dot-product.
+ if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
+ nonUnitDimValue != 8 && nonUnitDimValue != 16)
+ return rewriter.notifyMatchFailure(
+ contractOp, "BF16 dot-product operation expects non-unit (LHR or "
+ "RHS) dim and acc dim of size 4/8/16.");
+
+ if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
+ nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
+ nonUnitDimAcc.front() == nonUnitDimValue)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Int8 dot-product operation expects non-unit (LHR or "
+ "RHS) dim and acc dim of size 4/8/16.");
ArrayRef<int64_t> lhsShape = lhsTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimLhs;
@@ -76,16 +197,20 @@ struct VectorContractToPackedTypeDotProduct
if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
return rewriter.notifyMatchFailure(contractOp,
"Excepts unit dimensions for either "
- "LHS or RHS shape other than VNNI.");
+ "LHS or RHS shape.");
if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
return rewriter.notifyMatchFailure(
contractOp,
"Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
- VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
- if (!accTy)
- return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+ bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
+ int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
+ : nonUnitDimRhs.front();
+
+ if (!isVnni && (extraFlatDim != blockingFactor))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The K or reduction dim for flat layout should be 2.");
if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
(lhsTy.getElementType().isSignlessInteger(8) &&
@@ -94,127 +219,188 @@ struct VectorContractToPackedTypeDotProduct
"Only F32 for BF16 or Int32 for Int8 "
"accumulation type is supported.");
- ArrayRef<int64_t> accShape = accTy.getShape();
- llvm::SmallVector<int64_t> nonUnitDimAcc;
- llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
- [](int64_t dim) { return dim != 1; });
- if (nonUnitDimAcc.size() != 1)
- return rewriter.notifyMatchFailure(
- contractOp, "A or B should be a non-unit dim in acc.");
+ Value unitDimOperand =
+ rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
+ Value nonUnitDimOperand =
+ rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
- // Non-unit dimensions should match the vector length of BF16 or Int8
- // dot-product.
- unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
- : nonUnitDimRhs.front();
- if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
- nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
- return rewriter.notifyMatchFailure(
- contractOp, "BF16 dot-product operation expects non-unit (LHR or "
- "RHS) dim and acc dim of size 4/8/16.");
+ // If the A or B matrix vector of the contact operation is not packed, then
+ // find it's pair contract operation and pack (shuffle) them to VNNI packed.
+ if (!isVnni) {
+ vector::ContractionOp pairContractOp;
+ Operation *nextOp = contractOp;
+ while ((nextOp = nextOp->getNextNode())) {
+ auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
- if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
- nonUnitDim != 8 && nonUnitDim != 16 &&
- nonUnitDimAcc.front() == nonUnitDim)
- return rewriter.notifyMatchFailure(
- contractOp, "Int8 dot-product operation expects non-unit (LHR or "
- "RHS) dim and acc dim of size 4/8/16.");
+ if (!contOp)
+ continue;
+ if (validatePairVectorContract(contractOp, contOp,
+ rhsHasMultipleNonUnitDims,
+ nonUnitDimValue)) {
+ pairContractOp = contOp;
+ break;
+ }
+ }
+
+ // If the accumulators are shuffled we get nullptr else the
+ // transfer_read or load operations.
+ Operation *accRead =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+ if (!pairContractOp &&
+ (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Could not find a contract pair");
+
+ // Validate and shuffle the accumulator
+ if (accRead) {
+ // Trace back to the load or transfer_read operations of the contract
+ // accumulators.
+ Operation *accReadOp0 =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+ Operation *accReadOp1 =
+ traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+ // Iterate down to find the users of contact operations until it is
+ // store or transfer_write.
+ Operation *resultWriteOp0 =
+ traceToVectorWriteLikeUserOperation(contractOp.getResult());
+ Operation *resultWriteOp1 =
+ traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+ if (!accReadOp0 || !accReadOp1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Operands doesn't have load or transfer_read as it's parent op");
+
+ if (!resultWriteOp0 || !resultWriteOp1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "The use of contract operations are neither vector.store "
+ "or transfer_write or has multiple users.");
+
+ if (contractOp->getBlock() == accReadOp1->getBlock() &&
+ contractOp->isBeforeInBlock(accReadOp1))
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "The load/read operation of pair contract operation is "
+ "after the contractOp");
+
+ if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
+ resultWriteOp0->isBeforeInBlock(pairContractOp))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The store/write operation of contract operation is "
+ "before the pair contract operation");
+ // Shuffle the accumulators of the contract operations.
+ LogicalResult readShuffle =
+ shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+ pairContractOp, nonUnitDimValue, accTy);
+
+ if (failed(readShuffle))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Accumulator read is not by transfer_read or load");
+
+ // Shuffle the output of contract operations before it's use.
+ LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
+ rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
+
+ if (failed(writeShuffle))
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Write to accumulator is not by transfer_write or store");
+ }
+
+ if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
+ Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
+ ? pairContractOp.getRhs()
+ : pairContractOp.getLhs();
+
+ // Get the non-packed A or B matrix's vector<32xbf16> elements.
+ Operation *nonUnitDimReadOp =
+ traceToVectorReadLikeParentOperation(nonUnitDimOperand);
+ Operation *nonUnitDimReadOpPairContract =
+ traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
+
+ if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Could not find a valid contract pair");
+
+ if (contractOp->getBlock() ==
+ nonUnitDimReadOpPairContract->getBlock() &&
+ contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "The load/read operation of pair contract operation is "
+ "after the contractOp");
+
+ VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
+ ? contractOp.getRhsType()
+ : contractOp.getLhsType();
+
+ packNonUnitDimOperandToVNNI(
+ rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
+ contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
+ nonUnitDimTy);
+
+ nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
+ : contractOp.getLhs();
+ }
+ }
+
+ rewriter.setInsertionPoint(contractOp);
auto loc = contractOp.getLoc();
auto castAcc = vector::ShapeCastOp::create(
rewriter, loc,
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
contractOp.getAcc());
+ VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
+ ? contractOp.getRhsType()
+ : contractOp.getLhsType();
+ VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
+ : contractOp.getRhsType();
+
Value dp;
- // Broadcast the unit-dimension LHS or RHS to match the vector length of the
- // corresponding non-unit dimension on the other operand. For example,
- // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>,
- // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit
- // dimension on the LHS), we broadcast the RHS instead.
- if ((nonUnitDimRhs.size() - 1) > 0) {
- auto castRhs = vector::ShapeCastOp::create(
- rewriter, loc,
- VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
- rhsTy.getElementType()),
- contractOp.getRhs());
- auto castLhs = vector::ShapeCastOp::create(
- rewriter, loc,
- VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
- contractOp.getLhs());
- auto bitcastLhs = vector::BitCastOp::create(
- rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
- castLhs);
- auto broadcastLhs = vector::BroadcastOp::create(
- rewriter, loc,
- VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)),
- bitcastLhs);
- auto bitcastLhsPkType = vector::BitCastOp::create(
- rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
+ auto castNonUnitDim = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(blockingFactor * nonUnitDimValue,
+ nonUnitDimTy.getElementType()),
+ nonUnitDimOperand);
- if (lhsTy.getElementType().isBF16()) {
- dp = x86vector::DotBF16Op::create(
- rewriter, loc,
- VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()),
- castAcc, bitcastLhsPkType, castRhs);
- }
+ auto castUnitDim = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(blockingFactor, unitDimTy.getElementType()),
+ unitDimOperand);
+ auto bitcastUnitDim = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+ castUnitDim);
+ auto broadcastUnitDim = vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get({nonUnitDimValue}, rewriter.getIntegerType(32)),
+ bitcastUnitDim);
+ auto bitcastUnitDimPkType = vector::BitCastOp::create(
+ rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
- if (lhsTy.getElementType().isSignlessInteger(8)) {
- if (nonUnitDimAcc.front() == 16) {
- dp = x86vector::AVX10DotInt8Op::create(
- rewriter, loc,
- VectorType::get(nonUnitDimRhs.front(),
- rewriter.getIntegerType(32)),
- castAcc, bitcastLhsPkType, castRhs);
- } else {
- dp = x86vector::DotInt8Op::create(
- rewriter, loc,
- VectorType::get(nonUnitDimRhs.front(),
- rewriter.getIntegerType(32)),
- castAcc, bitcastLhsPkType, castRhs);
- }
- }
- } else {
- auto castLhs = vector::ShapeCastOp::create(
+ if (lhsTy.getElementType().isBF16()) {
+ dp = x86vector::DotBF16Op::create(
rewriter, loc,
- VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
- lhsTy.getElementType()),
- contractOp.getLhs());
- auto castRhs = vector::ShapeCastOp::create(
- rewriter, loc,
- VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
- contractOp.getRhs());
- auto bitcastRhs = vector::BitCastOp::create(
- rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
- castRhs);
- auto broadcastRhs = vector::BroadcastOp::create(
- rewriter, loc,
- VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)),
- bitcastRhs);
- auto bitcastRhsPkType = vector::BitCastOp::create(
- rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
+ VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
+ bitcastUnitDimPkType, castNonUnitDim);
+ }
- if (lhsTy.getElementType().isBF16()) {
- dp = x86vector::DotBF16Op::create(
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ if (nonUnitDimAcc.front() == 16) {
+ dp = x86vector::AVX10DotInt8Op::create(
rewriter, loc,
- VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()),
- castAcc, castLhs, bitcastRhsPkType);
- }
-
- if (lhsTy.getElementType().isSignlessInteger(8)) {
- if (nonUnitDimAcc.front() == 16) {
- dp = x86vector::AVX10DotInt8Op::create(
- rewriter, loc,
- VectorType::get(nonUnitDimLhs.front(),
- rewriter.getIntegerType(32)),
- castAcc, castLhs, bitcastRhsPkType);
- } else {
- dp = x86vector::DotInt8Op::create(
- rewriter, loc,
- VectorType::get(nonUnitDimLhs.front(),
- rewriter.getIntegerType(32)),
- castAcc, castLhs, bitcastRhsPkType);
- }
+ VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
+ castAcc, bitcastUnitDimPkType, castNonUnitDim);
+ } else {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
+ castAcc, bitcastUnitDimPkType, castNonUnitDim);
}
}
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index ccb2e92fdd9e2..98147b3d884a9 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -10,11 +10,18 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include <cassert>
+
namespace mlir {
namespace x86vector {
@@ -104,5 +111,304 @@ bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
return true;
}
+struct ShuffleMasks {
+ llvm::ArrayRef<int64_t> maskLo;
+ llvm::ArrayRef<int64_t> maskHi;
+};
+
+inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
+ // We only support these two layouts for now.
+ assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
+ "Unsupported nonUnitDimAcc value");
+ // Do interleaving between two <8xf32> targeting AVX2.
+ static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
+ static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
+
+ // Shuffle two <16xf32> as below targeting AVX512.
+ static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
+ 4, 5, 6, 7, 20, 21, 22, 23};
+ static constexpr int64_t maskHi16[] = {8, 9, 10, 11, 24, 25, 26, 27,
+ 12, 13, 14, 15, 28, 29, 30, 31};
+
+ if (nonUnitDimAcc == 16)
+ return {maskLo16, maskHi16};
+
+ return {maskLo8, maskHi8};
+}
+
+// This function walks backward from a value to locate its originating
+// vector read-like operation (`vector.transfer_read` or `vector.load`).
+// It follows simple forwarding through unary ops and across `scf.for`
+// loop iter-arguments, while stopping if layout-transforming ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the read-like defining operation or `nullptr` if no valid source
+// is found.
+Operation *traceToVectorReadLikeParentOperation(Value v) {
+ while (true) {
+ // Case 1: Value defined by an operation
+ if (Operation *defOp = v.getDefiningOp()) {
+ if (isa<vector::TransferReadOp, vector::LoadOp>(defOp))
+ return defOp;
+
+ return nullptr;
+ }
+
+ // Case 2: BlockArgument (scf.for iter_arg)
+ if (auto barg = dyn_cast<BlockArgument>(v)) {
+ auto *parentOp = barg.getOwner()->getParentOp();
+
+ if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+ unsigned argNum = barg.getArgNumber();
+
+ // arg0 = induction variable (not an iter_arg)
+ if (argNum == 0)
+ return nullptr;
+
+ unsigned iterIdx = argNum - 1;
+ v = forOp.getInitArgs()[iterIdx];
+ continue;
+ }
+
+ return nullptr;
+ }
+
+ return nullptr;
+ }
+}
+
+// This function recursively traces a value through its uses to find
+// a downstream vector write-like operation (`vector.transfer_write`
+// or `vector.store`). It transparently follows values across `scf.for`
+// and `scf.yield` boundaries while stopping if layout-altering ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the matching write-like user. Returns `nullptr` if none is found or
+// the value has multiple users.
+Operation *traceToVectorWriteLikeUserOperation(Value v) {
+
+ if (v.getNumUses() > 1)
+ return nullptr;
+
+ for (OpOperand &use : v.getUses()) {
+ Operation *user = use.getOwner();
+
+ // --- TERMINAL OPS ---
+ if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user))
+ return user;
+
+ if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user))
+ return nullptr;
+
+ // --- SCF YIELD ---
+ if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+ Operation *parent = yield->getParentOp();
+ unsigned idx = use.getOperandNumber();
+ if (auto *res =
+ traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
+ return res;
+ continue;
+ }
+
+ // --- SCF FOR ---
+ if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+ unsigned idx = use.getOperandNumber();
+ if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
+ return res;
+ continue;
+ }
+
+ // --- GENERIC CASE ---
+ for (Value res : user->getResults()) {
+ if (auto *found = traceToVectorWriteLikeUserOperation(res))
+ return found;
+ }
+ }
+
+ return nullptr;
+}
+
+// This function packs the accumulator of two flat BF16 vector.contract
+// operations into VNNI packed and are then replaced in their respective
+// contraction ops, enabling post-read layout or packing transformations.
+// TODO: replace all use with the packed value along with contration
+// and for op.
+LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA,
+ Operation *opB,
+ vector::ContractionOp contractA,
+ vector::ContractionOp contractB,
+ int64_t nonUnitDimAcc, VectorType accTy) {
+
+ if (!isa<vector::TransferReadOp, vector::LoadOp>(opA) ||
+ !isa<vector::TransferReadOp, vector::LoadOp>(opB)) {
+ return failure();
+ }
+
+ Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+ rewriter.setInsertionPointAfter(insertAfter);
+ Location loc = insertAfter->getLoc();
+
+ auto elemTy = accTy.getElementType();
+ auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
+
+ auto castA =
+ vector::ShapeCastOp::create(rewriter, loc, flatTy, opA->getResult(0));
+ auto castB =
+ vector::ShapeCastOp::create(rewriter, loc, flatTy, opB->getResult(0));
+
+ auto masks = getShuffleMasks(nonUnitDimAcc);
+
+ auto shuffleLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, masks.maskLo);
+ auto shuffleHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, masks.maskHi);
+
+ auto newAccA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
+ auto newAccB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
+
+ rewriter.replaceUsesWithIf(
+ opA->getResult(0), newAccA.getResult(), [&](OpOperand &use) {
+ return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
+ });
+
+ rewriter.replaceUsesWithIf(
+ opB->getResult(0), newAccB.getResult(), [&](OpOperand &use) {
+ return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
+ });
+
+ return success();
+}
+
+// This function shuffles the vectors written by vector.contract operation
+// as a flat layout structure before they are stored.
+LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter,
+ Operation *opA, Operation *opB,
+ int64_t nonUnitDimAcc,
+ VectorType accTy) {
+ // Helper to extract vector operand from write-like ops
+ auto getWrittenVector = [](Operation *op) -> Value {
+ if (auto write = dyn_cast<vector::TransferWriteOp>(op))
+ return write.getVector();
+ if (auto store = dyn_cast<vector::StoreOp>(op))
+ return store.getValueToStore();
+ return nullptr;
+ };
+
+ Value vecA = getWrittenVector(opA);
+ Value vecB = getWrittenVector(opB);
+
+ if (!vecA || !vecB)
+ return failure();
+
+ // Decide insertion point and location
+ Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
+
+ rewriter.setInsertionPoint(insertBefore);
+ Location loc = insertBefore->getLoc();
+
+ auto elemTy = accTy.getElementType();
+ auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
+
+ // Flatten vectors
+ auto castA = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
+ auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
+
+ // TODO: derive shuffle masks instead of hard-coding
+ auto masks = getShuffleMasks(nonUnitDimAcc);
+
+ auto shuffledLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, masks.maskLo);
+ auto shuffledHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, masks.maskHi);
+
+ // Cast back to accumulator type
+ auto newVecA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
+ auto newVecB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
+
+ // Update write operands in place
+ opA->setOperand(0, newVecA.getResult());
+ opB->setOperand(0, newVecB.getResult());
+
+ return success();
+}
+
+// Return true if vector.contract operations matches on below conditions:
+// (1) - the unitDim operand Lhs or Rhs should be same,
+// (2) - the defining source memref should be same for nonUnitDim
+// operation,
+// (3) - the nonUnit dim offset
diff erence between the
+// vector.contracts should be 8 or 16.
+bool validatePairVectorContract(vector::ContractionOp contractOp,
+ vector::ContractionOp pairContOp,
+ bool rhsHasMultipleNonUnitDims,
+ int64_t nonUnitDimValue) {
+ if (rhsHasMultipleNonUnitDims &&
+ !(contractOp.getLhs() == pairContOp.getLhs()))
+ return false;
+
+ if (!rhsHasMultipleNonUnitDims &&
+ !(contractOp.getRhs() == pairContOp.getRhs()))
+ return false;
+
+ auto nonUnitOperand =
+ rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
+ auto nonUnitOperandPairContOp =
+ rhsHasMultipleNonUnitDims ? pairContOp.getRhs() : pairContOp.getLhs();
+
+ Value srcBuff;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(nonUnitOperand.getDefiningOp())
+ .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
+ srcBuff = readOp.getOperand(0);
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ })
+ .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
+ srcBuff = op.getSource();
+ indexVals.clear();
+ });
+
+ Value srcBuffPairContOp;
+ SmallVector<OpFoldResult> indexValsPairContOp;
+ llvm::TypeSwitch<Operation *>(nonUnitOperandPairContOp.getDefiningOp())
+ .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
+ srcBuffPairContOp = readOp.getOperand(0);
+ indexValsPairContOp = SmallVector<OpFoldResult>(
+ readOp.getIndices().begin(), readOp.getIndices().end());
+ })
+ .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
+ srcBuffPairContOp = op.getSource();
+ indexVals.clear();
+ });
+
+ if (!srcBuff || !srcBuffPairContOp)
+ return false;
+
+ auto shuffleLw = srcBuff.getDefiningOp<vector::ShuffleOp>();
+ auto shuffleHw = srcBuffPairContOp.getDefiningOp<vector::ShuffleOp>();
+
+ if (shuffleLw && shuffleHw)
+ return shuffleLw.getV1() == shuffleHw.getV1() &&
+ shuffleLw.getV2() == shuffleHw.getV2();
+
+ if (srcBuff != srcBuffPairContOp)
+ return false;
+
+ for (size_t i = 0; i < indexVals.size(); i++) {
+ auto v0 = getConstantIntValue(indexVals[i]);
+ auto v1 = getConstantIntValue(indexValsPairContOp[i]);
+
+ if (!v0 || !v1)
+ return false;
+
+ if (*v1 == *v0)
+ continue;
+
+ if ((*v1 - *v0) != nonUnitDimValue)
+ return false;
+ }
+
+ return true;
+}
+
} // namespace x86vector
} // namespace mlir
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index e7a70429490d1..7def26de7c75b 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -233,6 +233,317 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_to_fma_flat_layout(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %3 = vector.transfer_read %arg1[%c0, %c8], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %arg2[%c0, %c8], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_to_fma_flat_layout
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
+// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
+// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_to_fma_flat_layout_load(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.load %arg0[%c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.load %arg1[%c0, %c8] :
+ !memrefB, !vecB
+ %4 = vector.load %arg2[%c0, %c0] :
+ !memrefC, !vecC
+ %5 = vector.load %arg2[%c0, %c8] :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+ vector.store %7, %arg2[%c0, %c8] : !memrefC, !vecC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_to_fma_flat_layout_load
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
+// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
+// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1xbf16>
+!vecB = vector<1x1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x1x1xbf16, strided<[2048, 32, 1], offset: ?>>
+!memrefB = memref<1x1x16xbf16, strided<[2048, 64, 1], offset: ?>>
+!memrefC = memref<1x16xf32, strided<[64, 1], offset: ?>>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @matmul_to_fma_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1: memref<16x32x64xbf16>,
+ %arg2: memref<64x64xf32>) -> memref<64x64xf32> {
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c1 {
+ scf.for %arg4 = %c0 to %c64 step %c16 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [1, 16] [1, 1] :
+ memref<64x64xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c8], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %4:2 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+ %5:2 = scf.for %arg8 = %c0 to %c32 step %c1 iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg8] [1, 1, 1] [1, 1, 1] :
+ memref<16x64x32xbf16> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg8, %arg4] [1, 1, 16] [1, 1, 1] :
+ memref<16x32x64xbf16> to !memrefB
+
+ %6 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1
+ {in_bounds = [true, true, true]} : !memrefA, !vecA
+ %7 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1
+ {in_bounds = [true, true, true]} : !memrefB, !vecB
+ %8 = vector.transfer_read %subview_1[%c0, %c0, %c8], %1
+ {in_bounds = [true, true, true]} : !memrefB, !vecB
+
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %6, %7, %arg9 {unroll_shape = array<i64: 1, 1, 8, 1>} : !vecA, !vecB into !vecC
+ %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %6, %8, %arg10 {unroll_shape = array<i64: 1, 1, 8, 1>} : !vecA, !vecB into !vecC
+
+ scf.yield %9, %10 : !vecC, !vecC
+ }
+ scf.yield %5#0, %5#1 : !vecC, !vecC
+ }
+
+ vector.transfer_write %4#1, %subview[%c0, %c8] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+
+ return %arg2 : memref<64x64xf32>
+}
+
+// CHECK-LABEL: @matmul_to_fma_flat_layout_loop
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: scf.yield
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<8x1xbf16>
+!vecB = vector<1x1xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<32x1xbf16>
+!memrefB = memref<1x4xbf16>
+!memrefC = memref<32x4xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_to_fma_flat_layout_bcstB(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+
+ %3 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %arg2[%c8, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %2, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c8, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_to_fma_flat_layout_bcstB
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x4xbf16> to memref<1x1xbf16, {{.*}}>
+// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<32x1xbf16> to memref<16x1xbf16, {{.*}}>
+// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[4, 1], offset: ?>>
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x2xbf16>
!vecB = vector<1x8x2xbf16>
!vecC = vector<1x8xf32>
@@ -372,6 +683,75 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_multiple_vc_users_flat(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.load %arg0[%c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.load %arg1[%c0, %c8] :
+ !memrefB, !vecB
+ %4 = vector.load %arg2[%c0, %c0] :
+ !memrefC, !vecC
+ %5 = vector.load %arg2[%c0, %c8] :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+ vector.store %7, %arg2[%c0, %c8] : !memrefC, !vecC
+
+ %8 = arith.addf %6, %7 : !vecC
+ vector.store %8, %arg2[%c0, %c16] : !memrefC, !vecC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_multiple_vc_users_flat
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.fma
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x2xbf16>
!vecB = vector<1x8x2xbf16>
!vecC = vector<1x8xf32>
@@ -413,6 +793,276 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_offset_
diff _is_not_8(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.load %arg0[%c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.load %arg1[%c0, %c16] :
+ !memrefB, !vecB
+ %4 = vector.load %arg2[%c0, %c0] :
+ !memrefC, !vecC
+ %5 = vector.load %arg2[%c0, %c16] :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+ vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_offset_
diff _is_not_8
+// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
+// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
+// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_vector_contracts_not_in_order(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %3 = vector.transfer_read %arg1[%c0, %c8], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %arg2[%c0, %c8], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_vector_contracts_not_in_order
+// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
+// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
+// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<8x1xbf16>
+!vecB = vector<1x1xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<32x1xbf16>
+!memrefB = memref<1x4xbf16>
+!memrefC = memref<32x4xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_flat_layout_dynamic_index(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg0[%arg3, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+
+ %3 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %arg2[%c8, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %2, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c8, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_layout_dynamic_index
+// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
+// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
+// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<8x2xbf16>
+!vecB = vector<2x1xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<32x2xbf16>
+!memrefB = memref<2x4xbf16>
+!memrefC = memref<32x4xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_non_unit_K_dim(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+
+ %3 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %arg2[%c8, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %2, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c8, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_non_unit_K_dim
+// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
+// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
+// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x1x2xbf16>
!vecB = vector<1x1x16x2xbf16>
!vecC = vector<1x1x16xf32>
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index e26a575e2bc90..36af0f9c171bd 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -309,13 +309,13 @@ module attributes {transform.with_named_sequence} {
// -----
-!vecA = vector<16x1x2xbf16>
-!vecB = vector<1x1x2xbf16>
-!vecC = vector<16x1xf32>
+!vecA = vector<1x1x4xi8>
+!vecB = vector<1x8x4xi8>
+!vecC = vector<1x8xi32>
#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
-func.func @matmul_outer_product_to_bf16dp_bcst_B(
+func.func @matmul_outer_product_to_int8dp(
%arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
{
%0 = vector.contract {
@@ -327,9 +327,78 @@ func.func @matmul_outer_product_to_bf16dp_bcst_B(
return %0 : !vecC
}
-// CHECK-LABEL: @matmul_outer_product_to_bf16dp_bcst_B
+// CHECK-LABEL: @matmul_outer_product_to_int8dp
// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_bf16dp_flat_layout(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %5 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %4, %2
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %5, %3
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_bf16dp_flat_layout
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86vector.avx512.dot
// CHECK: x86vector.avx512.dot
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -343,27 +412,163 @@ module attributes {transform.with_named_sequence} {
// -----
-!vecA = vector<1x1x4xi8>
-!vecB = vector<1x8x4xi8>
-!vecC = vector<1x8xi32>
-#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
-#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
-#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
-func.func @matmul_outer_product_to_int8dp(
- %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<1x1x2xbf16, strided<[2048, 32, 1], offset: ?>>
+!memrefB = memref<1x2x32xbf16, strided<[2048, 64, 1], offset: ?>>
+!memrefC = memref<1x32xf32, strided<[64, 1], offset: ?>>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @brmatmul_bf16dp_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1: memref<16x32x64xbf16>,
+ %arg2: memref<64x64xf32>) -> memref<64x64xf32> {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ scf.for %arg3 = %c0 to %c64 step %c1 {
+ scf.for %arg4 = %c0 to %c64 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [1, 32] [1, 1]
+ : memref<64x64xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+
+ %4:2 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+ %5:2 = scf.for %arg8 = %c0 to %c32 step %c2 iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg8] [1, 1, 2] [1, 1, 1]
+ : memref<16x64x32xbf16> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg8, %arg4] [1, 2, 32] [1, 1, 1]
+ : memref<16x32x64xbf16> to !memrefB
+
+ %6 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]}
+ : !memrefA, !vecA
+ %7 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]}
+ : !memrefB, !vecB
+ %8 = vector.transfer_read %subview_1[%c0, %c0, %c16], %1 {in_bounds = [true, true, true]}
+ : !memrefB, !vecB
+
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %arg9
+ {unroll_shape = array<i64: 1, 1, 16, 2>} : !vecA, !vecB into !vecC
+ %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %8, %arg10
+ {unroll_shape = array<i64: 1, 1, 16, 2>} : !vecA, !vecB into !vecC
+
+ scf.yield %9, %10 : !vecC, !vecC
+ }
+ scf.yield %5#0, %5#1 : !vecC, !vecC
+ }
+
+ vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]}
+ : !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]}
+ : !vecC, !memrefC
+ }
+ }
+
+ return %arg2 : memref<64x64xf32>
+}
+
+// CHECK-LABEL: @brmatmul_bf16dp_flat_layout_loop
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86vector.avx512.dot
+// CHECK: x86vector.avx512.dot
+// CHECK: scf.yield
+// CHECK: scf.yield
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_bf16dp_flat_layout_B_shuffled(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
{
- %0 = vector.contract {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.load %arg0[%c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.load %arg1[%c0, %c16] :
+ !memrefB, !vecB
+ %4 = vector.load %arg2[%c0, %c0] :
+ !memrefC, !vecC
+ %5 = vector.load %arg2[%c0, %c16] :
+ !memrefC, !vecC
+
+ %6 = vector.shape_cast %2 : !vecB to vector<32xbf16>
+ %7 = vector.shape_cast %3 : !vecB to vector<32xbf16>
+ %8 = vector.shuffle %6, %7 [0, 32, 1, 33, 2, 34, 3, 35,
+ 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18,
+ 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] :
+ vector<32xbf16>, vector<32xbf16>
+ %9 = vector.shuffle %6, %7 [4, 36, 5, 37, 6, 38, 7, 39,
+ 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53,
+ 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] :
+ vector<32xbf16>, vector<32xbf16>
+
+ %10 = vector.shape_cast %8 : vector<32xbf16> to !vecB
+ %11 = vector.shape_cast %9 : vector<32xbf16> to !vecB
+
+ %12 = vector.contract {
indexing_maps = [#map, #map1, #map2],
- iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>}
- %arg0, %arg1, %arg2
+ %1, %10, %4
: !vecA, !vecB into !vecC
- return %0 : !vecC
+
+ %13 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %11, %5
+ : !vecA, !vecB into !vecC
+
+ vector.store %12, %arg2[%c0, %c0] : !memrefC, !vecC
+ vector.store %13, %arg2[%c0, %c16] : !memrefC, !vecC
+
+ return %arg2 : !memrefC
}
-// CHECK-LABEL: @matmul_outer_product_to_int8dp
-// CHECK: vector.broadcast
-// CHECK: x86vector.avx.dot.i8
+// CHECK-LABEL: @matmul_bf16dp_flat_layout_B_shuffled
+// CHECK: x86vector.avx512.dot
+// CHECK: x86vector.avx512.dot
+// CHECK-NOT: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -749,3 +954,527 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+!vecA = vector<1x4xbf16>
+!vecB = vector<4x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x4xbf16>
+!memrefB = memref<4x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_flat_other_dim_is_not_2(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %3 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_other_dim_is_not_2
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_flat_offset_
diff _is_not16(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %3 = vector.transfer_read %arg1[%c0, %c32], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_offset_
diff _is_not16
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_flat_dynamic_offset(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %3 = vector.transfer_read %arg1[%c0, %arg3], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_dynamic_offset
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_flat_read_after_contract(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %3 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_read_after_contract
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_contracts_not_in_order(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.load %arg0[%c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.load %arg1[%c0, %c32] :
+ !memrefB, !vecB
+ %4 = vector.load %arg2[%c0, %c0] :
+ !memrefC, !vecC
+ %5 = vector.load %arg2[%c0, %c16] :
+ !memrefC, !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+ vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_contracts_not_in_order
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x32xbf16>
+!vecC = vector<1x32xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_dim_is_32(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.load %arg0[%c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.load %arg1[%c0, %c32] :
+ !memrefB, !vecB
+ %4 = vector.load %arg2[%c0, %c0] :
+ !memrefC, !vecC
+ %5 = vector.load %arg2[%c0, %c32] :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+ vector.store %7, %arg2[%c0, %c32] : !memrefC, !vecC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_dim_is_32
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_offset_
diff _is_32(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.load %arg0[%c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.load %arg1[%c0, %c32] :
+ !memrefB, !vecB
+ %4 = vector.load %arg2[%c0, %c0] :
+ !memrefC, !vecC
+ %5 = vector.load %arg2[%c0, %c16] :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+ vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_offset_
diff _is_32
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_dynamic_offset(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.load %arg0[%c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.load %arg1[%c0, %arg3] :
+ !memrefB, !vecB
+ %4 = vector.load %arg2[%c0, %c0] :
+ !memrefC, !vecC
+ %5 = vector.load %arg2[%c0, %c16] :
+ !memrefC, !vecC
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %4
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %3, %5
+ : !vecA, !vecB into !vecC
+
+ vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+ vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_dynamic_offset
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list