[Mlir-commits] [mlir] [mlir][x86] Fix - multiple issues / F8 support for AMX dot-product lowering (PR #196984)
Arun Thangamani
llvmlistbot at llvm.org
Mon May 25 09:18:38 PDT 2026
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/196984
>From 91394a7e8e01385a767407aa2340cf9e2486bc4c Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 11 May 2026 09:02:23 -0700
Subject: [PATCH 01/11] fixex issues with AMX dot-product lowering
---
.../VectorContractToAMXDotProduct.cpp | 391 +++++++++---------
1 file changed, 202 insertions(+), 189 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 94b94292e675f..551fccb47e114 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -27,6 +27,31 @@ using namespace mlir::x86;
namespace {
+static Value traceToVectorWriteLikeUserOperationForAMX(Value v) {
+ if (v.getNumUses() > 1)
+ return nullptr;
+
+ for (OpOperand &use : v.getUses()) {
+ Operation *user = use.getOwner();
+
+ if (!isa<scf::YieldOp>(user)) {
+ return v;
+ }
+
+ // --- SCF YIELD ---
+ if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+ Operation *parent = yield->getParentOp();
+ unsigned idx = use.getOperandNumber();
+ if (auto res =
+ traceToVectorWriteLikeUserOperationForAMX(parent->getResult(idx)))
+ return res;
+ continue;
+ }
+ }
+
+ return nullptr;
+}
+
// Function to collapse the last two dimension (vnni and k) to help the
// amx.tile_load to correctly load the packed element type.
static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
@@ -216,22 +241,30 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
- SmallVector<Value> subviewOffset(subview.getOffsets().size(), c0);
+ SmallVector<Value> subviewOffset(subview.getMixedOffsets().size(), c0);
Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
Value offsetIndx =
arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
+ // llvm::outs() << "check-a:" << matB << " subview:" << subview << "\n";
+ // llvm::outs() << "The size:" << subview.getMixedOffsets().size() << "\n";
+
scf::ForOp::create(
rewriter, loc, c0, cBound, cStep, ValueRange{},
[&](OpBuilder &nestedBuilder, Location loc, Value iv,
ValueRange iterArgs) {
+ // llvm::outs() << "check-a0" << subviewOffset.size() << "\n";
subviewOffset[subviewOffset.size() - 2] = iv;
+
+ // llvm::outs() << "check-a1" << "\n";
auto vec1 = vector::LoadOp::create(
rewriter, loc, VectorType::get((16 * offset), ipType), matB,
ValueRange(subviewOffset));
+ // llvm::outs() << "check-b" << "\n";
+
// Increment the iv by 1 or 2 based on the type to load the next 32/64
// elements
Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
@@ -243,6 +276,8 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
vector::ShuffleOp shuffle1;
vector::ShuffleOp shuffle2;
+ // llvm::outs() << "check-c" << "\n";
+
if (ipType.isBF16()) {
shuffle1 = vector::ShuffleOp::create(
@@ -283,6 +318,8 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
30, 62, 94, 126, 31, 63, 95, 127});
}
+ // llvm::outs() << "check-d" << "\n";
+
// iv to store the shuffled elements
Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
@@ -468,6 +505,8 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+ int64_t offset = step.getDefiningOp<arith::ConstantIndexOp>().value();
+
auto newLoop = scf::ForOp::create(
rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
[&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
@@ -485,7 +524,6 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
Value indxToStoreInBuffer = c0;
Value indxToLoadFromBuffer = c0;
-
if (!isVnni) {
if (outerLoop) {
if (innerLoopIndex.value() == 0) {
@@ -509,7 +547,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
} else {
Value nLoadIndx = arith::ConstantIndexOp::create(
- rewriter, locNewInnerLoop, (16 * blockingFactor));
+ rewriter, locNewInnerLoop, offset);
ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
nLoadIndx, ivNewInnerLoop);
indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
@@ -525,7 +563,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
} else {
if (pack) {
Value nLoadIndx = arith::ConstantIndexOp::create(
- rewriter, locNewInnerLoop, (16 * blockingFactor));
+ rewriter, locNewInnerLoop, offset);
ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
nLoadIndx, ivNewInnerLoop);
Value quotient_K = arith::DivUIOp::create(
@@ -541,27 +579,49 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
}
}
}
-
IRMapping rhsMapping;
- if (outerLoop)
- rhsMapping.map(
- vectorOpRhs->getOperand(
- getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
- ivOuterLoop);
- rhsMapping.map(
- vectorOpRhs->getOperand(
- getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
- ivNewInnerLoop);
- auto rhsClone = rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
+ Value matB;
+ Operation *rhsOp = vectorOpRhs;
- Value matB = rhsClone->getResult(0);
+ // Clone only if the op has operands.
+ if (rhsOp->getNumOperands() > 0) {
+ if (outerLoop) {
+ int64_t outerPos = getIndexPosition(contractOp.getRhs(), outerLoop);
+
+ if (outerPos >= 0) {
+ unsigned operandIdx = static_cast<unsigned>(outerPos + 1);
+
+ if (operandIdx < rhsOp->getNumOperands()) {
+ rhsMapping.map(rhsOp->getOperand(operandIdx), ivOuterLoop);
+ }
+ }
+ }
+
+ int64_t innerPos = getIndexPosition(contractOp.getRhs(), innerLoop);
+
+ if (innerPos >= 0) {
+ unsigned operandIdx = static_cast<unsigned>(innerPos + 1);
+
+ if (operandIdx < rhsOp->getNumOperands()) {
+ rhsMapping.map(rhsOp->getOperand(operandIdx), ivNewInnerLoop);
+ }
+ }
+
+ auto rhsClone = rewriterNewInnerLoop.clone(*rhsOp, rhsMapping);
+
+ matB = rhsClone->getResult(0);
+
+ } else {
+ // memref.get_global / constants
+ matB = rhsOp->getResult(0);
+ }
if (!isVnni) {
if (outerLoop) {
if (!pack) {
Value nLoadIndx = arith::ConstantIndexOp::create(
- rewriter, locNewInnerLoop, (16 * blockingFactor));
+ rewriter, locNewInnerLoop, offset);
matB = Value();
indxToLoadFromBuffer = c0;
indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
@@ -572,7 +632,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
} else {
if (!pack) {
Value nLoadIndx = arith::ConstantIndexOp::create(
- rewriter, locNewInnerLoop, (16 * blockingFactor));
+ rewriter, locNewInnerLoop, offset);
matB = Value();
Value quotient_K = arith::DivUIOp::create(
rewriter, loc, ivNewInnerLoop, nLoadIndx);
@@ -581,7 +641,6 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
}
}
}
-
// compute tiled dot-product
SmallVector<Value> accumulators = createTiledDp(
rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
@@ -860,7 +919,7 @@ struct VectorContractToAMXDotProduct
14, 30, 46, 62, 15, 31, 47, 63});
}
- auto rem = arith::RemUIOp::create(
+ auto rem = arith::DivUIOp::create(
rewriter, loc, rewriter.getIndexType(), iv, step);
vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
@@ -988,6 +1047,7 @@ struct VectorContractToAMXDotProduct
scf::ForOp newLoop;
// Case 2a: Reduction loop depth is 2.
if (loopLists.size() == 2) {
+
outerLoop = loopLists[1];
innerLoop = loopLists[0];
@@ -1120,8 +1180,8 @@ struct VectorContractToAMXDotProduct
// Case 2b: Reduction loop depth is 1.
if (loopLists.size() == 1) {
- innerLoop = loopLists[0];
+ innerLoop = loopLists[0];
SmallVector<Value> loopItrArgs = createTileZeros(
rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
@@ -1135,6 +1195,7 @@ struct VectorContractToAMXDotProduct
nullptr, false, false);
} else {
+
bool isInnerLoopUBLarger = false;
bool isInnerLoopUBHasOddQuot = false;
@@ -1154,8 +1215,12 @@ struct VectorContractToAMXDotProduct
rewriter.setInsertionPoint(innerLoop);
auto c0 =
arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
+
+ int64_t stepVal =
+ innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>().value();
+
auto spillLoopBound = arith::ConstantIndexOp::create(
- rewriter, innerLoop.getLoc(), 16 * blockingFactor);
+ rewriter, innerLoop.getLoc(), stepVal);
Value spillInnerLoop =
arith::SubIOp::create(rewriter, innerLoop.getLoc(),
@@ -1173,10 +1238,8 @@ struct VectorContractToAMXDotProduct
getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
c0);
auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
-
performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
ipType, blockingFactor, packedBuffer, c0);
-
auto newLoopNonSpill = createLoops(
rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
@@ -1200,194 +1263,144 @@ struct VectorContractToAMXDotProduct
// Copy the amx tile accumulation results to a MemRef buffer, add the
// initial accumulation value, and store back to the C-Matrix
+ Location loc = outerLoop.getLoc();
+ Value srcBuffAcc;
+ SmallVector<Value> indicesAcc;
+
+ llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) {
+ srcBuffAcc = readOp.getOperand(0);
+
+ auto indices = readOp.getIndices();
+ indicesAcc.reserve(indices.size());
+
+ llvm::transform(indices, std::back_inserter(indicesAcc),
+ [&](OpFoldResult ofr) {
+ return mlir::getValueOrCreateConstantIndexOp(
+ rewriter, loc, ofr);
+ });
+ });
- if (!isVnni) {
- Location loc = outerLoop.getLoc();
- Operation *accReadOp =
- traceToVectorReadLikeParentOperation(contractOp.getAcc());
-
- Value srcBuffAcc;
- SmallVector<Value> indicesAcc;
-
- llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
- [&](auto readOp) {
- srcBuffAcc = readOp.getOperand(0);
-
- auto indices = readOp.getIndices();
- indicesAcc.reserve(indices.size());
-
- llvm::transform(indices, std::back_inserter(indicesAcc),
- [&](OpFoldResult ofr) {
- return mlir::getValueOrCreateConstantIndexOp(
- rewriter, loc, ofr);
- });
- });
-
- auto outputShapes =
- mlir::cast<mlir::MemRefType>(srcBuffAcc.getType()).getShape();
- unsigned int M = outputShapes[outputShapes.size() - 2];
- unsigned int N = outputShapes[outputShapes.size() - 1];
-
- SmallVector<Value> dps = newLoop.getResults();
- auto bufferType = MemRefType::get({M, N}, opType);
- auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
-
- // Store the amx tiled-dot product output into an MxN memref.
- for (unsigned int i = 0, k = 0; i < M; i = i + 16) {
- for (unsigned int j = 0; j < N; j = j + 16) {
- Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
- Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
- amx::TileStoreOp::create(rewriter, loc, resultBuffer,
- ValueRange{indexOp_i, indexOp_j}, dps[k]);
- k++;
- }
+ auto outputShapes =
+ mlir::cast<mlir::MemRefType>(srcBuffAcc.getType()).getShape();
+ unsigned int M = outputShapes[outputShapes.size() - 2];
+ unsigned int N = outputShapes[outputShapes.size() - 1];
+
+ SmallVector<Value> dps = newLoop.getResults();
+ auto bufferType = MemRefType::get({M, N}, opType);
+ auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+
+ // Store the amx tiled-dot product output into an MxN memref.
+ for (unsigned int i = 0, k = 0; i < M; i = i + 16) {
+ for (unsigned int j = 0; j < N; j = j + 16) {
+ Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+ Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+ amx::TileStoreOp::create(rewriter, loc, resultBuffer,
+ ValueRange{indexOp_i, indexOp_j}, dps[k]);
+ k++;
}
- auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
- auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
- auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
- auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
-
- // Create a loop that iterates over the MxN memerf, retrives two rows +
- // shuffle them, add up the C element values and stores them back.
- scf::ForOp::create(
- rewriter, loc, c0, mBound, one, ValueRange{},
- [&](OpBuilder &nestedBuilder, Location loc, Value iv,
- ValueRange iterArgs) {
- auto row = vector::LoadOp::create(rewriter, loc,
- VectorType::get(16, opType),
- resultBuffer, ValueRange{iv, c0});
-
- auto row2 = vector::LoadOp::create(
- rewriter, loc, VectorType::get(16, opType), resultBuffer,
- ValueRange{iv, c16});
-
- auto shuffle1 = vector::ShuffleOp::create(
+ }
+ auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+ auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
+
+ // Create a loop that iterates over the MxN memerf, retrives two rows +
+ // shuffle them, add up the C element values and stores them back.
+ scf::ForOp::create(
+ rewriter, loc, c0, mBound, one, ValueRange{},
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ auto row =
+ vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
+ resultBuffer, ValueRange{iv, c0});
+
+ auto row2 =
+ vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
+ resultBuffer, ValueRange{iv, c16});
+
+ Value shuffle1 = row;
+ Value shuffle2 = row2;
+
+ if (!isVnni) {
+ shuffle1 = vector::ShuffleOp::create(
rewriter, loc, VectorType::get(16, opType), row, row2,
ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
21, 22, 23});
- auto shuffle2 = vector::ShuffleOp::create(
+ shuffle2 = vector::ShuffleOp::create(
rewriter, loc, VectorType::get(16, opType), row, row2,
ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
28, 29, 30, 31});
+ }
+ indicesAcc[indicesAcc.size() - 2] = iv;
+ indicesAcc[indicesAcc.size() - 1] = c0;
- indicesAcc[indicesAcc.size() - 2] = iv;
- indicesAcc[indicesAcc.size() - 1] = c0;
+ Value valueCRow1 =
+ vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
+ srcBuffAcc, indicesAcc);
+ indicesAcc[indicesAcc.size() - 1] = c16;
- Value valueCRow1 = vector::LoadOp::create(
- rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
- indicesAcc);
- indicesAcc[indicesAcc.size() - 1] = c16;
+ Value valueCRow2 =
+ vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
+ srcBuffAcc, indicesAcc);
- Value valueCRow2 = vector::LoadOp::create(
- rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
- indicesAcc);
+ Value addOp;
+ Value addOp2;
- Value addOp;
- Value addOp2;
+ if (ipType.isBF16()) {
+ addOp = arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
- if (ipType.isBF16()) {
- addOp =
- arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
+ addOp2 = arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
+ }
- addOp2 =
- arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
- }
+ if (ipType.isSignlessInteger(8)) {
+ addOp = arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
- if (ipType.isSignlessInteger(8)) {
- addOp =
- arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
+ addOp2 = arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
+ }
- addOp2 =
- arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
- }
- indicesAcc[indicesAcc.size() - 1] = c0;
- vector::StoreOp::create(rewriter, loc, addOp, srcBuffAcc,
- indicesAcc);
- indicesAcc[indicesAcc.size() - 1] = c16;
- vector::StoreOp::create(rewriter, loc, addOp2, srcBuffAcc,
- indicesAcc);
-
- scf::YieldOp::create(nestedBuilder, loc);
- });
- }
+ vector::StoreOp::create(rewriter, loc, addOp, resultBuffer,
+ ValueRange{iv, c0});
+ vector::StoreOp::create(rewriter, loc, addOp2, resultBuffer,
+ ValueRange{iv, c16});
- auto bufferType = MemRefType::get({16, 16}, opType);
- auto resultBuffer =
- memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
- SmallVector<Value> dps = newLoop.getResults();
+ scf::YieldOp::create(nestedBuilder, loc);
+ });
+
+ SmallVector<Value> writeResults;
+ for (unsigned int i = 0; i < M; i = i + 16) {
+ for (unsigned int j = 0; j < N; j = j + 16) {
+ Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+ Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+
+ auto flatTy = mlir::VectorType::get({16, 16}, opType);
+
+ int64_t srcRank =
+ (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
+ Value padding = ub::PoisonOp::create(rewriter, loc, opType);
+ auto map = AffineMap::getMinorIdentityMap(srcRank, flatTy.getRank(),
+ rewriter.getContext());
+ SmallVector<bool> inBounds(flatTy.getRank(), true);
+
+ auto vec1 = vector::TransferReadOp::create(
+ rewriter, loc, flatTy, resultBuffer,
+ ValueRange{indexOp_i, indexOp_j}, padding, map, inBounds);
+ writeResults.push_back(vec1);
+ }
+ }
for (size_t i = 0; i < ops.size(); i++) {
vector::ContractionOp contOp = ops[i];
- Operation *resultWriteOp =
- traceToVectorWriteLikeUserOperation(contOp.getResult());
- if (isVnni) {
- rewriter.setInsertionPoint(resultWriteOp);
-
- Value indexOp_0 =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
-
- amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), resultBuffer,
- ValueRange{indexOp_0, indexOp_0}, dps[i]);
-
- auto c0 =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
- auto one =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
- auto mBound =
- arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
-
- scf::ForOp::create(
- rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
- [&](OpBuilder &builder, Location loc, Value iv,
- ValueRange iterArgs) {
- auto resultAcc = vector::LoadOp::create(
- rewriter, loc, VectorType::get(16, opType), resultBuffer,
- ValueRange{iv, c0});
-
- Operation *accReadOp =
- traceToVectorReadLikeParentOperation(ops[i].getAcc());
-
- Value srcBuffAcc;
- SmallVector<Value> indicesAcc;
+ Value vecRoc = writeResults[i];
- llvm::TypeSwitch<Operation *>(accReadOp)
- .Case<TransferReadOp, LoadOp>([&](auto readOp) {
- srcBuffAcc = readOp.getOperand(0);
-
- auto indices = readOp.getIndices();
- indicesAcc.reserve(indices.size());
-
- llvm::transform(
- indices, std::back_inserter(indicesAcc),
- [&](OpFoldResult ofr) {
- return mlir::getValueOrCreateConstantIndexOp(
- rewriter, loc, ofr);
- });
- });
-
- Value sum =
- arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
- indicesAcc[indicesAcc.size() - 2] = sum;
-
- auto acc = vector::LoadOp::create(rewriter, loc,
- VectorType::get(16, opType),
- srcBuffAcc, indicesAcc);
- Value addition;
- if (ipType.isBF16())
- addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
-
- if (ipType.isSignlessInteger(8))
- addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
-
- vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
- indicesAcc);
-
- scf::YieldOp::create(builder, outerLoop.getLoc());
- });
+ Value resultWriteOp =
+ traceToVectorWriteLikeUserOperationForAMX(contOp.getResult());
+ if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType())) {
+ vecRoc = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
+ writeResults[i]);
}
-
- rewriter.eraseOp(resultWriteOp);
+ resultWriteOp.replaceAllUsesWith(vecRoc);
}
return success();
>From 8e448d5f2a3e369dcc9514587e242e913d3e7f67 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 12 May 2026 05:55:29 -0700
Subject: [PATCH 02/11] counting offset on the subview result
---
.../VectorContractToAMXDotProduct.cpp | 47 ++++----
.../X86/AMX/vector-contract-to-tiled-dp.mlir | 100 ++++++++++++++++--
2 files changed, 122 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 551fccb47e114..64e5a6b56504b 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -239,32 +239,24 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
-
- auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
- SmallVector<Value> subviewOffset(subview.getMixedOffsets().size(), c0);
+ SmallVector<Value> subviewOffset(
+ llvm::cast<MemRefType>(matB.getType()).getRank(), c0);
Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
Value offsetIndx =
arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
- // llvm::outs() << "check-a:" << matB << " subview:" << subview << "\n";
- // llvm::outs() << "The size:" << subview.getMixedOffsets().size() << "\n";
-
scf::ForOp::create(
rewriter, loc, c0, cBound, cStep, ValueRange{},
[&](OpBuilder &nestedBuilder, Location loc, Value iv,
ValueRange iterArgs) {
- // llvm::outs() << "check-a0" << subviewOffset.size() << "\n";
subviewOffset[subviewOffset.size() - 2] = iv;
- // llvm::outs() << "check-a1" << "\n";
auto vec1 = vector::LoadOp::create(
rewriter, loc, VectorType::get((16 * offset), ipType), matB,
ValueRange(subviewOffset));
- // llvm::outs() << "check-b" << "\n";
-
// Increment the iv by 1 or 2 based on the type to load the next 32/64
// elements
Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
@@ -276,8 +268,6 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
vector::ShuffleOp shuffle1;
vector::ShuffleOp shuffle2;
- // llvm::outs() << "check-c" << "\n";
-
if (ipType.isBF16()) {
shuffle1 = vector::ShuffleOp::create(
@@ -318,8 +308,6 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
30, 62, 94, 126, 31, 63, 95, 127});
}
- // llvm::outs() << "check-d" << "\n";
-
// iv to store the shuffled elements
Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
@@ -829,6 +817,8 @@ struct VectorContractToAMXDotProduct
"The ACC src is not a MemRef type.");
auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
// amx.tile_loads
auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
@@ -856,7 +846,6 @@ struct VectorContractToAMXDotProduct
auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
// create a loop that does online packing.
- Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value step =
arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
@@ -951,9 +940,32 @@ struct VectorContractToAMXDotProduct
dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
loadRhs, loadAcc);
- amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
+ auto bufferType = MemRefType::get({16, 16}, opType);
+ auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+
+ amx::TileStoreOp::create(rewriter, loc, resultBuffer, ValueRange{c0, c0},
+ dp);
+
+ auto flatTy = mlir::VectorType::get({16, 16}, opType);
+ int64_t srcRank =
+ (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
+ Value padding = ub::PoisonOp::create(rewriter, loc, opType);
+ auto map = AffineMap::getMinorIdentityMap(srcRank, flatTy.getRank(),
+ rewriter.getContext());
+ SmallVector<bool> inBounds(flatTy.getRank(), true);
+
+ Value vecRow = vector::TransferReadOp::create(
+ rewriter, loc, flatTy, resultBuffer, ValueRange{c0, c0}, padding, map,
+ inBounds);
+
+ Value resultOp =
+ traceToVectorWriteLikeUserOperationForAMX(contractOp.getResult());
+ if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType())) {
+ vecRow =
+ mlir::vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
+ }
- rewriter.eraseOp(resultWriteOp);
+ resultOp.replaceAllUsesWith(vecRow);
return success();
}
@@ -1186,7 +1198,6 @@ struct VectorContractToAMXDotProduct
rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
if (isVnni) {
-
newLoop = createLoops(
rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 1a6deed31eceb..20d269fd6ff88 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -239,13 +239,17 @@ func.func @online_packing_int8(
%3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+ %bias = arith.constant dense<13> : !vecC
+
%4 = vector.contract {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>}
%1, %2, %3 : !vecA, !vecB into !vecC
- vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ %5 = arith.addi %4, %bias : !vecC
+
+ vector.transfer_write %5, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
return %arg2 : !memrefC
}
@@ -259,10 +263,11 @@ func.func @online_packing_int8(
// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xi32>
// CHECK: x86.amx.tile_muli
// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: vector.transfer_read
+// CHECK: arith.addi
+// CHECK: vector.transfer_write
// CHECK-NOT: vector.contract
-
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -695,7 +700,80 @@ func.func @online_packing_bf16_loop(%arg0: memref<16x64x96xbf16>, %arg1: memref<
// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
// CHECK-NOT: vector.contract
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecAB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>
+!memrefB = memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @brgemm_bf16_with_cano(%arg0: memref<16x32x16x2xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xf32>) -> memref<32x32xf32> {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c1 = arith.constant 1 : index
+ %2 = vector.transfer_read %arg2[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32>, !vecC
+ %3 = vector.transfer_read %arg2[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32>, !vecC
+ %4 = vector.transfer_read %arg2[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32>, !vecC
+ %5 = vector.transfer_read %arg2[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32>, !vecC
+
+ %6:4 = scf.for %arg3 = %c0 to %c16 step %c1 iter_args(%arg4 = %2, %arg5 = %3, %arg6 = %4, %arg7 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+
+ %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x32x16x2xbf16> to !memrefA
+ %subview_0 = memref.subview %arg1[%arg3, 0, 0, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x16x32x2xbf16> to !memrefB
+
+ %7 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : !memrefA, !vecAB
+ %8 = vector.transfer_read %subview[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : !memrefA, !vecAB
+ %9 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : !memrefB, !vecAB
+ %10 = vector.transfer_read %subview_0[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : !memrefB, !vecAB
+
+ %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %9, %arg4 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %10, %arg5 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %9, %arg6 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %10, %arg7 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ scf.yield %11, %12, %13, %14 : !vecC, !vecC, !vecC, !vecC
+ }
+ vector.transfer_write %6#3, %arg2[%c16, %c16] {in_bounds = [true, true]} : !vecC, memref<32x32xf32>
+ vector.transfer_write %6#2, %arg2[%c16, %c0] {in_bounds = [true, true]} : !vecC, memref<32x32xf32>
+ vector.transfer_write %6#1, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, memref<32x32xf32>
+ vector.transfer_write %6#0, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, memref<32x32xf32>
+ %alloc = memref.alloc() : memref<32x32xf32>
+ memref.copy %arg2, %alloc : memref<32x32xf32> to memref<32x32xf32>
+ return %alloc : memref<32x32xf32>
+}
+
+// CHECK-LABEL: @brgemm_bf16_with_cano
+// CHECK-1: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK-4: x86.amx.tile_load
+// CHECK-4: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -728,6 +806,7 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memr
%c128 = arith.constant 128 : index
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
+ %bias = arith.constant dense<13> : !vecC
scf.for %arg3 = %c0 to %c64 step %c32 {
scf.for %arg4 = %c0 to %c128 step %c32 {
%subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xi32> to !memrefC
@@ -756,10 +835,16 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memr
%8, %10, %arg9 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
scf.yield %11, %12, %13, %14 : !vecC, !vecC, !vecC, !vecC
}
- vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
- vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
- vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
- vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ %7 = arith.addi %6#3, %bias : !vecC
+ %8 = arith.addi %6#2, %bias : !vecC
+ %9 = arith.addi %6#1, %bias : !vecC
+ %10 = arith.addi %6#0, %bias : !vecC
+
+ vector.transfer_write %7, %subview[%c16, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %8, %subview[%c16, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %9, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %10, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
}
}
%alloc = memref.alloc() : memref<64x128xi32>
@@ -777,6 +862,7 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memr
// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+// CHECK-COUNT-4: arith.addi
// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>
// CHECK-NOT: vector.contract
>From 4dbcaf4def5e6039ce05b972de8691379535c171 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 12 May 2026 07:52:25 -0700
Subject: [PATCH 03/11] enable support for f8 type
---
.../VectorContractToAMXDotProduct.cpp | 26 ++-
.../X86/AMX/vector-contract-to-tiled-dp.mlir | 219 ++++++++++++++++++
2 files changed, 238 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 64e5a6b56504b..595f585ad2d61 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -285,7 +285,8 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
}
- if (ipType.isSignlessInteger(8)) {
+ if (ipType.isSignlessInteger(8) || ipType.isF8E5M2() ||
+ ipType.isF8E4M3FN()) {
shuffle1 = vector::ShuffleOp::create(
rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
@@ -419,7 +420,7 @@ createTiledDp(OpBuilder &rewriter, Location loc,
auto accTileType = amx::TileType::get({16, 16}, opType);
Value dp;
- if (ipType.isBF16())
+ if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN())
dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
tilesRhs, accIterArgs[i]);
@@ -725,15 +726,20 @@ struct VectorContractToAMXDotProduct
VectorType lhsTy = contractOp.getLhsType();
if (!lhsTy.getElementType().isBF16() &&
- !lhsTy.getElementType().isSignlessInteger(8))
+ !lhsTy.getElementType().isSignlessInteger(8) &&
+ !lhsTy.getElementType().isF8E4M3FN() &&
+ !lhsTy.getElementType().isF8E5M2())
return rewriter.notifyMatchFailure(
- contractOp, "Only BF16/Int8 lowering is supported.");
+ contractOp, "Only BF16/Int8/F8 lowering is supported.");
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
- if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+ if (((lhsTy.getElementType().isBF16() ||
+ lhsTy.getElementType().isF8E4M3FN() ||
+ lhsTy.getElementType().isF8E5M2()) &&
+ !accTy.getElementType().isF32()) ||
(lhsTy.getElementType().isSignlessInteger(8) &&
!accTy.getElementType().isSignlessInteger(32)))
return rewriter.notifyMatchFailure(contractOp,
@@ -760,6 +766,12 @@ struct VectorContractToAMXDotProduct
opType = rewriter.getIntegerType(32);
}
+ if (lhsTy.getElementType().isF8E4M3FN())
+ ipType = rewriter.getF8E4M3FNType();
+
+ if (lhsTy.getElementType().isF8E5M2())
+ ipType = rewriter.getF8E5M2Type();
+
if (accReadOp->getBlock() == contractOp->getBlock() &&
resultWriteOp->getBlock() != contractOp->getBlock())
return rewriter.notifyMatchFailure(
@@ -932,7 +944,7 @@ struct VectorContractToAMXDotProduct
// Tiled dot-product.
Value dp;
- if (ipType.isBF16())
+ if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN())
dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
loadRhs, loadAcc);
@@ -1359,7 +1371,7 @@ struct VectorContractToAMXDotProduct
Value addOp;
Value addOp2;
- if (ipType.isBF16()) {
+ if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN()) {
addOp = arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
addOp2 = arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 20d269fd6ff88..1ebb7010050ab 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -216,6 +216,60 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x16x16x4xf8E5M2>
+!vecB = vector<1x16x16x4xf8E5M2>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x4xf8E5M2>
+!memrefB = memref<1x16x32x4xf8E5M2>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_f8E5M2(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : f8E5M2
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_f8E5M2
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xf8E5M2>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xf8E5M2>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<16x64xi8>
!vecB = vector<64x16xi8>
!vecC = vector<16x16xi32>
@@ -523,6 +577,88 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecAB = vector<16x16x4xf8E4M3FN>
+!vecC = vector<16x16xf32>
+!memrefA = memref<16x16x4xf8E4M3FN, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xf8E4M3FN, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @matmul_f8E4M3FN_loop(%arg0: memref<64x64x4xf8E4M3FN>, %arg1: memref<64x128x4xf8E4M3FN>, %arg2: memref<64x128xf32>) {
+ %0 = ub.poison : f32
+ %1 = ub.poison : f8E4M3FN
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+ memref<64x128xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+ memref<64x64x4xf8E4M3FN> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+ memref<64x128x4xf8E4M3FN> to !memrefB
+ %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecAB
+ %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+ %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+ scf.yield %8, %9 : !vecC, !vecC
+ }
+
+ vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+
+ return
+}
+
+// CHECK-LABEL: @matmul_f8E4M3FN_loop
+// CHECK-2: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK-3: x86.amx.tile_load
+// CHECK-2: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecAB = vector<1x16x16x4xi8>
!vecC = vector<1x16x16xi32>
!memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
@@ -712,6 +848,89 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<16x64xf8E5M2>
+!vecB = vector<64x16xf8E5M2>
+!vecC = vector<16x16xf32>
+!memrefA = memref<32x64xf8E5M2, strided<[256, 1], offset: ?>>
+!memrefB = memref<64x32xf8E5M2, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xf8E5M2>, %arg1: memref<256x128xf8E5M2>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : f32
+ %1 = ub.poison : f8E5M2
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+ %6:4 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg3, %arg5] [32, 64] [1, 1] : memref<64x256xf8E5M2> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4] [64, 32] [1, 1] : memref<256x128xf8E5M2> to !memrefB
+ %7 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+ %8 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+ %9 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+ %10 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+ %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %9, %arg6 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %10, %arg7 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %9, %arg8 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %10, %arg9 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+ scf.yield %11, %12, %13, %14 : !vecC, !vecC, !vecC, !vecC
+ }
+ vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xf32>
+ memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+ return %alloc : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @online_packing_int8_matmul_loop
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xf8E5M2>, vector<64xf8E5M2>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xf8E5M2>, vector<64xf8E5M2>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecAB = vector<1x16x16x2xbf16>
!vecC = vector<16x16xf32>
!memrefA = memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>
>From 4d6cc4949aad936848ccb094184775ce964ec210 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 12 May 2026 22:39:12 -0700
Subject: [PATCH 04/11] relaxed a condition on ShapeCast which is not needed.
---
.../VectorContractToAMXDotProduct.cpp | 34 +++++++++----------
mlir/lib/Dialect/X86/Utils/X86Utils.cpp | 2 +-
2 files changed, 17 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 595f585ad2d61..34f6298c00b34 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -27,23 +27,20 @@ using namespace mlir::x86;
namespace {
-static Value traceToVectorWriteLikeUserOperationForAMX(Value v) {
+static Value contractionUsersAfterYield(Value v) {
if (v.getNumUses() > 1)
return nullptr;
for (OpOperand &use : v.getUses()) {
Operation *user = use.getOwner();
- if (!isa<scf::YieldOp>(user)) {
+ if (!isa<scf::YieldOp>(user))
return v;
- }
- // --- SCF YIELD ---
if (auto yield = dyn_cast<scf::YieldOp>(user)) {
Operation *parent = yield->getParentOp();
unsigned idx = use.getOperandNumber();
- if (auto res =
- traceToVectorWriteLikeUserOperationForAMX(parent->getResult(idx)))
+ if (auto res = contractionUsersAfterYield(parent->getResult(idx)))
return res;
continue;
}
@@ -494,7 +491,9 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
- int64_t offset = step.getDefiningOp<arith::ConstantIndexOp>().value();
+ int64_t offset = 16 * blockingFactor;
+ if (auto cst = step.getDefiningOp<arith::ConstantIndexOp>())
+ offset = cst.value();
auto newLoop = scf::ForOp::create(
rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
@@ -970,8 +969,7 @@ struct VectorContractToAMXDotProduct
rewriter, loc, flatTy, resultBuffer, ValueRange{c0, c0}, padding, map,
inBounds);
- Value resultOp =
- traceToVectorWriteLikeUserOperationForAMX(contractOp.getResult());
+ Value resultOp = contractionUsersAfterYield(contractOp.getResult());
if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType())) {
vecRow =
mlir::vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
@@ -1236,15 +1234,16 @@ struct VectorContractToAMXDotProduct
(((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
rewriter.setInsertionPoint(innerLoop);
+
auto c0 =
arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
-
- int64_t stepVal =
- innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>().value();
+ int64_t offset = 16 * blockingFactor;
+ if (auto cst =
+ innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>())
+ offset = cst.value();
auto spillLoopBound = arith::ConstantIndexOp::create(
- rewriter, innerLoop.getLoc(), stepVal);
-
+ rewriter, innerLoop.getLoc(), offset);
Value spillInnerLoop =
arith::SubIOp::create(rewriter, innerLoop.getLoc(),
innerLoop.getUpperBound(), spillLoopBound);
@@ -1417,12 +1416,11 @@ struct VectorContractToAMXDotProduct
vector::ContractionOp contOp = ops[i];
Value vecRoc = writeResults[i];
- Value resultWriteOp =
- traceToVectorWriteLikeUserOperationForAMX(contOp.getResult());
- if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType())) {
+ Value resultWriteOp = contractionUsersAfterYield(contOp.getResult());
+ if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType()))
vecRoc = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
writeResults[i]);
- }
+
resultWriteOp.replaceAllUsesWith(vecRoc);
}
diff --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
index aea6bf6adcd4a..2cd6012b9ec03 100644
--- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
+++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
@@ -203,7 +203,7 @@ Operation *traceToVectorWriteLikeUserOperation(Value v) {
if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user))
return user;
- if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user))
+ if (isa<vector::ShuffleOp>(user))
return nullptr;
// --- SCF YIELD ---
>From 2574d27a1e094aa9715176fad8f2c524113b9970 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 13 May 2026 20:42:22 -0700
Subject: [PATCH 05/11] extra validation on contraction lhs and rhs type.
---
.../Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp | 4 ++++
mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir | 4 ++--
2 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 34f6298c00b34..e31c0f6c7586e 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -731,6 +731,10 @@ struct VectorContractToAMXDotProduct
return rewriter.notifyMatchFailure(
contractOp, "Only BF16/Int8/F8 lowering is supported.");
+ if (lhsTy.getElementType() != contractOp.getRhsType().getElementType())
+ return rewriter.notifyMatchFailure(
+ contractOp, "Contraction should have same lhs and rhs type.");
+
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 1ebb7010050ab..d8950bd8baaf7 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -858,7 +858,7 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xf8E5M2>, %arg1: memref<256x128xf8E5M2>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
+func.func @online_packing_f8E5M2_matmul_loop(%arg0: memref<64x256xf8E5M2>, %arg1: memref<256x128xf8E5M2>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
%c16 = arith.constant 16 : index
%0 = ub.poison : f32
%1 = ub.poison : f8E5M2
@@ -906,7 +906,7 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xf8E5M2>, %arg1:
return %alloc : memref<64x128xf32>
}
-// CHECK-LABEL: @online_packing_int8_matmul_loop
+// CHECK-LABEL: @online_packing_f8E5M2_matmul_loop
// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xf8E5M2>, vector<64xf8E5M2>
>From 433ebe854bdf4533903b76a71c49049ea5a676bf Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 13 May 2026 21:01:03 -0700
Subject: [PATCH 06/11] clean-up.
---
.../X86/Transforms/VectorContractToAMXDotProduct.cpp | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index e31c0f6c7586e..a425d1534d836 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -581,9 +581,8 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
if (outerPos >= 0) {
unsigned operandIdx = static_cast<unsigned>(outerPos + 1);
- if (operandIdx < rhsOp->getNumOperands()) {
+ if (operandIdx < rhsOp->getNumOperands())
rhsMapping.map(rhsOp->getOperand(operandIdx), ivOuterLoop);
- }
}
}
@@ -592,19 +591,18 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
if (innerPos >= 0) {
unsigned operandIdx = static_cast<unsigned>(innerPos + 1);
- if (operandIdx < rhsOp->getNumOperands()) {
+ if (operandIdx < rhsOp->getNumOperands())
rhsMapping.map(rhsOp->getOperand(operandIdx), ivNewInnerLoop);
- }
}
auto rhsClone = rewriterNewInnerLoop.clone(*rhsOp, rhsMapping);
-
matB = rhsClone->getResult(0);
} else {
// memref.get_global / constants
matB = rhsOp->getResult(0);
}
+
if (!isVnni) {
if (outerLoop) {
if (!pack) {
@@ -1332,7 +1330,7 @@ struct VectorContractToAMXDotProduct
auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
// Create a loop that iterates over the MxN memerf, retrives two rows +
- // shuffle them, add up the C element values and stores them back.
+ // shuffle them, add up the C element values and stores them to temp buffer.
scf::ForOp::create(
rewriter, loc, c0, mBound, one, ValueRange{},
[&](OpBuilder &nestedBuilder, Location loc, Value iv,
@@ -1416,6 +1414,7 @@ struct VectorContractToAMXDotProduct
}
}
+ // Replace use of vector.contract with dot-products.
for (size_t i = 0; i < ops.size(); i++) {
vector::ContractionOp contOp = ops[i];
Value vecRoc = writeResults[i];
>From 22473a87e0fcaac055971b48516f4b972d4085a7 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 15 May 2026 09:30:25 -0700
Subject: [PATCH 07/11] minor fix in int8 shuffling
---
.../lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index a425d1534d836..6c02d5229119e 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -876,6 +876,7 @@ struct VectorContractToAMXDotProduct
arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
indicesRhs[indicesRhs.size() - 2] = iv;
+ indicesRhs[indicesRhs.size() - 1] = c0;
ValueRange range1(indicesRhs);
auto vec1 = vector::LoadOp::create(
rewriter, loc,
>From a274536b6dc589b9abfa57d5bf78e11f8900001c Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 20 May 2026 23:25:12 -0700
Subject: [PATCH 08/11] clean-up, comments and use of rewriter
---
.../VectorContractToAMXDotProduct.cpp | 54 +++++++++----------
mlir/lib/Dialect/X86/Utils/X86Utils.cpp | 4 +-
2 files changed, 29 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 6c02d5229119e..4b1f9adfde76d 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -27,26 +27,23 @@ using namespace mlir::x86;
namespace {
+// Recursively follows single-use values through scf.yield operations
+// and returns the first non-yield user result in the contraction chain.
static Value contractionUsersAfterYield(Value v) {
- if (v.getNumUses() > 1)
+ if (v.getNumUses() != 1)
return nullptr;
- for (OpOperand &use : v.getUses()) {
- Operation *user = use.getOwner();
+ OpOperand &use = *v.use_begin();
+ Operation *user = use.getOwner();
- if (!isa<scf::YieldOp>(user))
- return v;
+ if (!isa<scf::YieldOp>(user))
+ return v;
- if (auto yield = dyn_cast<scf::YieldOp>(user)) {
- Operation *parent = yield->getParentOp();
- unsigned idx = use.getOperandNumber();
- if (auto res = contractionUsersAfterYield(parent->getResult(idx)))
- return res;
- continue;
- }
- }
+ auto yield = cast<scf::YieldOp>(user);
+ Operation *parent = yield->getParentOp();
+ unsigned idx = use.getOperandNumber();
- return nullptr;
+ return contractionUsersAfterYield(parent->getResult(idx));
}
// Function to collapse the last two dimension (vnni and k) to help the
@@ -138,6 +135,9 @@ static LogicalResult validateContractOps(OpBuilder &rewriter,
if (buffRhs != srcBuffRhs)
return failure();
+
+ if (!contractionUsersAfterYield(contractOp.getResult()))
+ return failure();
}
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
@@ -572,7 +572,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
Value matB;
Operation *rhsOp = vectorOpRhs;
- // Clone only if the op has operands.
+ // Clone for the subview type operations
if (rhsOp->getNumOperands() > 0) {
if (outerLoop) {
@@ -599,7 +599,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
matB = rhsClone->getResult(0);
} else {
- // memref.get_global / constants
+ // The mat B is of kind 'memref.get_global @__constant'
matB = rhsOp->getResult(0);
}
@@ -803,7 +803,7 @@ struct VectorContractToAMXDotProduct
return rewriter.notifyMatchFailure(
contractOp, "The contract operation doesn't satisfy the operands "
"dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
- "The rest dims should be 1.");
+ "The rest dims should be 1. Op should have one user.");
Location loc = contractOp.getLoc();
@@ -978,7 +978,7 @@ struct VectorContractToAMXDotProduct
mlir::vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
}
- resultOp.replaceAllUsesWith(vecRow);
+ rewriter.replaceAllUsesWith(resultOp, vecRow);
return success();
}
@@ -1044,9 +1044,10 @@ struct VectorContractToAMXDotProduct
if (failed(validate))
return rewriter.notifyMatchFailure(
- contractOp, "The associated contract operations doesn't satisfy "
- "the re-write conditions either the dimensions are "
- "wrong or MemRef source are different.");
+ contractOp,
+ "The associated contract operations doesn't satisfy "
+ "the re-write conditions either the dimensions are "
+ "wrong or MemRef source are different or many users.");
ops.push_back(contract);
}
@@ -1072,7 +1073,6 @@ struct VectorContractToAMXDotProduct
scf::ForOp newLoop;
// Case 2a: Reduction loop depth is 2.
if (loopLists.size() == 2) {
-
outerLoop = loopLists[1];
innerLoop = loopLists[0];
@@ -1328,12 +1328,12 @@ struct VectorContractToAMXDotProduct
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
- auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
+ auto nBound = arith::ConstantIndexOp::create(rewriter, loc, N);
// Create a loop that iterates over the MxN memerf, retrives two rows +
// shuffle them, add up the C element values and stores them to temp buffer.
scf::ForOp::create(
- rewriter, loc, c0, mBound, one, ValueRange{},
+ rewriter, loc, c0, nBound, one, ValueRange{},
[&](OpBuilder &nestedBuilder, Location loc, Value iv,
ValueRange iterArgs) {
auto row =
@@ -1418,14 +1418,14 @@ struct VectorContractToAMXDotProduct
// Replace use of vector.contract with dot-products.
for (size_t i = 0; i < ops.size(); i++) {
vector::ContractionOp contOp = ops[i];
- Value vecRoc = writeResults[i];
+ Value vecRow = writeResults[i];
Value resultWriteOp = contractionUsersAfterYield(contOp.getResult());
if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType()))
- vecRoc = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
+ vecRow = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
writeResults[i]);
- resultWriteOp.replaceAllUsesWith(vecRoc);
+ rewriter.replaceAllUsesWith(resultWriteOp, vecRow);
}
return success();
diff --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
index 2cd6012b9ec03..a04a026f35ae6 100644
--- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
+++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
@@ -187,8 +187,8 @@ Operation *traceToVectorReadLikeParentOperation(Value v) {
// This function recursively traces a value through its uses to find
// a downstream vector write-like operation (`vector.transfer_write`
// or `vector.store`). It transparently follows values across `scf.for`
-// and `scf.yield` boundaries while stopping if layout-altering ops such
-// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// and `scf.yield` boundaries while stopping if layout-altering ops
+// like `shuffle` are encountered. The traversal returns
// the matching write-like user. Returns `nullptr` if none is found or
// the value has multiple users.
Operation *traceToVectorWriteLikeUserOperation(Value v) {
>From b402d03db7f7b80b694d17dd5f2d183fda82682b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 24 May 2026 01:02:39 -0700
Subject: [PATCH 09/11] support for the reduction loops to accept dynamic lower
bounds
---
.../VectorContractToAMXDotProduct.cpp | 34 ++++++++++++++++---
1 file changed, 29 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 4b1f9adfde76d..ee6a2bf0ef6ca 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -1138,15 +1138,29 @@ struct VectorContractToAMXDotProduct
rhsMapping.map(
vectorOpRhs->getOperand(
getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
- c0);
+ outerLoop.getLowerBound());
rhsMapping.map(
vectorOpRhs->getOperand(
getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
- c0);
+ innerLoop.getLowerBound());
auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
+ Value quotient_batch = arith::DivUIOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+ outerLoop.getStep());
+ Value quotient_k = arith::DivUIOp::create(rewriter, outerLoop.getLoc(),
+ innerLoop.getLowerBound(),
+ innerLoop.getStep());
+
+ Value quotient_add = arith::AddIOp::create(rewriter, outerLoop.getLoc(),
+ quotient_batch, quotient_k);
+ Value c2 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 2);
+ Value rem = arith::RemUIOp::create(rewriter, outerLoop.getLoc(),
+ quotient_add, c2);
+
performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
- ipType, blockingFactor, packedBuffer, c0);
+ ipType, blockingFactor, packedBuffer, rem);
// First Set of Loops
auto newLoopNonSpill = scf::ForOp::create(
@@ -1261,10 +1275,20 @@ struct VectorContractToAMXDotProduct
rhsMapping.map(
vectorOpRhs->getOperand(
getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
- c0);
+ innerLoop.getLowerBound());
auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
+
+ Value quotient_k = arith::DivUIOp::create(rewriter, innerLoop.getLoc(),
+ innerLoop.getLowerBound(),
+ innerLoop.getStep());
+ Value c2 =
+ arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 2);
+ Value rem = arith::RemUIOp::create(rewriter, innerLoop.getLoc(),
+ quotient_k, c2);
+
performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
- ipType, blockingFactor, packedBuffer, c0);
+ ipType, blockingFactor, packedBuffer, rem);
+
auto newLoopNonSpill = createLoops(
rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
>From ef01f0b6fa3f4f96ef227c867b77a00898d76c6b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 25 May 2026 04:51:30 -0700
Subject: [PATCH 10/11] check on loop step + new test-cases
---
.../VectorContractToAMXDotProduct.cpp | 52 ++++-
.../X86/AMX/vector-contract-to-tiled-dp.mlir | 202 ++++++++++++++++++
2 files changed, 247 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index ee6a2bf0ef6ca..82840a4f86d79 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -107,6 +107,20 @@ getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
return std::make_pair(srcBuff, indices);
}
+// Function to validate the loop step value.
+static LogicalResult validateLoopStep(OpBuilder &rewriter, Value step,
+ int64_t value) {
+
+ auto cst = step.getDefiningOp<arith::ConstantIndexOp>();
+ if (!cst)
+ return failure();
+
+ if (cst.value() != value && cst.value() != 1)
+ return failure();
+
+ return success();
+}
+
// Function to validate the vector.contract operation.
static LogicalResult validateContractOps(OpBuilder &rewriter,
vector::ContractionOp contractOp,
@@ -135,11 +149,11 @@ static LogicalResult validateContractOps(OpBuilder &rewriter,
if (buffRhs != srcBuffRhs)
return failure();
-
- if (!contractionUsersAfterYield(contractOp.getResult()))
- return failure();
}
+ if (!contractionUsersAfterYield(contractOp.getResult()))
+ return failure();
+
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
return failure();
@@ -973,10 +987,8 @@ struct VectorContractToAMXDotProduct
inBounds);
Value resultOp = contractionUsersAfterYield(contractOp.getResult());
- if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType())) {
- vecRow =
- mlir::vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
- }
+ if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType()))
+ vecRow = vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
rewriter.replaceAllUsesWith(resultOp, vecRow);
return success();
@@ -1076,6 +1088,21 @@ struct VectorContractToAMXDotProduct
outerLoop = loopLists[1];
innerLoop = loopLists[0];
+ LogicalResult validateOuterLoopStep =
+ validateLoopStep(rewriter, outerLoop.getStep(), 1);
+ if (failed(validateOuterLoopStep))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid loop step.");
+
+ int64_t stepValue = 16;
+ if (!isVnni)
+ stepValue = stepValue * blockingFactor;
+ LogicalResult validateInnerLoopStep =
+ validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
+ if (failed(validateInnerLoopStep))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Invalid loop step. The step should be 32 for BF16 and "
+ "64 for Int8/F8.");
+
SmallVector<Value> loopItrArgs = createTileZeros(
rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
@@ -1221,6 +1248,17 @@ struct VectorContractToAMXDotProduct
if (loopLists.size() == 1) {
innerLoop = loopLists[0];
+ int64_t stepValue = 16;
+ if (!isVnni)
+ stepValue = stepValue * blockingFactor;
+
+ LogicalResult validateInnerLoopStep =
+ validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
+ if (failed(validateInnerLoopStep))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Invalid loop step. The step should be 32 for BF16 and "
+ "64 for Int8/F8 or 1 if it is batch loop.");
+
SmallVector<Value> loopItrArgs = createTileZeros(
rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index d8950bd8baaf7..fb2314dfdf506 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -846,6 +846,111 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+!vecA = vector<16x32xbf16>
+!vecB = vector<32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<32x32xbf16, strided<[96, 1], offset: ?>>
+!memrefB = memref<32x32xbf16, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+
+func.func @online_packing_bf16_loop_lb_non_zero(%arg0: memref<64x96xbf16>, %arg1: memref<96x128xbf16>, %arg2: memref<64x128xf32>, %klb: index, %kub: index) -> memref<64x128xf32> {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c96 = arith.constant 96 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+ memref<64x128xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %7:4 = scf.for %arg10 = %klb to %kub step %c32 iter_args(%arg11 = %2, %arg12 = %3, %arg13 = %4, %arg14 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg3, %arg10] [32, 32] [1, 1] :
+ memref<64x96xbf16> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg10, %arg4] [32, 32] [1, 1] :
+ memref<96x128xbf16> to !memrefB
+ %8 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %9 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %10 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %11 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %10, %arg11 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %11, %arg12 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %10, %arg13 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %11, %arg14 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+
+ scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC
+ }
+ vector.transfer_write %7#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %7#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %7#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %7#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xf32>
+ memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+ return %alloc : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @online_packing_bf16_loop_lb_non_zero
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+
// -----
!vecA = vector<16x64xf8E5M2>
@@ -1458,6 +1563,103 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<16x32xbf16>
+!vecB = vector<32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<32x32xbf16, strided<[96, 1], offset: ?>>
+!memrefB = memref<32x32xbf16, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+
+func.func @negative_online_packing_bf16_dynamic_loop_step(%arg0: memref<64x96xbf16>, %arg1: memref<96x128xbf16>, %arg2: memref<64x128xf32>, %kStep: index) -> memref<64x128xf32> {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c96 = arith.constant 96 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+ memref<64x128xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %7:4 = scf.for %arg10 = %c0 to %c32 step %kStep iter_args(%arg11 = %2, %arg12 = %3, %arg13 = %4, %arg14 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg3, %arg10] [32, 32] [1, 1] :
+ memref<64x96xbf16> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg10, %arg4] [32, 32] [1, 1] :
+ memref<96x128xbf16> to !memrefB
+ %8 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %9 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %10 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %11 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %10, %arg11 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %11, %arg12 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %10, %arg13 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+ %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %11, %arg14 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+
+ scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC
+ }
+ vector.transfer_write %7#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %7#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %7#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %7#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xf32>
+ memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+ return %alloc : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @negative_online_packing_bf16_dynamic_loop_step
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecAB = vector<1x1x16x16x4xi8>
!vecC = vector<16x16xi32>
!memrefA = memref<1x1x16x16x4xi8, strided<[262144, 16384, 256, 4, 1], offset: ?>>
>From 818f161cfdf64d819f6eba91167f4f8f765f8ef4 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 25 May 2026 09:18:16 -0700
Subject: [PATCH 11/11] fix typo on comment.
---
.../Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 82840a4f86d79..1ffc31672296f 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -1256,8 +1256,9 @@ struct VectorContractToAMXDotProduct
validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
if (failed(validateInnerLoopStep))
return rewriter.notifyMatchFailure(
- contractOp, "Invalid loop step. The step should be 32 for BF16 and "
- "64 for Int8/F8 or 1 if it is batch loop.");
+ contractOp,
+ "Invalid loop step. The step should be 32 for BF16 and "
+ "64 for Int8/F8 or 1 if it is rduction loop other than K.");
SmallVector<Value> loopItrArgs = createTileZeros(
rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
More information about the Mlir-commits
mailing list