[Mlir-commits] [mlir] [mlir][x86] Lower packed type vector.contract to AMX dot-product (online-packing) (PR #188192)
Arun Thangamani
llvmlistbot at llvm.org
Tue Mar 24 01:08:47 PDT 2026
https://github.com/arun-thmn created https://github.com/llvm/llvm-project/pull/188192
A transform pass to lower flat layout `vector.contract` operation to (a) amx.tile_mulf for BF16, or (b) amx.tile_muli for Int8 packed types via `online` packing.
>From 7b888d75131d67a4e3f2cb2ffdec7e1f82abdf01 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 24 Mar 2026 01:05:51 -0700
Subject: [PATCH] support for amx online-packing
---
.../VectorContractToAMXDotProduct.cpp | 936 +++++++++++++++---
.../X86/AMX/vector-contract-to-tiled-dp.mlir | 349 ++++++-
2 files changed, 1118 insertions(+), 167 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 85966a85af40e..2b159e6f59cb9 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -1,4 +1,4 @@
-//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
+//===- VectorContractToAMXDotProduct.cpp ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -70,8 +70,9 @@ getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
if (!srcBuff)
return failure();
- if (isNotAcc)
+ if (isNotAcc) {
indexVals.pop_back();
+ }
SmallVector<Value> indices;
indices.reserve(indexVals.size());
@@ -189,20 +190,155 @@ static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
// Creates amx.tile_loads.
static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
Value operand, Value mat, Type ipType,
- bool rhs, unsigned int offset) {
+ bool rhs, unsigned int offset,
+ bool isVnni) {
auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
auto [srcBuff, indices] = *srcIndx;
- indices.pop_back();
+ if (isVnni) {
+ indices.pop_back();
+ }
- if (rhs) {
+ if (rhs && isVnni) {
auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
indices[indices.size() - 1] = arith::MulIOp::create(
rewriter, loc, indices[indices.size() - 1], cOffset);
}
amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
- return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+ auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+ return load;
+}
+
+static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
+ Type ipType, unsigned int offset, Value bBuffer,
+ Value allocStore) {
+
+ auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ SmallVector<Value> vals(subview.getOffsets().size(), c0);
+
+ // Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+ Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
+ Value nLoadIndx = arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
+
+ scf::ForOp::create(
+ rewriter, loc, c0, cBound, cStep, ValueRange{},
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ Value i1_load = arith::AddIOp::create(rewriter, loc, nLoadIndx, iv);
+
+ vals[vals.size() - 2] = iv;
+ ValueRange range1(vals);
+ auto vec1 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+ range1);
+
+ vals[vals.size() - 2] = i1_load;
+ ValueRange range2(vals);
+ auto vec2 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+ range2);
+
+ vector::ShuffleOp shuffle1;
+ vector::ShuffleOp shuffle2;
+
+ if (offset == 2) {
+
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9,
+ 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
+ 19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13,
+ 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
+ 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
+ }
+
+ if (offset == 4) {
+
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{
+ 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
+ 35, 67, 99, 4, 36, 68, 100, 5, 37, 69, 101, 6, 38,
+ 70, 102, 7, 39, 71, 103, 8, 40, 72, 104, 9, 41, 73,
+ 105, 10, 42, 74, 106, 11, 43, 75, 107, 12, 44, 76, 108,
+ 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{
+ 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51,
+ 83, 115, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118,
+ 23, 55, 87, 119, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58,
+ 90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125,
+ 30, 62, 94, 126, 31, 63, 95, 127});
+ }
+ Value j_pos = arith::DivUIOp::create(rewriter, loc, iv, cStep);
+ Value j16_pos = arith::AddIOp::create(rewriter, loc, c16, j_pos);
+
+ vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
+ ValueRange{allocStore, j_pos, c0});
+ vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
+ ValueRange{allocStore, j16_pos, c0});
+
+ scf::YieldOp::create(nestedBuilder, loc);
+ });
+}
+
+static llvm::DenseMap<Operation *, amx::TileLoadOp>
+packInputs(OpBuilder &rewriter, Location loc,
+ SmallVector<vector::ContractionOp> ops, Value matB, Type ipType,
+ unsigned int offset, Value bBuffer, bool pack, Value allocStore,
+ Value addIdx) {
+ llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+
+ for (size_t j = 0; j < ops.size(); j++) {
+ for (size_t i = 0; i < ops.size(); i++) {
+
+ if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
+
+ Operation *readOpRhs = ops[j].getRhs().getDefiningOp();
+ auto itRhs = readsToTileLoads.find(readOpRhs);
+ if (itRhs != readsToTileLoads.end()) {
+ continue;
+ }
+
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+ if (pack) {
+ performShuffle(rewriter, loc, matB, ipType, offset, bBuffer,
+ allocStore);
+ }
+
+ amx::TileType tileType =
+ amx::TileType::get({16, (16 * offset)}, ipType);
+ auto load = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+ ValueRange{addIdx, c0, c0});
+
+ auto load1 = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+ ValueRange{addIdx, c16, c0});
+
+ readsToTileLoads.try_emplace(readOpRhs, load);
+
+ readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), load1);
+ }
+ }
+ }
+
+ return readsToTileLoads;
}
// Creates tiled amx dot-products.
@@ -210,16 +346,26 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
SmallVector<vector::ContractionOp> ops,
Value matA, Value matB, Type ipType,
Type opType, ValueRange accIterArgs,
- unsigned int offset) {
+ unsigned int offset, bool isVnni,
+ Value bBuffer, bool pack,
+ Value allocStore, Value addIdx) {
- auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
- auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
+ if (isVnni) {
+ matA = collapseInnerDims(rewriter, loc, matA);
+ matB = collapseInnerDims(rewriter, loc, matB);
+ }
SmallVector<Value> accumulators;
// Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
// or load operations.
llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+ // function call to make the flat.
+ if (!isVnni) {
+ readsToTileLoads = packInputs(rewriter, loc, ops, matB, ipType, offset,
+ bBuffer, pack, allocStore, addIdx);
+ }
+
// Iterate over the contraction operations and compute the tiled dot-product.
for (size_t i = 0; i < ops.size(); i++) {
@@ -229,8 +375,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
if (itLhs != readsToTileLoads.end()) {
tilesLhs = itLhs->second;
} else {
- tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
- subviewCollapseLhs, ipType, false, offset);
+ tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
+ false, offset, isVnni);
readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
}
@@ -240,8 +386,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
if (itRhs != readsToTileLoads.end()) {
tilesRhs = itRhs->second;
} else {
- tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
- subviewCollapseRhs, ipType, true, offset);
+ tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
+ true, offset, isVnni);
readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
}
@@ -276,6 +422,172 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
return loopItrArgs;
}
+static Value bufferIndxToStore(OpBuilder &rewriter, Location loc, Value iv_K,
+ Value iv_red, bool oddDimK, bool nDimK,
+ bool pack, unsigned int blockingFactor) {
+
+ Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+
+ Value nLoadIndx =
+ arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
+
+ Value quotient_K = arith::DivUIOp::create(rewriter, loc, iv_K, nLoadIndx);
+ Value rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ quotient_K, c2);
+
+ if (!nDimK && !pack) {
+ rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ iv_red, c2);
+ }
+
+ // if K quotient is odd. Then, BR loop iv is taken
+ // into consideration
+ if (oddDimK) {
+ auto rem_BR = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ iv_red, c2);
+ auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
+ rem_K, rem_BR);
+ rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ remAdd, c2);
+ }
+
+ return rem_K;
+}
+
+static scf::ForOp
+createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
+ Value upperBound, Value step, SmallVector<Value> loopItrArgs,
+ Type ipType, Type opType, unsigned int blockingFactor, bool isVnni,
+ Operation *vectorOpLhs, Operation *vectorOpRhs,
+ vector::ContractionOp contractOp, scf::ForOp outerLoop,
+ scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
+ Value ivOuterLoop, Value bBuffer, bool pack,
+ arith::ConstantIndexOp innerLoopIndex, bool nDimK, bool oddDimK) {
+
+ auto newLoop1 = scf::ForOp::create(
+ rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
+ [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
+ Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
+ IRMapping mapping;
+ if (outerLoop) {
+ mapping.map(vectorOpLhs->getOperand(
+ getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+ ivOuterLoop);
+ }
+ mapping.map(vectorOpLhs->getOperand(
+ getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
+
+ Value c0 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 0);
+ Value c1 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 1);
+ Value c2 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 2);
+ Value allocStore = c0;
+ Value allocGet = c0;
+
+ if (!isVnni) {
+ if (outerLoop) {
+ if (innerLoopIndex.value() == 0) {
+ if (pack) {
+ ivNewInnerLoop = c0;
+ ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+ c1, ivOuterLoop);
+
+ if (!nDimK || oddDimK) {
+ allocStore = arith::RemUIOp::create(rewriter, locNewInnerLoop,
+ rewriter.getIndexType(),
+ ivOuterLoop, c2);
+ }
+
+ Value addIdx =
+ arith::AddIOp::create(rewriter, loc, allocStore, c1);
+ allocGet = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+ }
+
+ } else {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+ nLoadIndx, ivNewInnerLoop);
+ allocStore =
+ bufferIndxToStore(rewriter, loc, ivNewInnerLoop, ivOuterLoop,
+ oddDimK, nDimK, pack, blockingFactor);
+ Value addIdx =
+ arith::AddIOp::create(rewriter, loc, allocStore, c1);
+ allocGet = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+ }
+ } else {
+ if (pack) {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+ nLoadIndx, ivNewInnerLoop);
+ Value quotient_K = arith::DivUIOp::create(
+ rewriter, loc, ivNewInnerLoop, nLoadIndx);
+ allocStore = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+
+ Value addIdx =
+ arith::AddIOp::create(rewriter, loc, allocStore, c1);
+ allocGet = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+ }
+ }
+ }
+
+ IRMapping rhsMapping;
+ if (outerLoop) {
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+ ivOuterLoop);
+ }
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto rhsClone = rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
+
+ Value matB = rhsClone->getResult(0);
+
+ if (!isVnni) {
+ if (outerLoop) {
+ if (!pack) {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ matB = Value();
+ allocGet = c0;
+ allocGet =
+ bufferIndxToStore(rewriter, loc, nLoadIndx, ivOuterLoop,
+ oddDimK, nDimK, pack, blockingFactor);
+ }
+ } else {
+ if (!pack) {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ matB = Value();
+ Value quotient_K = arith::DivUIOp::create(
+ rewriter, loc, ivNewInnerLoop, nLoadIndx);
+ allocGet = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+ }
+ }
+ }
+
+ SmallVector<Value> accumulators = createTiledDp(
+ rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
+ ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
+ bBuffer, pack, allocStore, allocGet);
+
+ scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
+ accumulators);
+ });
+
+ return newLoop1;
+}
+
// Implements tiled dot-product operation for a vector.contract operation or a
// sequence of vector.contracts inside the reduction loops.
//
@@ -326,9 +638,6 @@ struct VectorContractToAMXDotProduct
return rewriter.notifyMatchFailure(contractOp,
"Only F32 for BF16 or Int32 for Int8 "
"accumulation type is supported.");
- if (!isVnni)
- return rewriter.notifyMatchFailure(
- contractOp, "Only VNNI-packed inputs are supported.");
Operation *accReadOp =
traceToVectorReadLikeParentOperation(contractOp.getAcc());
@@ -342,8 +651,13 @@ struct VectorContractToAMXDotProduct
"transfer_read or a load. And, the result should be "
"stored using transfer_write or store.");
- Type ipType = rewriter.getBF16Type();
- Type opType = rewriter.getF32Type();
+ Type ipType;
+ Type opType;
+
+ if (lhsTy.getElementType().isBF16()) {
+ ipType = rewriter.getBF16Type();
+ opType = rewriter.getF32Type();
+ }
if (lhsTy.getElementType().isSignlessInteger(8)) {
ipType = rewriter.getIntegerType(8);
@@ -360,13 +674,21 @@ struct VectorContractToAMXDotProduct
return rewriter.notifyMatchFailure(
contractOp, "The accumulator read is in different block.");
+ unsigned int dimValue = blockingFactor;
+ if (!isVnni)
+ dimValue = 16 * blockingFactor;
+
// Case 1: For just one VC rewrite. Where all accumulator read/write
// within the same block.
if (accReadOp->getBlock() == contractOp->getBlock() &&
resultWriteOp->getBlock() == contractOp->getBlock()) {
+ bool collapse = false;
+ if (isVnni)
+ collapse = true;
+
LogicalResult validate = validateContractOps(
- rewriter, contractOp, blockingFactor, Value(), Value(), false);
+ rewriter, contractOp, dimValue, Value(), Value(), false);
if (failed(validate))
return rewriter.notifyMatchFailure(
@@ -377,18 +699,20 @@ struct VectorContractToAMXDotProduct
Location loc = contractOp.getLoc();
auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
- contractOp.getLhs(), true);
+ contractOp.getLhs(), collapse);
if (failed(srcIndxLhs))
return rewriter.notifyMatchFailure(contractOp,
"The LHS src is not a MemRef type.");
auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
- contractOp.getRhs(), true);
+ contractOp.getRhs(), collapse);
if (failed(srcIndxRhs))
return rewriter.notifyMatchFailure(contractOp,
"The RHS src is not a MemRef type.");
- auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+ auto rhsSrc = *srcIndxRhs;
+ auto srcBuffRhs = rhsSrc.first;
+ auto indicesRhs = rhsSrc.second;
auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
contractOp.getAcc(), false);
@@ -401,8 +725,112 @@ struct VectorContractToAMXDotProduct
auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
srcBuffLhs, indicesLhs);
- auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
- srcBuffRhs, indicesRhs);
+
+ // Create the subview and then load.
+ //
+ amx::TileLoadOp loadRhs;
+ if (!isVnni) {
+ VectorType vecTy;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ vecTy = readOp.getType();
+ });
+ auto one = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult> strides(indexVals.size(), one);
+ SmallVector<OpFoldResult> sizes = getAsIndexOpFoldResult(
+ contractOp.getRhs().getDefiningOp()->getContext(),
+ vecTy.getShape());
+ auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
+ indexVals, sizes, strides);
+ auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
+ auto bBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+
+ // create a loop that swaps them.
+
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ Value step =
+ arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
+ Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
+ (blockingFactor * 16));
+ Value nextLoadIndx =
+ arith::ConstantIndexOp::create(rewriter, loc, (blockingFactor / 2));
+ Value nextStoreIndx = arith::ConstantIndexOp::create(
+ rewriter, loc, 16 * (blockingFactor / 2));
+
+ scf::ForOp::create(
+ rewriter, loc, c0, uBound, step, ValueRange{},
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ Value i1_load =
+ arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
+
+ indicesRhs[indicesRhs.size() - 2] = iv;
+ ValueRange range1(indicesRhs);
+ auto vec1 = vector::LoadOp::create(
+ rewriter, loc,
+ VectorType::get(16 * (blockingFactor / 2), ipType), subview,
+ range1);
+
+ indicesRhs[indicesRhs.size() - 2] = i1_load;
+ ValueRange range2(indicesRhs);
+ auto vec2 = vector::LoadOp::create(
+ rewriter, loc,
+ VectorType::get(16 * (blockingFactor / 2), ipType), subview,
+ range2);
+
+ vector::ShuffleOp shuffle1;
+ vector::ShuffleOp shuffle2;
+
+ if (blockingFactor == 2) {
+
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
+ ArrayRef<int64_t>{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21,
+ 6, 22, 7, 23});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
+ ArrayRef<int64_t>{8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
+ 29, 14, 30, 15, 31});
+ }
+
+ if (blockingFactor == 4) {
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
+ ArrayRef<int64_t>{0, 16, 32, 48, 1, 17, 33, 49,
+ 2, 18, 34, 50, 3, 19, 35, 51,
+ 4, 20, 36, 52, 5, 21, 37, 53,
+ 6, 22, 38, 54, 7, 23, 39, 55});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
+ ArrayRef<int64_t>{8, 24, 40, 56, 9, 25, 41, 57,
+ 10, 26, 42, 58, 11, 27, 43, 59,
+ 12, 28, 44, 60, 13, 29, 45, 61,
+ 14, 30, 46, 62, 15, 31, 47, 63});
+ }
+
+ auto rem = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), iv, step);
+
+ vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
+ ValueRange{rem, c0});
+ vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
+ ValueRange{rem, nextStoreIndx});
+
+ scf::YieldOp::create(nestedBuilder, loc);
+ });
+ loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+ ValueRange{c0, c0});
+ } else {
+
+ loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, srcBuffRhs,
+ indicesRhs);
+ }
auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
@@ -429,7 +857,6 @@ struct VectorContractToAMXDotProduct
// reduction loop.
SmallVector<scf::ForOp> loopLists;
Operation *current = contractOp;
-
while (true) {
Operation *parent = current->getParentOfType<scf::ForOp>();
loopLists.push_back(dyn_cast<scf::ForOp>(parent));
@@ -440,7 +867,6 @@ struct VectorContractToAMXDotProduct
current = parent;
}
-
if (loopLists.size() > 2 || loopLists.size() == 0)
return rewriter.notifyMatchFailure(
contractOp, "Rewrite is supported until reduction loop depth of 2.");
@@ -458,7 +884,6 @@ struct VectorContractToAMXDotProduct
return rewriter.notifyMatchFailure(contractOp,
"The RHS src is not a MemRef type.");
auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
-
Operation *vectorOpLhs;
llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
.Case<TransferReadOp, LoadOp>([&](auto readOp) {
@@ -478,7 +903,7 @@ struct VectorContractToAMXDotProduct
if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
LogicalResult validate = validateContractOps(
- rewriter, contract, blockingFactor, srcBuffLhs, srcBuffRhs, true);
+ rewriter, contract, dimValue, srcBuffLhs, srcBuffRhs, true);
if (failed(validate))
return rewriter.notifyMatchFailure(
@@ -490,8 +915,8 @@ struct VectorContractToAMXDotProduct
}
}
- scf::ForOp outerLoop;
scf::ForOp innerLoop;
+ scf::ForOp outerLoop;
scf::ForOp newLoop;
// Case 2a: Reduction loop depth is 2.
@@ -502,126 +927,248 @@ struct VectorContractToAMXDotProduct
SmallVector<Value> loopItrArgs = createTileZeros(
rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
- newLoop = scf::ForOp::create(
- rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
- outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
- [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
- Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
- auto newInnerLoop = scf::ForOp::create(
- rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
- innerLoop.getUpperBound(), innerLoop.getStep(),
- iterArgsOuterLoop,
- [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
- Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
- IRMapping mapping;
- mapping.map(
- vectorOpLhs->getOperand(
- getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
- ivOuterLoop);
- mapping.map(
- vectorOpLhs->getOperand(
- getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
- ivNewInnerLoop);
- auto lhsClone =
- rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
-
- IRMapping rhsMapping;
- rhsMapping.map(
- vectorOpRhs->getOperand(
- getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
- ivOuterLoop);
- rhsMapping.map(
- vectorOpRhs->getOperand(
- getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
- ivNewInnerLoop);
- auto rhsClone =
- rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
-
- SmallVector<Value> accumulators = createTiledDp(
- rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
- rhsClone->getResult(0), ipType, opType,
- iterArgsNewInnerLoop, blockingFactor);
-
- scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
- accumulators);
- });
-
- scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
- newInnerLoop.getResults());
- });
+ if (isVnni) {
+ newLoop = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+ outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ auto newInnerLoop = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ iterArgsOuterLoop, ipType, opType, blockingFactor, isVnni,
+ vectorOpLhs, vectorOpRhs, contractOp, outerLoop, innerLoop,
+ ops, ivOuterLoop, nullptr, true, nullptr, false, false);
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+ newInnerLoop.getResults());
+ });
+
+ } else {
+
+ bool nDimK = false;
+ bool oddDimK = false;
+
+ int64_t ubVal = 16 * blockingFactor;
+ mlir::Value ub = innerLoop.getUpperBound();
+ if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
+ if (auto intAttr =
+ llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+ ubVal = intAttr.getInt();
+ }
+ }
+
+ nDimK = ubVal > 16 * blockingFactor;
+ oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+
+ rewriter.setInsertionPoint(outerLoop);
+
+ auto c0 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+ auto c1 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+
+ auto spillLoopBound = arith::ConstantIndexOp::create(
+ rewriter, outerLoop.getLoc(), 16 * blockingFactor);
+ Value subBRLoop = arith::SubIOp::create(rewriter, outerLoop.getLoc(),
+ outerLoop.getUpperBound(), c1);
+ Value subKloop =
+ arith::SubIOp::create(rewriter, innerLoop.getLoc(),
+ innerLoop.getUpperBound(), spillLoopBound);
+ auto bufferType =
+ MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+ // First Shuffling outside the reduction loops
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+ c0);
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+ c0);
+ auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
+
+ performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
+ ipType, blockingFactor, bBuffer, c0);
+
+ // First Set of Loops
+ auto newLoop1 = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(), subBRLoop,
+ outerLoop.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ auto newInnerLoop1 = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
+ opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
+ true, spillLoopBound, nDimK, oddDimK);
+
+ auto newInnerLoop =
+ createLoops(rewriter, innerLoop.getLoc(), subKloop,
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ newInnerLoop1.getResults(), ipType, opType,
+ blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, outerLoop, innerLoop, ops,
+ ivOuterLoop, bBuffer, true, c0, nDimK, oddDimK);
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+ newInnerLoop.getResults());
+ });
+
+ // Last set of Loops
+ newLoop = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), subBRLoop, outerLoop.getUpperBound(),
+ outerLoop.getStep(), newLoop1.getResults(),
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ auto newInnerLoop1 = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
+ opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
+ true, spillLoopBound, nDimK, oddDimK);
+
+ auto newInnerLoop =
+ createLoops(rewriter, innerLoop.getLoc(), subKloop,
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ newInnerLoop1.getResults(), ipType, opType,
+ blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, outerLoop, innerLoop, ops,
+ ivOuterLoop, bBuffer, false, c0, nDimK, oddDimK);
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+ newInnerLoop.getResults());
+ });
+ }
}
- // Case 2b: Reduction loop depth is 1.
if (loopLists.size() == 1) {
outerLoop = loopLists[0];
+ innerLoop = loopLists[0];
SmallVector<Value> loopItrArgs = createTileZeros(
- rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
- newLoop = scf::ForOp::create(
- rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
- outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
- [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
- Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
- IRMapping mapping;
- mapping.map(
- vectorOpLhs->getOperand(
- getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
- ivOuterLoop);
-
- auto lhsClone = rewriterOuterLoop.clone(*vectorOpLhs, mapping);
-
- IRMapping rhsMapping;
- rhsMapping.map(
- vectorOpRhs->getOperand(
- getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
- ivOuterLoop);
-
- auto rhsClone = rewriterOuterLoop.clone(*vectorOpRhs, rhsMapping);
-
- SmallVector<Value> accumulators = createTiledDp(
- rewriter, locOuterLoop, ops, lhsClone->getResult(0),
- rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop,
- blockingFactor);
-
- scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
- });
+ rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
+
+ if (isVnni) {
+
+ newLoop = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
+ opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, nullptr, innerLoop, ops, nullptr, nullptr, true,
+ nullptr, false, false);
+
+ } else {
+ bool nDimK = false;
+ bool oddDimK = false;
+
+ int64_t ubVal = 16 * blockingFactor;
+ mlir::Value ub = innerLoop.getUpperBound();
+ if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
+ if (auto intAttr =
+ llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+ ubVal = intAttr.getInt();
+ }
+ }
+
+ nDimK = ubVal > 16 * blockingFactor;
+ oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+ rewriter.setInsertionPoint(innerLoop);
+ auto c0 =
+ arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
+ auto spillLoopBound = arith::ConstantIndexOp::create(
+ rewriter, innerLoop.getLoc(), 16 * blockingFactor);
+
+ Value subKloop =
+ arith::SubIOp::create(rewriter, innerLoop.getLoc(),
+ innerLoop.getUpperBound(), spillLoopBound);
+
+ auto bufferType =
+ MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
+
+ // First Shuffling outside the reduction loops
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+ c0);
+ auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
+
+ performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
+ ipType, blockingFactor, bBuffer, c0);
+
+ auto newLoop1 = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(), subKloop,
+ innerLoop.getStep(), loopItrArgs, ipType, opType, blockingFactor,
+ isVnni, vectorOpLhs, vectorOpRhs, contractOp, nullptr, innerLoop,
+ ops, nullptr, bBuffer, true, spillLoopBound, nDimK, oddDimK);
+
+ newLoop = createLoops(rewriter, innerLoop.getLoc(), subKloop,
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ newLoop1.getResults(), ipType, opType,
+ blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, nullptr, innerLoop, ops, nullptr,
+ bBuffer, false, c0, nDimK, oddDimK);
+ }
}
- // post processing after the loop creation.
// Copy the amx tile accumulation results to a MemRef buffer, add the
// initial accumulation value, and store back to the C-Matrix
- auto bufferType = MemRefType::get({16, 16}, opType);
- auto bBuffer =
- memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
- SmallVector<Value> dps = newLoop.getResults();
- for (size_t i = 0; i < ops.size(); i++) {
- vector::ContractionOp contOp = ops[i];
- Operation *resultWriteOp =
- traceToVectorWriteLikeUserOperation(contOp.getResult());
- rewriter.setInsertionPoint(resultWriteOp);
+ if (!isVnni) {
+ Location loc = outerLoop.getLoc();
+ SmallVector<Value> dps = newLoop.getResults();
+ auto bufferType = MemRefType::get({32, 32}, opType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+ for (int i = 0, k = 0; i < 32; i = i + 16) {
+ for (int j = 0; j < 32; j = j + 16) {
+ Value indexOp_i =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), i);
+ Value indexOp_j =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), j);
+ amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+ ValueRange{indexOp_i, indexOp_j}, dps[k]);
+ k++;
+ }
+ }
+ auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+ auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto mBound = arith::ConstantIndexOp::create(rewriter, loc, 32);
- Value indexOp_0 =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+ scf::ForOp::create(
+ rewriter, loc, c0, mBound, one, ValueRange{},
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ auto row = vector::LoadOp::create(rewriter, loc,
+ VectorType::get(16, opType),
+ bBuffer, ValueRange{iv, c0});
- amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
- ValueRange{indexOp_0, indexOp_0}, dps[i]);
+ auto row2 = vector::LoadOp::create(rewriter, loc,
+ VectorType::get(16, opType),
+ bBuffer, ValueRange{iv, c16});
- auto c0 = arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
- auto one =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
- auto mBound =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+ auto shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get(16, opType), row, row2,
+ ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
+ 21, 22, 23});
- scf::ForOp::create(
- rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
- [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
- auto resultAcc = vector::LoadOp::create(
- rewriter, loc, VectorType::get(16, opType), bBuffer,
- ValueRange{iv, c0});
+ auto shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get(16, opType), row, row2,
+ ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
+ 28, 29, 30, 31});
Operation *accReadOp =
- traceToVectorReadLikeParentOperation(ops[i].getAcc());
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
Value srcBuffAcc;
SmallVector<Value> indicesAcc;
@@ -641,24 +1188,119 @@ struct VectorContractToAMXDotProduct
});
});
- Value sum = arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
- indicesAcc[indicesAcc.size() - 2] = sum;
-
- auto acc = vector::LoadOp::create(rewriter, loc,
- VectorType::get(16, opType),
- srcBuffAcc, indicesAcc);
- Value addition;
- if (ipType.isBF16())
- addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
-
- if (ipType.isSignlessInteger(8))
- addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
-
- vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+ indicesAcc[indicesAcc.size() - 2] = iv;
+ indicesAcc[indicesAcc.size() - 1] = c0;
+
+ Value valueCRow1 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
+ indicesAcc);
+ indicesAcc[indicesAcc.size() - 1] = c16;
+
+ Value valueCRow2 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
+ indicesAcc);
+ Value addOp;
+ Value addOp2;
+
+ if (ipType.isBF16()) {
+ addOp =
+ arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
+
+ addOp2 =
+ arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
+ }
+
+ if (ipType.isSignlessInteger(8)) {
+ addOp =
+ arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
+
+ addOp2 =
+ arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
+ }
+ indicesAcc[indicesAcc.size() - 1] = c0;
+ vector::StoreOp::create(rewriter, loc, addOp, srcBuffAcc,
+ indicesAcc);
+ indicesAcc[indicesAcc.size() - 1] = c16;
+ vector::StoreOp::create(rewriter, loc, addOp2, srcBuffAcc,
indicesAcc);
- scf::YieldOp::create(builder, outerLoop.getLoc());
+ scf::YieldOp::create(nestedBuilder, loc);
});
+ }
+ auto bufferType = MemRefType::get({16, 16}, opType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+ SmallVector<Value> dps = newLoop.getResults();
+ for (size_t i = 0; i < ops.size(); i++) {
+ vector::ContractionOp contOp = ops[i];
+ Operation *resultWriteOp =
+ traceToVectorWriteLikeUserOperation(contOp.getResult());
+ if (isVnni) {
+ rewriter.setInsertionPoint(resultWriteOp);
+
+ Value indexOp_0 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+ amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+ ValueRange{indexOp_0, indexOp_0}, dps[i]);
+
+ auto c0 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+ auto one =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+ auto mBound =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+
+ scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
+ [&](OpBuilder &builder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ auto resultAcc = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), bBuffer,
+ ValueRange{iv, c0});
+
+ Operation *accReadOp =
+ traceToVectorReadLikeParentOperation(ops[i].getAcc());
+
+ Value srcBuffAcc;
+ SmallVector<Value> indicesAcc;
+
+ llvm::TypeSwitch<Operation *>(accReadOp)
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ srcBuffAcc = readOp.getOperand(0);
+
+ auto indices = readOp.getIndices();
+ indicesAcc.reserve(indices.size());
+
+ llvm::transform(
+ indices, std::back_inserter(indicesAcc),
+ [&](OpFoldResult ofr) {
+ return mlir::getValueOrCreateConstantIndexOp(
+ rewriter, loc, ofr);
+ });
+ });
+
+ Value sum =
+ arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
+ indicesAcc[0] = sum;
+
+ auto acc = vector::LoadOp::create(rewriter, loc,
+ VectorType::get(16, opType),
+ srcBuffAcc, indicesAcc);
+ Value addition;
+ if (ipType.isBF16())
+ addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
+
+ if (ipType.isSignlessInteger(8))
+ addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
+
+ vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+ indicesAcc);
+
+ scf::YieldOp::create(builder, outerLoop.getLoc());
+ });
+ }
rewriter.eraseOp(resultWriteOp);
}
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index cde15b680a037..151946453df81 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -216,6 +216,122 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8>
+!memrefB = memref<64x32xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+func.func @online_packing_int8(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @online_packing_int8
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: scf.for
+// CHECK: vector.shuffle{{.*}}[0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55] : vector<32xi8>, vector<32xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63] : vector<32xi8>, vector<32xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: x86.amx.tile_muli
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x32xbf16>
+!vecB = vector<1x32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x32xbf16>
+!memrefB = memref<1x32x32xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @online_packing_bf16(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @online_packing_bf16
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: scf.for
+// CHECK: vector.shuffle{{.*}}[0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23] : vector<16xbf16>, vector<16xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31] : vector<16xbf16>, vector<16xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecAB = vector<1x16x16x2xbf16>
!vecC = vector<16x16xf32>
!memrefA = memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
@@ -483,6 +599,199 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x16x32xbf16>
+!vecB = vector<1x32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x32xbf16, strided<[6144, 96, 1], offset: ?>>
+!memrefB = memref<1x32x32xbf16, strided<[12288, 128, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @online_packing_bf16_loop(%arg0: memref<16x64x96xbf16>, %arg1: memref<16x96x128xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c96 = arith.constant 96 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+ memref<64x128xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %7:4 = scf.for %arg10 = %c0 to %c96 step %c32 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10] [1, 32, 32] [1, 1, 1] :
+ memref<16x64x96xbf16> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4] [1, 32, 32] [1, 1, 1] :
+ memref<16x96x128xbf16> to !memrefB
+ %8 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecA
+ %9 = vector.transfer_read %subview_0[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecA
+ %10 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+ %11 = vector.transfer_read %subview_1[%c0, %c0, %c16], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %10, %arg11 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %11, %arg12 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %10, %arg13 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %11, %arg14 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+
+ scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC
+ }
+ scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC
+ }
+ vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xf32>
+ memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+ return %alloc : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @online_packing_bf16_loop
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK-COUNT-4: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8, strided<[256, 1], offset: ?>>
+!memrefB = memref<64x32xi8, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memref<256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> {
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %6:4 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg3, %arg5] [32, 64] [1, 1] : memref<64x256xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4] [64, 32] [1, 1] : memref<256x128xi8> to !memrefB
+ %7 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+ %8 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+ %9 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+ %10 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+ %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %9, %arg6 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %10, %arg7 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %9, %arg8 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %10, %arg9 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ scf.yield %11, %12, %13, %14 : !vecC, !vecC, !vecC, !vecC
+ }
+ vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xi32>
+ memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32>
+ return %alloc : memref<64x128xi32>
+}
+
+// CHECK-LABEL: @online_packing_int8_matmul_loop
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111] : vector<64xi8>, vector<64xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_muli
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x16x16x4xi8>
!vecB = vector<1x16x16x4xi8>
!vecC = vector<16x16xi32>
@@ -637,32 +946,32 @@ module attributes {transform.with_named_sequence} {
// -----
-!vecA = vector<16x64xi8>
-!vecB = vector<64x16xi8>
-!vecC = vector<16x16xi32>
-!memrefA = memref<32x64xi8>
-!memrefB = memref<64x32xi8>
-!memrefC = memref<32x32xi32>
-#map = affine_map<(d1, d2, d3) -> (d1, d3)>
-#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
-#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
-func.func @negative_no_vnni_packed(
+!vecA = vector<1x16x32xbf16>
+!vecB = vector<1x32x32xbf16>
+!vecC = vector<16x32xf32>
+!memrefA = memref<1x32x32xbf16>
+!memrefB = memref<1x32x32xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @negative_wrong_dimensions_online_packing(
%arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
{
%c0 = arith.constant 0 : index
- %0 = ub.poison : i8
- %32 = ub.poison : i32
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
- %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
!memrefA, !vecA
- %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
!memrefB, !vecB
%3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
%4 = vector.contract {
indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction"],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
kind = #vector.kind<add>}
%1, %2, %3 : !vecA, !vecB into !vecC
@@ -671,13 +980,13 @@ func.func @negative_no_vnni_packed(
return %arg2 : !memrefC
}
-// CHECK-LABEL: @negative_no_vnni_packed
-// CHECK-NOT: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
-// CHECK-NOT: x86.amx.tile_muli
-// CHECK-NOT: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-LABEL: @negative_wrong_dimensions_online_packing
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK-NOT: x86.amx.tile_store
// CHECK: vector.contract
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
More information about the Mlir-commits
mailing list