[Mlir-commits] [mlir] [mlir][x86] Lower packed type vector.contract to AMX dot-product (online-packing) (PR #188192)
Arun Thangamani
llvmlistbot at llvm.org
Wed Apr 8 11:09:30 PDT 2026
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/188192
>From 7b888d75131d67a4e3f2cb2ffdec7e1f82abdf01 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 24 Mar 2026 01:05:51 -0700
Subject: [PATCH 1/5] support for amx online-packing
---
.../VectorContractToAMXDotProduct.cpp | 936 +++++++++++++++---
.../X86/AMX/vector-contract-to-tiled-dp.mlir | 349 ++++++-
2 files changed, 1118 insertions(+), 167 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 85966a85af40e..2b159e6f59cb9 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -1,4 +1,4 @@
-//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
+//===- VectorContractToAMXDotProduct.cpp ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -70,8 +70,9 @@ getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
if (!srcBuff)
return failure();
- if (isNotAcc)
+ if (isNotAcc) {
indexVals.pop_back();
+ }
SmallVector<Value> indices;
indices.reserve(indexVals.size());
@@ -189,20 +190,155 @@ static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
// Creates amx.tile_loads.
static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
Value operand, Value mat, Type ipType,
- bool rhs, unsigned int offset) {
+ bool rhs, unsigned int offset,
+ bool isVnni) {
auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
auto [srcBuff, indices] = *srcIndx;
- indices.pop_back();
+ if (isVnni) {
+ indices.pop_back();
+ }
- if (rhs) {
+ if (rhs && isVnni) {
auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
indices[indices.size() - 1] = arith::MulIOp::create(
rewriter, loc, indices[indices.size() - 1], cOffset);
}
amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
- return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+ auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+ return load;
+}
+
+static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
+ Type ipType, unsigned int offset, Value bBuffer,
+ Value allocStore) {
+
+ auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ SmallVector<Value> vals(subview.getOffsets().size(), c0);
+
+ // Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+ Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
+ Value nLoadIndx = arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
+
+ scf::ForOp::create(
+ rewriter, loc, c0, cBound, cStep, ValueRange{},
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ Value i1_load = arith::AddIOp::create(rewriter, loc, nLoadIndx, iv);
+
+ vals[vals.size() - 2] = iv;
+ ValueRange range1(vals);
+ auto vec1 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+ range1);
+
+ vals[vals.size() - 2] = i1_load;
+ ValueRange range2(vals);
+ auto vec2 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+ range2);
+
+ vector::ShuffleOp shuffle1;
+ vector::ShuffleOp shuffle2;
+
+ if (offset == 2) {
+
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9,
+ 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
+ 19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13,
+ 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
+ 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
+ }
+
+ if (offset == 4) {
+
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{
+ 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
+ 35, 67, 99, 4, 36, 68, 100, 5, 37, 69, 101, 6, 38,
+ 70, 102, 7, 39, 71, 103, 8, 40, 72, 104, 9, 41, 73,
+ 105, 10, 42, 74, 106, 11, 43, 75, 107, 12, 44, 76, 108,
+ 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{
+ 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51,
+ 83, 115, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118,
+ 23, 55, 87, 119, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58,
+ 90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125,
+ 30, 62, 94, 126, 31, 63, 95, 127});
+ }
+ Value j_pos = arith::DivUIOp::create(rewriter, loc, iv, cStep);
+ Value j16_pos = arith::AddIOp::create(rewriter, loc, c16, j_pos);
+
+ vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
+ ValueRange{allocStore, j_pos, c0});
+ vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
+ ValueRange{allocStore, j16_pos, c0});
+
+ scf::YieldOp::create(nestedBuilder, loc);
+ });
+}
+
+static llvm::DenseMap<Operation *, amx::TileLoadOp>
+packInputs(OpBuilder &rewriter, Location loc,
+ SmallVector<vector::ContractionOp> ops, Value matB, Type ipType,
+ unsigned int offset, Value bBuffer, bool pack, Value allocStore,
+ Value addIdx) {
+ llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+
+ for (size_t j = 0; j < ops.size(); j++) {
+ for (size_t i = 0; i < ops.size(); i++) {
+
+ if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
+
+ Operation *readOpRhs = ops[j].getRhs().getDefiningOp();
+ auto itRhs = readsToTileLoads.find(readOpRhs);
+ if (itRhs != readsToTileLoads.end()) {
+ continue;
+ }
+
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+ if (pack) {
+ performShuffle(rewriter, loc, matB, ipType, offset, bBuffer,
+ allocStore);
+ }
+
+ amx::TileType tileType =
+ amx::TileType::get({16, (16 * offset)}, ipType);
+ auto load = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+ ValueRange{addIdx, c0, c0});
+
+ auto load1 = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+ ValueRange{addIdx, c16, c0});
+
+ readsToTileLoads.try_emplace(readOpRhs, load);
+
+ readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), load1);
+ }
+ }
+ }
+
+ return readsToTileLoads;
}
// Creates tiled amx dot-products.
@@ -210,16 +346,26 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
SmallVector<vector::ContractionOp> ops,
Value matA, Value matB, Type ipType,
Type opType, ValueRange accIterArgs,
- unsigned int offset) {
+ unsigned int offset, bool isVnni,
+ Value bBuffer, bool pack,
+ Value allocStore, Value addIdx) {
- auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
- auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
+ if (isVnni) {
+ matA = collapseInnerDims(rewriter, loc, matA);
+ matB = collapseInnerDims(rewriter, loc, matB);
+ }
SmallVector<Value> accumulators;
// Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
// or load operations.
llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+ // function call to make the flat.
+ if (!isVnni) {
+ readsToTileLoads = packInputs(rewriter, loc, ops, matB, ipType, offset,
+ bBuffer, pack, allocStore, addIdx);
+ }
+
// Iterate over the contraction operations and compute the tiled dot-product.
for (size_t i = 0; i < ops.size(); i++) {
@@ -229,8 +375,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
if (itLhs != readsToTileLoads.end()) {
tilesLhs = itLhs->second;
} else {
- tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
- subviewCollapseLhs, ipType, false, offset);
+ tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
+ false, offset, isVnni);
readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
}
@@ -240,8 +386,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
if (itRhs != readsToTileLoads.end()) {
tilesRhs = itRhs->second;
} else {
- tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
- subviewCollapseRhs, ipType, true, offset);
+ tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
+ true, offset, isVnni);
readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
}
@@ -276,6 +422,172 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
return loopItrArgs;
}
+static Value bufferIndxToStore(OpBuilder &rewriter, Location loc, Value iv_K,
+ Value iv_red, bool oddDimK, bool nDimK,
+ bool pack, unsigned int blockingFactor) {
+
+ Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+
+ Value nLoadIndx =
+ arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
+
+ Value quotient_K = arith::DivUIOp::create(rewriter, loc, iv_K, nLoadIndx);
+ Value rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ quotient_K, c2);
+
+ if (!nDimK && !pack) {
+ rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ iv_red, c2);
+ }
+
+ // if K quotient is odd. Then, BR loop iv is taken
+ // into consideration
+ if (oddDimK) {
+ auto rem_BR = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ iv_red, c2);
+ auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
+ rem_K, rem_BR);
+ rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ remAdd, c2);
+ }
+
+ return rem_K;
+}
+
+static scf::ForOp
+createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
+ Value upperBound, Value step, SmallVector<Value> loopItrArgs,
+ Type ipType, Type opType, unsigned int blockingFactor, bool isVnni,
+ Operation *vectorOpLhs, Operation *vectorOpRhs,
+ vector::ContractionOp contractOp, scf::ForOp outerLoop,
+ scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
+ Value ivOuterLoop, Value bBuffer, bool pack,
+ arith::ConstantIndexOp innerLoopIndex, bool nDimK, bool oddDimK) {
+
+ auto newLoop1 = scf::ForOp::create(
+ rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
+ [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
+ Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
+ IRMapping mapping;
+ if (outerLoop) {
+ mapping.map(vectorOpLhs->getOperand(
+ getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+ ivOuterLoop);
+ }
+ mapping.map(vectorOpLhs->getOperand(
+ getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
+
+ Value c0 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 0);
+ Value c1 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 1);
+ Value c2 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 2);
+ Value allocStore = c0;
+ Value allocGet = c0;
+
+ if (!isVnni) {
+ if (outerLoop) {
+ if (innerLoopIndex.value() == 0) {
+ if (pack) {
+ ivNewInnerLoop = c0;
+ ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+ c1, ivOuterLoop);
+
+ if (!nDimK || oddDimK) {
+ allocStore = arith::RemUIOp::create(rewriter, locNewInnerLoop,
+ rewriter.getIndexType(),
+ ivOuterLoop, c2);
+ }
+
+ Value addIdx =
+ arith::AddIOp::create(rewriter, loc, allocStore, c1);
+ allocGet = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+ }
+
+ } else {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+ nLoadIndx, ivNewInnerLoop);
+ allocStore =
+ bufferIndxToStore(rewriter, loc, ivNewInnerLoop, ivOuterLoop,
+ oddDimK, nDimK, pack, blockingFactor);
+ Value addIdx =
+ arith::AddIOp::create(rewriter, loc, allocStore, c1);
+ allocGet = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+ }
+ } else {
+ if (pack) {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+ nLoadIndx, ivNewInnerLoop);
+ Value quotient_K = arith::DivUIOp::create(
+ rewriter, loc, ivNewInnerLoop, nLoadIndx);
+ allocStore = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+
+ Value addIdx =
+ arith::AddIOp::create(rewriter, loc, allocStore, c1);
+ allocGet = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+ }
+ }
+ }
+
+ IRMapping rhsMapping;
+ if (outerLoop) {
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+ ivOuterLoop);
+ }
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto rhsClone = rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
+
+ Value matB = rhsClone->getResult(0);
+
+ if (!isVnni) {
+ if (outerLoop) {
+ if (!pack) {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ matB = Value();
+ allocGet = c0;
+ allocGet =
+ bufferIndxToStore(rewriter, loc, nLoadIndx, ivOuterLoop,
+ oddDimK, nDimK, pack, blockingFactor);
+ }
+ } else {
+ if (!pack) {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ matB = Value();
+ Value quotient_K = arith::DivUIOp::create(
+ rewriter, loc, ivNewInnerLoop, nLoadIndx);
+ allocGet = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+ }
+ }
+ }
+
+ SmallVector<Value> accumulators = createTiledDp(
+ rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
+ ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
+ bBuffer, pack, allocStore, allocGet);
+
+ scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
+ accumulators);
+ });
+
+ return newLoop1;
+}
+
// Implements tiled dot-product operation for a vector.contract operation or a
// sequence of vector.contracts inside the reduction loops.
//
@@ -326,9 +638,6 @@ struct VectorContractToAMXDotProduct
return rewriter.notifyMatchFailure(contractOp,
"Only F32 for BF16 or Int32 for Int8 "
"accumulation type is supported.");
- if (!isVnni)
- return rewriter.notifyMatchFailure(
- contractOp, "Only VNNI-packed inputs are supported.");
Operation *accReadOp =
traceToVectorReadLikeParentOperation(contractOp.getAcc());
@@ -342,8 +651,13 @@ struct VectorContractToAMXDotProduct
"transfer_read or a load. And, the result should be "
"stored using transfer_write or store.");
- Type ipType = rewriter.getBF16Type();
- Type opType = rewriter.getF32Type();
+ Type ipType;
+ Type opType;
+
+ if (lhsTy.getElementType().isBF16()) {
+ ipType = rewriter.getBF16Type();
+ opType = rewriter.getF32Type();
+ }
if (lhsTy.getElementType().isSignlessInteger(8)) {
ipType = rewriter.getIntegerType(8);
@@ -360,13 +674,21 @@ struct VectorContractToAMXDotProduct
return rewriter.notifyMatchFailure(
contractOp, "The accumulator read is in different block.");
+ unsigned int dimValue = blockingFactor;
+ if (!isVnni)
+ dimValue = 16 * blockingFactor;
+
// Case 1: For just one VC rewrite. Where all accumulator read/write
// within the same block.
if (accReadOp->getBlock() == contractOp->getBlock() &&
resultWriteOp->getBlock() == contractOp->getBlock()) {
+ bool collapse = false;
+ if (isVnni)
+ collapse = true;
+
LogicalResult validate = validateContractOps(
- rewriter, contractOp, blockingFactor, Value(), Value(), false);
+ rewriter, contractOp, dimValue, Value(), Value(), false);
if (failed(validate))
return rewriter.notifyMatchFailure(
@@ -377,18 +699,20 @@ struct VectorContractToAMXDotProduct
Location loc = contractOp.getLoc();
auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
- contractOp.getLhs(), true);
+ contractOp.getLhs(), collapse);
if (failed(srcIndxLhs))
return rewriter.notifyMatchFailure(contractOp,
"The LHS src is not a MemRef type.");
auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
- contractOp.getRhs(), true);
+ contractOp.getRhs(), collapse);
if (failed(srcIndxRhs))
return rewriter.notifyMatchFailure(contractOp,
"The RHS src is not a MemRef type.");
- auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+ auto rhsSrc = *srcIndxRhs;
+ auto srcBuffRhs = rhsSrc.first;
+ auto indicesRhs = rhsSrc.second;
auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
contractOp.getAcc(), false);
@@ -401,8 +725,112 @@ struct VectorContractToAMXDotProduct
auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
srcBuffLhs, indicesLhs);
- auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
- srcBuffRhs, indicesRhs);
+
+ // Create the subview and then load.
+ //
+ amx::TileLoadOp loadRhs;
+ if (!isVnni) {
+ VectorType vecTy;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ vecTy = readOp.getType();
+ });
+ auto one = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult> strides(indexVals.size(), one);
+ SmallVector<OpFoldResult> sizes = getAsIndexOpFoldResult(
+ contractOp.getRhs().getDefiningOp()->getContext(),
+ vecTy.getShape());
+ auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
+ indexVals, sizes, strides);
+ auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
+ auto bBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+
+ // create a loop that swaps them.
+
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ Value step =
+ arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
+ Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
+ (blockingFactor * 16));
+ Value nextLoadIndx =
+ arith::ConstantIndexOp::create(rewriter, loc, (blockingFactor / 2));
+ Value nextStoreIndx = arith::ConstantIndexOp::create(
+ rewriter, loc, 16 * (blockingFactor / 2));
+
+ scf::ForOp::create(
+ rewriter, loc, c0, uBound, step, ValueRange{},
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ Value i1_load =
+ arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
+
+ indicesRhs[indicesRhs.size() - 2] = iv;
+ ValueRange range1(indicesRhs);
+ auto vec1 = vector::LoadOp::create(
+ rewriter, loc,
+ VectorType::get(16 * (blockingFactor / 2), ipType), subview,
+ range1);
+
+ indicesRhs[indicesRhs.size() - 2] = i1_load;
+ ValueRange range2(indicesRhs);
+ auto vec2 = vector::LoadOp::create(
+ rewriter, loc,
+ VectorType::get(16 * (blockingFactor / 2), ipType), subview,
+ range2);
+
+ vector::ShuffleOp shuffle1;
+ vector::ShuffleOp shuffle2;
+
+ if (blockingFactor == 2) {
+
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
+ ArrayRef<int64_t>{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21,
+ 6, 22, 7, 23});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
+ ArrayRef<int64_t>{8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
+ 29, 14, 30, 15, 31});
+ }
+
+ if (blockingFactor == 4) {
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
+ ArrayRef<int64_t>{0, 16, 32, 48, 1, 17, 33, 49,
+ 2, 18, 34, 50, 3, 19, 35, 51,
+ 4, 20, 36, 52, 5, 21, 37, 53,
+ 6, 22, 38, 54, 7, 23, 39, 55});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
+ ArrayRef<int64_t>{8, 24, 40, 56, 9, 25, 41, 57,
+ 10, 26, 42, 58, 11, 27, 43, 59,
+ 12, 28, 44, 60, 13, 29, 45, 61,
+ 14, 30, 46, 62, 15, 31, 47, 63});
+ }
+
+ auto rem = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), iv, step);
+
+ vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
+ ValueRange{rem, c0});
+ vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
+ ValueRange{rem, nextStoreIndx});
+
+ scf::YieldOp::create(nestedBuilder, loc);
+ });
+ loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+ ValueRange{c0, c0});
+ } else {
+
+ loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, srcBuffRhs,
+ indicesRhs);
+ }
auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
@@ -429,7 +857,6 @@ struct VectorContractToAMXDotProduct
// reduction loop.
SmallVector<scf::ForOp> loopLists;
Operation *current = contractOp;
-
while (true) {
Operation *parent = current->getParentOfType<scf::ForOp>();
loopLists.push_back(dyn_cast<scf::ForOp>(parent));
@@ -440,7 +867,6 @@ struct VectorContractToAMXDotProduct
current = parent;
}
-
if (loopLists.size() > 2 || loopLists.size() == 0)
return rewriter.notifyMatchFailure(
contractOp, "Rewrite is supported until reduction loop depth of 2.");
@@ -458,7 +884,6 @@ struct VectorContractToAMXDotProduct
return rewriter.notifyMatchFailure(contractOp,
"The RHS src is not a MemRef type.");
auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
-
Operation *vectorOpLhs;
llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
.Case<TransferReadOp, LoadOp>([&](auto readOp) {
@@ -478,7 +903,7 @@ struct VectorContractToAMXDotProduct
if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
LogicalResult validate = validateContractOps(
- rewriter, contract, blockingFactor, srcBuffLhs, srcBuffRhs, true);
+ rewriter, contract, dimValue, srcBuffLhs, srcBuffRhs, true);
if (failed(validate))
return rewriter.notifyMatchFailure(
@@ -490,8 +915,8 @@ struct VectorContractToAMXDotProduct
}
}
- scf::ForOp outerLoop;
scf::ForOp innerLoop;
+ scf::ForOp outerLoop;
scf::ForOp newLoop;
// Case 2a: Reduction loop depth is 2.
@@ -502,126 +927,248 @@ struct VectorContractToAMXDotProduct
SmallVector<Value> loopItrArgs = createTileZeros(
rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
- newLoop = scf::ForOp::create(
- rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
- outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
- [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
- Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
- auto newInnerLoop = scf::ForOp::create(
- rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
- innerLoop.getUpperBound(), innerLoop.getStep(),
- iterArgsOuterLoop,
- [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
- Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
- IRMapping mapping;
- mapping.map(
- vectorOpLhs->getOperand(
- getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
- ivOuterLoop);
- mapping.map(
- vectorOpLhs->getOperand(
- getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
- ivNewInnerLoop);
- auto lhsClone =
- rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
-
- IRMapping rhsMapping;
- rhsMapping.map(
- vectorOpRhs->getOperand(
- getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
- ivOuterLoop);
- rhsMapping.map(
- vectorOpRhs->getOperand(
- getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
- ivNewInnerLoop);
- auto rhsClone =
- rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
-
- SmallVector<Value> accumulators = createTiledDp(
- rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
- rhsClone->getResult(0), ipType, opType,
- iterArgsNewInnerLoop, blockingFactor);
-
- scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
- accumulators);
- });
-
- scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
- newInnerLoop.getResults());
- });
+ if (isVnni) {
+ newLoop = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+ outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ auto newInnerLoop = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ iterArgsOuterLoop, ipType, opType, blockingFactor, isVnni,
+ vectorOpLhs, vectorOpRhs, contractOp, outerLoop, innerLoop,
+ ops, ivOuterLoop, nullptr, true, nullptr, false, false);
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+ newInnerLoop.getResults());
+ });
+
+ } else {
+
+ bool nDimK = false;
+ bool oddDimK = false;
+
+ int64_t ubVal = 16 * blockingFactor;
+ mlir::Value ub = innerLoop.getUpperBound();
+ if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
+ if (auto intAttr =
+ llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+ ubVal = intAttr.getInt();
+ }
+ }
+
+ nDimK = ubVal > 16 * blockingFactor;
+ oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+
+ rewriter.setInsertionPoint(outerLoop);
+
+ auto c0 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+ auto c1 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+
+ auto spillLoopBound = arith::ConstantIndexOp::create(
+ rewriter, outerLoop.getLoc(), 16 * blockingFactor);
+ Value subBRLoop = arith::SubIOp::create(rewriter, outerLoop.getLoc(),
+ outerLoop.getUpperBound(), c1);
+ Value subKloop =
+ arith::SubIOp::create(rewriter, innerLoop.getLoc(),
+ innerLoop.getUpperBound(), spillLoopBound);
+ auto bufferType =
+ MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+ // First Shuffling outside the reduction loops
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+ c0);
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+ c0);
+ auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
+
+ performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
+ ipType, blockingFactor, bBuffer, c0);
+
+ // First Set of Loops
+ auto newLoop1 = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(), subBRLoop,
+ outerLoop.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ auto newInnerLoop1 = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
+ opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
+ true, spillLoopBound, nDimK, oddDimK);
+
+ auto newInnerLoop =
+ createLoops(rewriter, innerLoop.getLoc(), subKloop,
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ newInnerLoop1.getResults(), ipType, opType,
+ blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, outerLoop, innerLoop, ops,
+ ivOuterLoop, bBuffer, true, c0, nDimK, oddDimK);
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+ newInnerLoop.getResults());
+ });
+
+ // Last set of Loops
+ newLoop = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), subBRLoop, outerLoop.getUpperBound(),
+ outerLoop.getStep(), newLoop1.getResults(),
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ auto newInnerLoop1 = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
+ opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
+ true, spillLoopBound, nDimK, oddDimK);
+
+ auto newInnerLoop =
+ createLoops(rewriter, innerLoop.getLoc(), subKloop,
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ newInnerLoop1.getResults(), ipType, opType,
+ blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, outerLoop, innerLoop, ops,
+ ivOuterLoop, bBuffer, false, c0, nDimK, oddDimK);
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+ newInnerLoop.getResults());
+ });
+ }
}
- // Case 2b: Reduction loop depth is 1.
if (loopLists.size() == 1) {
outerLoop = loopLists[0];
+ innerLoop = loopLists[0];
SmallVector<Value> loopItrArgs = createTileZeros(
- rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
- newLoop = scf::ForOp::create(
- rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
- outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
- [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
- Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
- IRMapping mapping;
- mapping.map(
- vectorOpLhs->getOperand(
- getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
- ivOuterLoop);
-
- auto lhsClone = rewriterOuterLoop.clone(*vectorOpLhs, mapping);
-
- IRMapping rhsMapping;
- rhsMapping.map(
- vectorOpRhs->getOperand(
- getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
- ivOuterLoop);
-
- auto rhsClone = rewriterOuterLoop.clone(*vectorOpRhs, rhsMapping);
-
- SmallVector<Value> accumulators = createTiledDp(
- rewriter, locOuterLoop, ops, lhsClone->getResult(0),
- rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop,
- blockingFactor);
-
- scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
- });
+ rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
+
+ if (isVnni) {
+
+ newLoop = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
+ opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, nullptr, innerLoop, ops, nullptr, nullptr, true,
+ nullptr, false, false);
+
+ } else {
+ bool nDimK = false;
+ bool oddDimK = false;
+
+ int64_t ubVal = 16 * blockingFactor;
+ mlir::Value ub = innerLoop.getUpperBound();
+ if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
+ if (auto intAttr =
+ llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+ ubVal = intAttr.getInt();
+ }
+ }
+
+ nDimK = ubVal > 16 * blockingFactor;
+ oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+ rewriter.setInsertionPoint(innerLoop);
+ auto c0 =
+ arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
+ auto spillLoopBound = arith::ConstantIndexOp::create(
+ rewriter, innerLoop.getLoc(), 16 * blockingFactor);
+
+ Value subKloop =
+ arith::SubIOp::create(rewriter, innerLoop.getLoc(),
+ innerLoop.getUpperBound(), spillLoopBound);
+
+ auto bufferType =
+ MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
+
+ // First Shuffling outside the reduction loops
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+ c0);
+ auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
+
+ performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
+ ipType, blockingFactor, bBuffer, c0);
+
+ auto newLoop1 = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(), subKloop,
+ innerLoop.getStep(), loopItrArgs, ipType, opType, blockingFactor,
+ isVnni, vectorOpLhs, vectorOpRhs, contractOp, nullptr, innerLoop,
+ ops, nullptr, bBuffer, true, spillLoopBound, nDimK, oddDimK);
+
+ newLoop = createLoops(rewriter, innerLoop.getLoc(), subKloop,
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ newLoop1.getResults(), ipType, opType,
+ blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+ contractOp, nullptr, innerLoop, ops, nullptr,
+ bBuffer, false, c0, nDimK, oddDimK);
+ }
}
- // post processing after the loop creation.
// Copy the amx tile accumulation results to a MemRef buffer, add the
// initial accumulation value, and store back to the C-Matrix
- auto bufferType = MemRefType::get({16, 16}, opType);
- auto bBuffer =
- memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
- SmallVector<Value> dps = newLoop.getResults();
- for (size_t i = 0; i < ops.size(); i++) {
- vector::ContractionOp contOp = ops[i];
- Operation *resultWriteOp =
- traceToVectorWriteLikeUserOperation(contOp.getResult());
- rewriter.setInsertionPoint(resultWriteOp);
+ if (!isVnni) {
+ Location loc = outerLoop.getLoc();
+ SmallVector<Value> dps = newLoop.getResults();
+ auto bufferType = MemRefType::get({32, 32}, opType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+ for (int i = 0, k = 0; i < 32; i = i + 16) {
+ for (int j = 0; j < 32; j = j + 16) {
+ Value indexOp_i =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), i);
+ Value indexOp_j =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), j);
+ amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+ ValueRange{indexOp_i, indexOp_j}, dps[k]);
+ k++;
+ }
+ }
+ auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+ auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto mBound = arith::ConstantIndexOp::create(rewriter, loc, 32);
- Value indexOp_0 =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+ scf::ForOp::create(
+ rewriter, loc, c0, mBound, one, ValueRange{},
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ auto row = vector::LoadOp::create(rewriter, loc,
+ VectorType::get(16, opType),
+ bBuffer, ValueRange{iv, c0});
- amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
- ValueRange{indexOp_0, indexOp_0}, dps[i]);
+ auto row2 = vector::LoadOp::create(rewriter, loc,
+ VectorType::get(16, opType),
+ bBuffer, ValueRange{iv, c16});
- auto c0 = arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
- auto one =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
- auto mBound =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+ auto shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get(16, opType), row, row2,
+ ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
+ 21, 22, 23});
- scf::ForOp::create(
- rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
- [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
- auto resultAcc = vector::LoadOp::create(
- rewriter, loc, VectorType::get(16, opType), bBuffer,
- ValueRange{iv, c0});
+ auto shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get(16, opType), row, row2,
+ ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
+ 28, 29, 30, 31});
Operation *accReadOp =
- traceToVectorReadLikeParentOperation(ops[i].getAcc());
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
Value srcBuffAcc;
SmallVector<Value> indicesAcc;
@@ -641,24 +1188,119 @@ struct VectorContractToAMXDotProduct
});
});
- Value sum = arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
- indicesAcc[indicesAcc.size() - 2] = sum;
-
- auto acc = vector::LoadOp::create(rewriter, loc,
- VectorType::get(16, opType),
- srcBuffAcc, indicesAcc);
- Value addition;
- if (ipType.isBF16())
- addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
-
- if (ipType.isSignlessInteger(8))
- addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
-
- vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+ indicesAcc[indicesAcc.size() - 2] = iv;
+ indicesAcc[indicesAcc.size() - 1] = c0;
+
+ Value valueCRow1 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
+ indicesAcc);
+ indicesAcc[indicesAcc.size() - 1] = c16;
+
+ Value valueCRow2 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
+ indicesAcc);
+ Value addOp;
+ Value addOp2;
+
+ if (ipType.isBF16()) {
+ addOp =
+ arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
+
+ addOp2 =
+ arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
+ }
+
+ if (ipType.isSignlessInteger(8)) {
+ addOp =
+ arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
+
+ addOp2 =
+ arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
+ }
+ indicesAcc[indicesAcc.size() - 1] = c0;
+ vector::StoreOp::create(rewriter, loc, addOp, srcBuffAcc,
+ indicesAcc);
+ indicesAcc[indicesAcc.size() - 1] = c16;
+ vector::StoreOp::create(rewriter, loc, addOp2, srcBuffAcc,
indicesAcc);
- scf::YieldOp::create(builder, outerLoop.getLoc());
+ scf::YieldOp::create(nestedBuilder, loc);
});
+ }
+ auto bufferType = MemRefType::get({16, 16}, opType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+ SmallVector<Value> dps = newLoop.getResults();
+ for (size_t i = 0; i < ops.size(); i++) {
+ vector::ContractionOp contOp = ops[i];
+ Operation *resultWriteOp =
+ traceToVectorWriteLikeUserOperation(contOp.getResult());
+ if (isVnni) {
+ rewriter.setInsertionPoint(resultWriteOp);
+
+ Value indexOp_0 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+ amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+ ValueRange{indexOp_0, indexOp_0}, dps[i]);
+
+ auto c0 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+ auto one =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+ auto mBound =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+
+ scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
+ [&](OpBuilder &builder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ auto resultAcc = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), bBuffer,
+ ValueRange{iv, c0});
+
+ Operation *accReadOp =
+ traceToVectorReadLikeParentOperation(ops[i].getAcc());
+
+ Value srcBuffAcc;
+ SmallVector<Value> indicesAcc;
+
+ llvm::TypeSwitch<Operation *>(accReadOp)
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ srcBuffAcc = readOp.getOperand(0);
+
+ auto indices = readOp.getIndices();
+ indicesAcc.reserve(indices.size());
+
+ llvm::transform(
+ indices, std::back_inserter(indicesAcc),
+ [&](OpFoldResult ofr) {
+ return mlir::getValueOrCreateConstantIndexOp(
+ rewriter, loc, ofr);
+ });
+ });
+
+ Value sum =
+ arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
+ indicesAcc[0] = sum;
+
+ auto acc = vector::LoadOp::create(rewriter, loc,
+ VectorType::get(16, opType),
+ srcBuffAcc, indicesAcc);
+ Value addition;
+ if (ipType.isBF16())
+ addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
+
+ if (ipType.isSignlessInteger(8))
+ addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
+
+ vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+ indicesAcc);
+
+ scf::YieldOp::create(builder, outerLoop.getLoc());
+ });
+ }
rewriter.eraseOp(resultWriteOp);
}
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index cde15b680a037..151946453df81 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -216,6 +216,122 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8>
+!memrefB = memref<64x32xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+func.func @online_packing_int8(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @online_packing_int8
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: scf.for
+// CHECK: vector.shuffle{{.*}}[0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55] : vector<32xi8>, vector<32xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63] : vector<32xi8>, vector<32xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: x86.amx.tile_muli
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x32xbf16>
+!vecB = vector<1x32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x32xbf16>
+!memrefB = memref<1x32x32xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @online_packing_bf16(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @online_packing_bf16
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: scf.for
+// CHECK: vector.shuffle{{.*}}[0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23] : vector<16xbf16>, vector<16xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31] : vector<16xbf16>, vector<16xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecAB = vector<1x16x16x2xbf16>
!vecC = vector<16x16xf32>
!memrefA = memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
@@ -483,6 +599,199 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x16x32xbf16>
+!vecB = vector<1x32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x32xbf16, strided<[6144, 96, 1], offset: ?>>
+!memrefB = memref<1x32x32xbf16, strided<[12288, 128, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @online_packing_bf16_loop(%arg0: memref<16x64x96xbf16>, %arg1: memref<16x96x128xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c96 = arith.constant 96 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+ memref<64x128xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %7:4 = scf.for %arg10 = %c0 to %c96 step %c32 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10] [1, 32, 32] [1, 1, 1] :
+ memref<16x64x96xbf16> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4] [1, 32, 32] [1, 1, 1] :
+ memref<16x96x128xbf16> to !memrefB
+ %8 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecA
+ %9 = vector.transfer_read %subview_0[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecA
+ %10 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+ %11 = vector.transfer_read %subview_1[%c0, %c0, %c16], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %10, %arg11 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %11, %arg12 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %10, %arg13 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %11, %arg14 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+
+ scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC
+ }
+ scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC
+ }
+ vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xf32>
+ memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+ return %alloc : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @online_packing_bf16_loop
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK-COUNT-4: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8, strided<[256, 1], offset: ?>>
+!memrefB = memref<64x32xi8, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memref<256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> {
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %6:4 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg3, %arg5] [32, 64] [1, 1] : memref<64x256xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4] [64, 32] [1, 1] : memref<256x128xi8> to !memrefB
+ %7 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+ %8 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+ %9 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+ %10 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+ %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %9, %arg6 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %10, %arg7 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %9, %arg8 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %10, %arg9 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ scf.yield %11, %12, %13, %14 : !vecC, !vecC, !vecC, !vecC
+ }
+ vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xi32>
+ memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32>
+ return %alloc : memref<64x128xi32>
+}
+
+// CHECK-LABEL: @online_packing_int8_matmul_loop
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111] : vector<64xi8>, vector<64xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_muli
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x16x16x4xi8>
!vecB = vector<1x16x16x4xi8>
!vecC = vector<16x16xi32>
@@ -637,32 +946,32 @@ module attributes {transform.with_named_sequence} {
// -----
-!vecA = vector<16x64xi8>
-!vecB = vector<64x16xi8>
-!vecC = vector<16x16xi32>
-!memrefA = memref<32x64xi8>
-!memrefB = memref<64x32xi8>
-!memrefC = memref<32x32xi32>
-#map = affine_map<(d1, d2, d3) -> (d1, d3)>
-#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
-#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
-func.func @negative_no_vnni_packed(
+!vecA = vector<1x16x32xbf16>
+!vecB = vector<1x32x32xbf16>
+!vecC = vector<16x32xf32>
+!memrefA = memref<1x32x32xbf16>
+!memrefB = memref<1x32x32xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @negative_wrong_dimensions_online_packing(
%arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
{
%c0 = arith.constant 0 : index
- %0 = ub.poison : i8
- %32 = ub.poison : i32
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
- %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
!memrefA, !vecA
- %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
!memrefB, !vecB
%3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
%4 = vector.contract {
indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction"],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
kind = #vector.kind<add>}
%1, %2, %3 : !vecA, !vecB into !vecC
@@ -671,13 +980,13 @@ func.func @negative_no_vnni_packed(
return %arg2 : !memrefC
}
-// CHECK-LABEL: @negative_no_vnni_packed
-// CHECK-NOT: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
-// CHECK-NOT: x86.amx.tile_muli
-// CHECK-NOT: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-LABEL: @negative_wrong_dimensions_online_packing
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK-NOT: x86.amx.tile_store
// CHECK: vector.contract
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
>From caabcfde4a773121a31a4442e8dcab28c469c47c Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 26 Mar 2026 21:21:44 -0700
Subject: [PATCH 2/5] code refactoring + addition of two -ve test-cases
---
.../VectorContractToAMXDotProduct.cpp | 491 ++++++++++--------
.../X86/AMX/vector-contract-to-tiled-dp.mlir | 151 ++++++
2 files changed, 439 insertions(+), 203 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 2b159e6f59cb9..744c065b4e05e 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -1,4 +1,4 @@
-//===- VectorContractToAMXDotProduct.cpp ----------------------===//
+//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -211,42 +211,41 @@ static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
}
static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
- Type ipType, unsigned int offset, Value bBuffer,
- Value allocStore) {
+ Type ipType, unsigned int offset, Value packedBuffer,
+ Value indxToStoreInBuffer) {
- auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
- SmallVector<Value> vals(subview.getOffsets().size(), c0);
-
- // Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+ auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
+ SmallVector<Value> subviewOffset(subview.getOffsets().size(), c0);
+
Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
- Value nLoadIndx = arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
+ Value offsetIndx =
+ arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
scf::ForOp::create(
rewriter, loc, c0, cBound, cStep, ValueRange{},
[&](OpBuilder &nestedBuilder, Location loc, Value iv,
ValueRange iterArgs) {
- Value i1_load = arith::AddIOp::create(rewriter, loc, nLoadIndx, iv);
-
- vals[vals.size() - 2] = iv;
- ValueRange range1(vals);
+ subviewOffset[subviewOffset.size() - 2] = iv;
auto vec1 = vector::LoadOp::create(
rewriter, loc, VectorType::get((16 * offset), ipType), matB,
- range1);
+ ValueRange(subviewOffset));
- vals[vals.size() - 2] = i1_load;
- ValueRange range2(vals);
+ // Increment the iv by 1 or 2 based on the type to load the next 32/64
+ // elements
+ Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
+ subviewOffset[subviewOffset.size() - 2] = incIV;
auto vec2 = vector::LoadOp::create(
rewriter, loc, VectorType::get((16 * offset), ipType), matB,
- range2);
+ ValueRange(subviewOffset));
vector::ShuffleOp shuffle1;
vector::ShuffleOp shuffle2;
- if (offset == 2) {
+ if (ipType.isBF16()) {
shuffle1 = vector::ShuffleOp::create(
rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
@@ -263,7 +262,7 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
}
- if (offset == 4) {
+ if (ipType.isSignlessInteger(8)) {
shuffle1 = vector::ShuffleOp::create(
rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
@@ -285,13 +284,15 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125,
30, 62, 94, 126, 31, 63, 95, 127});
}
- Value j_pos = arith::DivUIOp::create(rewriter, loc, iv, cStep);
- Value j16_pos = arith::AddIOp::create(rewriter, loc, c16, j_pos);
- vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
- ValueRange{allocStore, j_pos, c0});
- vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
- ValueRange{allocStore, j16_pos, c0});
+ // iv to store the shuffled elements
+ Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
+ Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
+
+ vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
+ ValueRange{indxToStoreInBuffer, ivShuff1, c0});
+ vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
+ ValueRange{indxToStoreInBuffer, ivShuff2, c0});
scf::YieldOp::create(nestedBuilder, loc);
});
@@ -300,9 +301,12 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
static llvm::DenseMap<Operation *, amx::TileLoadOp>
packInputs(OpBuilder &rewriter, Location loc,
SmallVector<vector::ContractionOp> ops, Value matB, Type ipType,
- unsigned int offset, Value bBuffer, bool pack, Value allocStore,
- Value addIdx) {
+ unsigned int offset, Value packedBuffer, bool pack,
+ Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
+
llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
for (size_t j = 0; j < ops.size(); j++) {
for (size_t i = 0; i < ops.size(); i++) {
@@ -315,25 +319,23 @@ packInputs(OpBuilder &rewriter, Location loc,
continue;
}
- Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
- Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
-
if (pack) {
- performShuffle(rewriter, loc, matB, ipType, offset, bBuffer,
- allocStore);
+ performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
+ indxToStoreInBuffer);
}
amx::TileType tileType =
amx::TileType::get({16, (16 * offset)}, ipType);
- auto load = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
- ValueRange{addIdx, c0, c0});
-
- auto load1 = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
- ValueRange{addIdx, c16, c0});
+ auto loadRow1 =
+ amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
+ ValueRange{indxToLoadFromMatB, c0, c0});
- readsToTileLoads.try_emplace(readOpRhs, load);
+ auto loadRow2 =
+ amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
+ ValueRange{indxToLoadFromMatB, c16, c0});
- readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), load1);
+ readsToTileLoads.try_emplace(readOpRhs, loadRow1);
+ readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
}
}
}
@@ -342,13 +344,12 @@ packInputs(OpBuilder &rewriter, Location loc,
}
// Creates tiled amx dot-products.
-static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
- SmallVector<vector::ContractionOp> ops,
- Value matA, Value matB, Type ipType,
- Type opType, ValueRange accIterArgs,
- unsigned int offset, bool isVnni,
- Value bBuffer, bool pack,
- Value allocStore, Value addIdx) {
+static SmallVector<Value>
+createTiledDp(OpBuilder &rewriter, Location loc,
+ SmallVector<vector::ContractionOp> ops, Value matA, Value matB,
+ Type ipType, Type opType, ValueRange accIterArgs,
+ unsigned int offset, bool isVnni, Value packedBuffer, bool pack,
+ Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
if (isVnni) {
matA = collapseInnerDims(rewriter, loc, matA);
@@ -360,10 +361,11 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
// or load operations.
llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
- // function call to make the flat.
+ // function call to online pack the input B matrix
if (!isVnni) {
- readsToTileLoads = packInputs(rewriter, loc, ops, matB, ipType, offset,
- bBuffer, pack, allocStore, addIdx);
+ readsToTileLoads =
+ packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
+ indxToStoreInBuffer, indxToLoadFromMatB);
}
// Iterate over the contraction operations and compute the tiled dot-product.
@@ -422,36 +424,37 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
return loopItrArgs;
}
-static Value bufferIndxToStore(OpBuilder &rewriter, Location loc, Value iv_K,
- Value iv_red, bool oddDimK, bool nDimK,
- bool pack, unsigned int blockingFactor) {
+static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
+ Value ivInnerLoop, Value ivOuterLoop,
+ bool isInnerLoopUBHasOddQuot,
+ bool isInnerLoopUBLarger, bool pack,
+ unsigned int blockingFactor) {
Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
-
- Value nLoadIndx =
+ Value packOffset =
arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
- Value quotient_K = arith::DivUIOp::create(rewriter, loc, iv_K, nLoadIndx);
- Value rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
- quotient_K, c2);
+ Value quotientInnerLoop =
+ arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
+ Value remInnerLoop = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), quotientInnerLoop, c2);
- if (!nDimK && !pack) {
- rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
- iv_red, c2);
+ if (!isInnerLoopUBLarger && !pack) {
+ remInnerLoop = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
}
// if K quotient is odd. Then, BR loop iv is taken
// into consideration
- if (oddDimK) {
- auto rem_BR = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
- iv_red, c2);
+ if (isInnerLoopUBHasOddQuot) {
+ auto remOuterLoop = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
- rem_K, rem_BR);
- rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
- remAdd, c2);
+ remInnerLoop, remOuterLoop);
+ remInnerLoop = arith::RemUIOp::create(rewriter, loc,
+ rewriter.getIndexType(), remAdd, c2);
}
-
- return rem_K;
+ return remInnerLoop;
}
static scf::ForOp
@@ -461,10 +464,15 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
Operation *vectorOpLhs, Operation *vectorOpRhs,
vector::ContractionOp contractOp, scf::ForOp outerLoop,
scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
- Value ivOuterLoop, Value bBuffer, bool pack,
- arith::ConstantIndexOp innerLoopIndex, bool nDimK, bool oddDimK) {
+ Value ivOuterLoop, Value packedBuffer, bool pack,
+ arith::ConstantIndexOp innerLoopIndex, bool isInnerLoopUBLarger,
+ bool isInnerLoopUBHasOddQuot) {
- auto newLoop1 = scf::ForOp::create(
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+
+ auto newLoop = scf::ForOp::create(
rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
[&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
@@ -479,11 +487,8 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
ivNewInnerLoop);
auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
- Value c0 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 0);
- Value c1 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 1);
- Value c2 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 2);
- Value allocStore = c0;
- Value allocGet = c0;
+ Value indxToStoreInBuffer = c0;
+ Value indxToLoadFromBuffer = c0;
if (!isVnni) {
if (outerLoop) {
@@ -493,16 +498,17 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
c1, ivOuterLoop);
- if (!nDimK || oddDimK) {
- allocStore = arith::RemUIOp::create(rewriter, locNewInnerLoop,
- rewriter.getIndexType(),
- ivOuterLoop, c2);
+ if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
+ indxToStoreInBuffer = arith::RemUIOp::create(
+ rewriter, locNewInnerLoop, rewriter.getIndexType(),
+ ivOuterLoop, c2);
}
- Value addIdx =
- arith::AddIOp::create(rewriter, loc, allocStore, c1);
- allocGet = arith::RemUIOp::create(
- rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+ Value indxToLoadFromMatB = arith::AddIOp::create(
+ rewriter, loc, indxToStoreInBuffer, c1);
+ indxToLoadFromBuffer = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), indxToLoadFromMatB,
+ c2);
}
} else {
@@ -510,13 +516,15 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
rewriter, locNewInnerLoop, (16 * blockingFactor));
ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
nLoadIndx, ivNewInnerLoop);
- allocStore =
+ indxToStoreInBuffer =
bufferIndxToStore(rewriter, loc, ivNewInnerLoop, ivOuterLoop,
- oddDimK, nDimK, pack, blockingFactor);
- Value addIdx =
- arith::AddIOp::create(rewriter, loc, allocStore, c1);
- allocGet = arith::RemUIOp::create(
- rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+ isInnerLoopUBHasOddQuot,
+ isInnerLoopUBLarger, pack, blockingFactor);
+ Value indxToLoadFromMatB =
+ arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
+ indxToLoadFromBuffer =
+ arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ indxToLoadFromMatB, c2);
}
} else {
if (pack) {
@@ -526,13 +534,14 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
nLoadIndx, ivNewInnerLoop);
Value quotient_K = arith::DivUIOp::create(
rewriter, loc, ivNewInnerLoop, nLoadIndx);
- allocStore = arith::RemUIOp::create(
+ indxToStoreInBuffer = arith::RemUIOp::create(
rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
- Value addIdx =
- arith::AddIOp::create(rewriter, loc, allocStore, c1);
- allocGet = arith::RemUIOp::create(
- rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+ Value indxToLoadFromMatB =
+ arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
+ indxToLoadFromBuffer =
+ arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ indxToLoadFromMatB, c2);
}
}
}
@@ -558,10 +567,11 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
Value nLoadIndx = arith::ConstantIndexOp::create(
rewriter, locNewInnerLoop, (16 * blockingFactor));
matB = Value();
- allocGet = c0;
- allocGet =
+ indxToLoadFromBuffer = c0;
+ indxToLoadFromBuffer =
bufferIndxToStore(rewriter, loc, nLoadIndx, ivOuterLoop,
- oddDimK, nDimK, pack, blockingFactor);
+ isInnerLoopUBHasOddQuot,
+ isInnerLoopUBLarger, pack, blockingFactor);
}
} else {
if (!pack) {
@@ -570,28 +580,30 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
matB = Value();
Value quotient_K = arith::DivUIOp::create(
rewriter, loc, ivNewInnerLoop, nLoadIndx);
- allocGet = arith::RemUIOp::create(
+ indxToLoadFromBuffer = arith::RemUIOp::create(
rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
}
}
}
+ // compute tiled dot-product
SmallVector<Value> accumulators = createTiledDp(
rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
- bBuffer, pack, allocStore, allocGet);
+ packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
accumulators);
});
- return newLoop1;
+ return newLoop;
}
// Implements tiled dot-product operation for a vector.contract operation or a
// sequence of vector.contracts inside the reduction loops.
//
-// For example - for F32 type:
+// For example:
+// Case 1: register blocked vector.contract with prepacked input
// ```
// vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
// vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
@@ -605,6 +617,52 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
// amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
// amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
// ```
+//
+//
+// Case2: vector.contract with register blocked
+//
+// Output IR with online packing (with s/w pipeline advantage):
+// s/w pipeline: load, pack to VNNI, and store the B sub matrix
+// of the 0th batch-reduce and K iteration.
+// scf.for (0 to 31) {
+// - load 0th and 1st vector<32xbf16>, pack into VNNI, store the
+// first shuffle in 0th and 2nd shuffle in 16th index of the
+// buffer.
+// }
+// scf.for (0 to br-2) { batch-reduce loop
+// scf.for (0 to k-2) { K loop
+// - load A matrix
+// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
+// iteration from the buffer (d) compute the tiled dot-product
+// }
+// Last iteration of the the K Loop (k-1) {
+// - load A matrix
+// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+// matrix for the next batch-reduce + K loop iteration (c) load VNNI pack B
+// matrix of K iteration from the buffer (d) compute the tiled dot-product
+// }
+// }
+// Last iteration of the batch-reduce loop (br-1) {
+// scf.for (0 to k-2) { K loop
+// - load A matrix
+// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
+// iteration from the buffer (d) compute the tiled dot-product
+// }
+// Last iteration of the the K Loop (k-1) {
+// - load A matrix
+// - load VNNI pack B matrix of K iteration from the buffer
+// - compute the tiled dot-product
+// }
+// }
+//
+// scf.for (0 to M)
+// scf.for (0 to N)
+// - Load the ith and i+1th acc
+// - Shuffle them as we packed using vpunpack
+// - Load C matrix and do arith.add with the shuffle
+// - Store back into C matrix
struct VectorContractToAMXDotProduct
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -727,7 +785,6 @@ struct VectorContractToAMXDotProduct
srcBuffLhs, indicesLhs);
// Create the subview and then load.
- //
amx::TileLoadOp loadRhs;
if (!isVnni) {
VectorType vecTy;
@@ -746,12 +803,10 @@ struct VectorContractToAMXDotProduct
auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
indexVals, sizes, strides);
auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
- auto bBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
-
- // create a loop that swaps them.
+ auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+ // create a loop that does online packing.
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
-
Value step =
arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
@@ -817,14 +872,14 @@ struct VectorContractToAMXDotProduct
auto rem = arith::RemUIOp::create(
rewriter, loc, rewriter.getIndexType(), iv, step);
- vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
+ vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
ValueRange{rem, c0});
- vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
+ vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
ValueRange{rem, nextStoreIndx});
scf::YieldOp::create(nestedBuilder, loc);
});
- loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+ loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
ValueRange{c0, c0});
} else {
@@ -915,6 +970,22 @@ struct VectorContractToAMXDotProduct
}
}
+ if (!isVnni) {
+ unsigned int pairCount = 0;
+ for (size_t j = 0; j < ops.size(); j++) {
+ for (size_t i = j; i < ops.size(); i++) {
+
+ if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
+ pairCount = pairCount + 2;
+ }
+ }
+ }
+
+ if (pairCount != ops.size())
+ return rewriter.notifyMatchFailure(
+ contractOp, "Coudn't find the pair vector contract ");
+ }
+
scf::ForOp innerLoop;
scf::ForOp outerLoop;
@@ -946,8 +1017,8 @@ struct VectorContractToAMXDotProduct
} else {
- bool nDimK = false;
- bool oddDimK = false;
+ bool isInnerLoopUBLarger = false;
+ bool isInnerLoopUBHasOddQuot = false;
int64_t ubVal = 16 * blockingFactor;
mlir::Value ub = innerLoop.getUpperBound();
@@ -958,27 +1029,27 @@ struct VectorContractToAMXDotProduct
}
}
- nDimK = ubVal > 16 * blockingFactor;
- oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+ isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
+ isInnerLoopUBHasOddQuot =
+ (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
rewriter.setInsertionPoint(outerLoop);
auto c0 =
arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
-
auto c1 =
arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
-
auto spillLoopBound = arith::ConstantIndexOp::create(
rewriter, outerLoop.getLoc(), 16 * blockingFactor);
- Value subBRLoop = arith::SubIOp::create(rewriter, outerLoop.getLoc(),
- outerLoop.getUpperBound(), c1);
- Value subKloop =
+
+ Value spillOuterLoop = arith::SubIOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getUpperBound(), c1);
+ Value spillInnerLoop =
arith::SubIOp::create(rewriter, innerLoop.getLoc(),
innerLoop.getUpperBound(), spillLoopBound);
auto bufferType =
MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
- auto bBuffer =
+ auto packedBuffer =
memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
// First Shuffling outside the reduction loops
@@ -994,28 +1065,29 @@ struct VectorContractToAMXDotProduct
auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
- ipType, blockingFactor, bBuffer, c0);
+ ipType, blockingFactor, packedBuffer, c0);
// First Set of Loops
- auto newLoop1 = scf::ForOp::create(
- rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(), subBRLoop,
- outerLoop.getStep(), loopItrArgs,
+ auto newLoopNonSpill = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+ spillOuterLoop, outerLoop.getStep(), loopItrArgs,
[&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
auto newInnerLoop1 = createLoops(
rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
- subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
- opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
- contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
- true, spillLoopBound, nDimK, oddDimK);
+ spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
+ ipType, opType, blockingFactor, isVnni, vectorOpLhs,
+ vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
+ ivOuterLoop, packedBuffer, true, spillLoopBound,
+ isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
- auto newInnerLoop =
- createLoops(rewriter, innerLoop.getLoc(), subKloop,
- innerLoop.getUpperBound(), innerLoop.getStep(),
- newInnerLoop1.getResults(), ipType, opType,
- blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
- contractOp, outerLoop, innerLoop, ops,
- ivOuterLoop, bBuffer, true, c0, nDimK, oddDimK);
+ auto newInnerLoop = createLoops(
+ rewriter, innerLoop.getLoc(), spillInnerLoop,
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ newInnerLoop1.getResults(), ipType, opType, blockingFactor,
+ isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
+ innerLoop, ops, ivOuterLoop, packedBuffer, true, c0,
+ isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
newInnerLoop.getResults());
@@ -1023,24 +1095,26 @@ struct VectorContractToAMXDotProduct
// Last set of Loops
newLoop = scf::ForOp::create(
- rewriter, outerLoop.getLoc(), subBRLoop, outerLoop.getUpperBound(),
- outerLoop.getStep(), newLoop1.getResults(),
+ rewriter, outerLoop.getLoc(), spillOuterLoop,
+ outerLoop.getUpperBound(), outerLoop.getStep(),
+ newLoopNonSpill.getResults(),
[&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
auto newInnerLoop1 = createLoops(
rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
- subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
- opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
- contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
- true, spillLoopBound, nDimK, oddDimK);
+ spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
+ ipType, opType, blockingFactor, isVnni, vectorOpLhs,
+ vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
+ ivOuterLoop, packedBuffer, true, spillLoopBound,
+ isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
- auto newInnerLoop =
- createLoops(rewriter, innerLoop.getLoc(), subKloop,
- innerLoop.getUpperBound(), innerLoop.getStep(),
- newInnerLoop1.getResults(), ipType, opType,
- blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
- contractOp, outerLoop, innerLoop, ops,
- ivOuterLoop, bBuffer, false, c0, nDimK, oddDimK);
+ auto newInnerLoop = createLoops(
+ rewriter, innerLoop.getLoc(), spillInnerLoop,
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ newInnerLoop1.getResults(), ipType, opType, blockingFactor,
+ isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
+ innerLoop, ops, ivOuterLoop, packedBuffer, false, c0,
+ isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
newInnerLoop.getResults());
@@ -1065,8 +1139,8 @@ struct VectorContractToAMXDotProduct
nullptr, false, false);
} else {
- bool nDimK = false;
- bool oddDimK = false;
+ bool isInnerLoopUBLarger = false;
+ bool isInnerLoopUBHasOddQuot = false;
int64_t ubVal = 16 * blockingFactor;
mlir::Value ub = innerLoop.getUpperBound();
@@ -1077,21 +1151,23 @@ struct VectorContractToAMXDotProduct
}
}
- nDimK = ubVal > 16 * blockingFactor;
- oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+ isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
+ isInnerLoopUBHasOddQuot =
+ (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
+
rewriter.setInsertionPoint(innerLoop);
auto c0 =
arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
auto spillLoopBound = arith::ConstantIndexOp::create(
rewriter, innerLoop.getLoc(), 16 * blockingFactor);
- Value subKloop =
+ Value spillInnerLoop =
arith::SubIOp::create(rewriter, innerLoop.getLoc(),
innerLoop.getUpperBound(), spillLoopBound);
auto bufferType =
MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
- auto bBuffer =
+ auto packedBuffer =
memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
// First Shuffling outside the reduction loops
@@ -1103,20 +1179,22 @@ struct VectorContractToAMXDotProduct
auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
- ipType, blockingFactor, bBuffer, c0);
+ ipType, blockingFactor, packedBuffer, c0);
- auto newLoop1 = createLoops(
- rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(), subKloop,
- innerLoop.getStep(), loopItrArgs, ipType, opType, blockingFactor,
- isVnni, vectorOpLhs, vectorOpRhs, contractOp, nullptr, innerLoop,
- ops, nullptr, bBuffer, true, spillLoopBound, nDimK, oddDimK);
+ auto newLoopNonSpill = createLoops(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
+ blockingFactor, isVnni, vectorOpLhs, vectorOpRhs, contractOp,
+ nullptr, innerLoop, ops, nullptr, packedBuffer, true,
+ spillLoopBound, isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
- newLoop = createLoops(rewriter, innerLoop.getLoc(), subKloop,
+ newLoop = createLoops(rewriter, innerLoop.getLoc(), spillInnerLoop,
innerLoop.getUpperBound(), innerLoop.getStep(),
- newLoop1.getResults(), ipType, opType,
+ newLoopNonSpill.getResults(), ipType, opType,
blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
contractOp, nullptr, innerLoop, ops, nullptr,
- bBuffer, false, c0, nDimK, oddDimK);
+ packedBuffer, false, c0, isInnerLoopUBLarger,
+ isInnerLoopUBHasOddQuot);
}
}
@@ -1125,17 +1203,41 @@ struct VectorContractToAMXDotProduct
if (!isVnni) {
Location loc = outerLoop.getLoc();
+ Operation *accReadOp =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+ Value srcBuffAcc;
+ SmallVector<Value> indicesAcc;
+
+ llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) {
+ srcBuffAcc = readOp.getOperand(0);
+
+ auto indices = readOp.getIndices();
+ indicesAcc.reserve(indices.size());
+
+ llvm::transform(indices, std::back_inserter(indicesAcc),
+ [&](OpFoldResult ofr) {
+ return mlir::getValueOrCreateConstantIndexOp(
+ rewriter, loc, ofr);
+ });
+ });
+
+ auto outputShapes =
+ mlir::cast<mlir::MemRefType>(srcBuffAcc.getType()).getShape();
+ unsigned int M = outputShapes[outputShapes.size() - 2];
+ unsigned int N = outputShapes[outputShapes.size() - 1];
+
SmallVector<Value> dps = newLoop.getResults();
- auto bufferType = MemRefType::get({32, 32}, opType);
- auto bBuffer =
- memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
- for (int i = 0, k = 0; i < 32; i = i + 16) {
- for (int j = 0; j < 32; j = j + 16) {
- Value indexOp_i =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), i);
- Value indexOp_j =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), j);
- amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+ auto bufferType = MemRefType::get({M, N}, opType);
+ auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+
+ // Store the amx tiled-dot product output into an MxN memref.
+ for (unsigned int i = 0, k = 0; i < M; i = i + 16) {
+ for (unsigned int j = 0; j < N; j = j + 16) {
+ Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+ Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+ amx::TileStoreOp::create(rewriter, loc, resultBuffer,
ValueRange{indexOp_i, indexOp_j}, dps[k]);
k++;
}
@@ -1143,19 +1245,21 @@ struct VectorContractToAMXDotProduct
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
- auto mBound = arith::ConstantIndexOp::create(rewriter, loc, 32);
+ auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
+ // Create a loop that iterates over the MxN memerf, retrives two rows +
+ // shuffle them, add up the C element values and stores them back.
scf::ForOp::create(
rewriter, loc, c0, mBound, one, ValueRange{},
[&](OpBuilder &nestedBuilder, Location loc, Value iv,
ValueRange iterArgs) {
auto row = vector::LoadOp::create(rewriter, loc,
VectorType::get(16, opType),
- bBuffer, ValueRange{iv, c0});
+ resultBuffer, ValueRange{iv, c0});
- auto row2 = vector::LoadOp::create(rewriter, loc,
- VectorType::get(16, opType),
- bBuffer, ValueRange{iv, c16});
+ auto row2 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), resultBuffer,
+ ValueRange{iv, c16});
auto shuffle1 = vector::ShuffleOp::create(
rewriter, loc, VectorType::get(16, opType), row, row2,
@@ -1167,27 +1271,6 @@ struct VectorContractToAMXDotProduct
ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
28, 29, 30, 31});
- Operation *accReadOp =
- traceToVectorReadLikeParentOperation(contractOp.getAcc());
-
- Value srcBuffAcc;
- SmallVector<Value> indicesAcc;
-
- llvm::TypeSwitch<Operation *>(accReadOp)
- .Case<TransferReadOp, LoadOp>([&](auto readOp) {
- srcBuffAcc = readOp.getOperand(0);
-
- auto indices = readOp.getIndices();
- indicesAcc.reserve(indices.size());
-
- llvm::transform(
- indices, std::back_inserter(indicesAcc),
- [&](OpFoldResult ofr) {
- return mlir::getValueOrCreateConstantIndexOp(rewriter,
- loc, ofr);
- });
- });
-
indicesAcc[indicesAcc.size() - 2] = iv;
indicesAcc[indicesAcc.size() - 1] = c0;
@@ -1199,6 +1282,7 @@ struct VectorContractToAMXDotProduct
Value valueCRow2 = vector::LoadOp::create(
rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
indicesAcc);
+
Value addOp;
Value addOp2;
@@ -1227,11 +1311,12 @@ struct VectorContractToAMXDotProduct
scf::YieldOp::create(nestedBuilder, loc);
});
}
+
auto bufferType = MemRefType::get({16, 16}, opType);
- auto bBuffer =
+ auto resultBuffer =
memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
-
SmallVector<Value> dps = newLoop.getResults();
+
for (size_t i = 0; i < ops.size(); i++) {
vector::ContractionOp contOp = ops[i];
Operation *resultWriteOp =
@@ -1242,7 +1327,7 @@ struct VectorContractToAMXDotProduct
Value indexOp_0 =
arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
- amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+ amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), resultBuffer,
ValueRange{indexOp_0, indexOp_0}, dps[i]);
auto c0 =
@@ -1257,7 +1342,7 @@ struct VectorContractToAMXDotProduct
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange iterArgs) {
auto resultAcc = vector::LoadOp::create(
- rewriter, loc, VectorType::get(16, opType), bBuffer,
+ rewriter, loc, VectorType::get(16, opType), resultBuffer,
ValueRange{iv, c0});
Operation *accReadOp =
@@ -1283,7 +1368,7 @@ struct VectorContractToAMXDotProduct
Value sum =
arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
- indicesAcc[0] = sum;
+ indicesAcc[indicesAcc.size() - 2] = sum;
auto acc = vector::LoadOp::create(rewriter, loc,
VectorType::get(16, opType),
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 151946453df81..6bb90a80da66e 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -1341,3 +1341,154 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x64xi8, strided<[256, 1], offset: ?>>
+!memrefB = memref<64x32xi8, strided<[128, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @negative_vc_wrong_order_no_pair(%arg0: memref<64x256xi8>, %arg1: memref<256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1]
+ : memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+ %4:2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg3, %arg5] [16, 64] [1, 1]
+ : memref<64x256xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4] [64, 32] [1, 1]
+ : memref<256x128xi8> to !memrefB
+ %5 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]}
+ : !memrefA, !vecA
+ %6 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]}
+ : !memrefB, !vecB
+ %7 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]}
+ : !memrefB, !vecB
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %7, %arg7 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %6, %arg6 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ scf.yield %9, %8 : !vecC, !vecC
+ }
+ vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]}
+ : !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]}
+ : !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xi32>
+ memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32>
+ return %alloc : memref<64x128xi32>
+}
+
+
+// CHECK-LABEL: @negative_vc_wrong_order_no_pair
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8, strided<[256, 1], offset: ?>>
+!memrefB = memref<64x16xi8, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x16xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @negative_vc_no_pair(%arg0: memref<64x256xi8>, %arg1: memref<256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c32 = arith.constant 32 : index
+ %c16 = arith.constant 16 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c16 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 16] [1, 1] : memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+ %4:2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg3, %arg5] [32, 64] [1, 1]
+ : memref<64x256xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4] [64, 16] [1, 1]
+ : memref<256x128xi8> to !memrefB
+ %5 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]}
+ : !memrefA, !vecA
+ %6 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]}
+ : !memrefA, !vecA
+ %7 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]}
+ : !memrefB, !vecB
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %7, %arg6 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %6, %7, %arg7 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ scf.yield %8, %9 : !vecC, !vecC
+ }
+ vector.transfer_write %4#1, %subview[%c16, %c0] {in_bounds = [true, true]}
+ : !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]}
+ : !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xi32>
+ memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32>
+ return %alloc : memref<64x128xi32>
+}
+
+// CHECK-LABEL: @negative_vc_no_pair
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From ed5e583ca32dc5c610caa6b09f7d2078051b7bb5 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 27 Mar 2026 08:50:28 -0700
Subject: [PATCH 3/5] code refactor.
---
.../VectorContractToAMXDotProduct.cpp | 53 ++++++++-----------
1 file changed, 23 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 744c065b4e05e..63e187613bfc4 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -70,9 +70,8 @@ getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
if (!srcBuff)
return failure();
- if (isNotAcc) {
+ if (isNotAcc)
indexVals.pop_back();
- }
SmallVector<Value> indices;
indices.reserve(indexVals.size());
@@ -206,8 +205,7 @@ static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
}
amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
- auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
- return load;
+ return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
}
static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
@@ -424,11 +422,10 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
return loopItrArgs;
}
-static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
- Value ivInnerLoop, Value ivOuterLoop,
- bool isInnerLoopUBHasOddQuot,
- bool isInnerLoopUBLarger, bool pack,
- unsigned int blockingFactor) {
+static Value getIndxToLoadStoreFromPckBuffer(
+ OpBuilder &rewriter, Location loc, Value ivInnerLoop, Value ivOuterLoop,
+ bool isInnerLoopUBHasOddQuot, bool isInnerLoopUBLarger, bool pack,
+ unsigned int blockingFactor) {
Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
Value packOffset =
@@ -444,8 +441,6 @@ static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
}
- // if K quotient is odd. Then, BR loop iv is taken
- // into consideration
if (isInnerLoopUBHasOddQuot) {
auto remOuterLoop = arith::RemUIOp::create(
rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
@@ -454,6 +449,7 @@ static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
remInnerLoop = arith::RemUIOp::create(rewriter, loc,
rewriter.getIndexType(), remAdd, c2);
}
+
return remInnerLoop;
}
@@ -477,11 +473,11 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
[&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
IRMapping mapping;
- if (outerLoop) {
+ if (outerLoop)
mapping.map(vectorOpLhs->getOperand(
getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
ivOuterLoop);
- }
+
mapping.map(vectorOpLhs->getOperand(
getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
ivNewInnerLoop);
@@ -516,10 +512,10 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
rewriter, locNewInnerLoop, (16 * blockingFactor));
ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
nLoadIndx, ivNewInnerLoop);
- indxToStoreInBuffer =
- bufferIndxToStore(rewriter, loc, ivNewInnerLoop, ivOuterLoop,
- isInnerLoopUBHasOddQuot,
- isInnerLoopUBLarger, pack, blockingFactor);
+ indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
+ rewriter, loc, ivNewInnerLoop, ivOuterLoop,
+ isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
+ blockingFactor);
Value indxToLoadFromMatB =
arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
indxToLoadFromBuffer =
@@ -547,12 +543,12 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
}
IRMapping rhsMapping;
- if (outerLoop) {
+ if (outerLoop)
rhsMapping.map(
vectorOpRhs->getOperand(
getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
ivOuterLoop);
- }
+
rhsMapping.map(
vectorOpRhs->getOperand(
getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
@@ -568,10 +564,10 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
rewriter, locNewInnerLoop, (16 * blockingFactor));
matB = Value();
indxToLoadFromBuffer = c0;
- indxToLoadFromBuffer =
- bufferIndxToStore(rewriter, loc, nLoadIndx, ivOuterLoop,
- isInnerLoopUBHasOddQuot,
- isInnerLoopUBLarger, pack, blockingFactor);
+ indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
+ rewriter, loc, nLoadIndx, ivOuterLoop,
+ isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
+ blockingFactor);
}
} else {
if (!pack) {
@@ -709,13 +705,8 @@ struct VectorContractToAMXDotProduct
"transfer_read or a load. And, the result should be "
"stored using transfer_write or store.");
- Type ipType;
- Type opType;
-
- if (lhsTy.getElementType().isBF16()) {
- ipType = rewriter.getBF16Type();
- opType = rewriter.getF32Type();
- }
+ Type ipType = rewriter.getBF16Type();
+ Type opType = rewriter.getF32Type();
if (lhsTy.getElementType().isSignlessInteger(8)) {
ipType = rewriter.getIntegerType(8);
@@ -910,6 +901,7 @@ struct VectorContractToAMXDotProduct
// Case 2: The acc are passed as iter args through the reduction loop.
// We support, reduction loop depth until 2. TODO: Support for n-depth
// reduction loop.
+ // TODOs: Re-factor 2a and 2b.
SmallVector<scf::ForOp> loopLists;
Operation *current = contractOp;
while (true) {
@@ -1122,6 +1114,7 @@ struct VectorContractToAMXDotProduct
}
}
+ // Case 2b: Reduction loop depth is 1.
if (loopLists.size() == 1) {
outerLoop = loopLists[0];
innerLoop = loopLists[0];
>From 35ecac82ba288eb9c49124e0dcb2161801e8895f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 30 Mar 2026 08:46:25 -0700
Subject: [PATCH 4/5] minor change to shuffling order for int8
---
.../Transforms/VectorContractToAMXDotProduct.cpp | 16 ++++++++--------
.../X86/AMX/vector-contract-to-tiled-dp.mlir | 4 ++--
2 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 63e187613bfc4..eb35720b657ea 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -267,19 +267,19 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
vec2,
ArrayRef<int64_t>{
0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
- 35, 67, 99, 4, 36, 68, 100, 5, 37, 69, 101, 6, 38,
- 70, 102, 7, 39, 71, 103, 8, 40, 72, 104, 9, 41, 73,
- 105, 10, 42, 74, 106, 11, 43, 75, 107, 12, 44, 76, 108,
- 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111});
+ 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42,
+ 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81,
+ 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120,
+ 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123});
shuffle2 = vector::ShuffleOp::create(
rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
vec2,
ArrayRef<int64_t>{
- 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51,
- 83, 115, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118,
- 23, 55, 87, 119, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58,
- 90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125,
+ 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39,
+ 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110,
+ 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54,
+ 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125,
30, 62, 94, 126, 31, 63, 95, 127});
}
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 6bb90a80da66e..1a6deed31eceb 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -770,8 +770,8 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memr
// CHECK-LABEL: @online_packing_int8_matmul_loop
// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
-// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111] : vector<64xi8>, vector<64xi8>
-// CHECK-NEXT: vector.shuffle{{.*}}[16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xi8>, vector<64xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
// CHECK: x86.amx.tile_load
// CHECK: x86.amx.tile_muli
// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
>From b1b4c569ab021bda2906ee8601687c1721d6d289 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 8 Apr 2026 11:08:57 -0700
Subject: [PATCH 5/5] remove un-necessary braces
---
.../X86/Transforms/VectorContractToAMXDotProduct.cpp | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index eb35720b657ea..cc66308e98260 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -966,10 +966,8 @@ struct VectorContractToAMXDotProduct
unsigned int pairCount = 0;
for (size_t j = 0; j < ops.size(); j++) {
for (size_t i = j; i < ops.size(); i++) {
-
- if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
+ if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16))
pairCount = pairCount + 2;
- }
}
}
@@ -1116,7 +1114,6 @@ struct VectorContractToAMXDotProduct
// Case 2b: Reduction loop depth is 1.
if (loopLists.size() == 1) {
- outerLoop = loopLists[0];
innerLoop = loopLists[0];
SmallVector<Value> loopItrArgs = createTileZeros(
@@ -1189,6 +1186,10 @@ struct VectorContractToAMXDotProduct
packedBuffer, false, c0, isInnerLoopUBLarger,
isInnerLoopUBHasOddQuot);
}
+
+ // This helps the final store back to the acc uses the same code for
+ // the both reduction loop depth 1 or 2.
+ outerLoop = innerLoop;
}
// Copy the amx tile accumulation results to a MemRef buffer, add the
More information about the Mlir-commits
mailing list