[Mlir-commits] [mlir] [mlir][x86] Lower packed type vector.contract to AMX dot-product (online-packing) (PR #188192)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 26 21:48:24 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Arun Thangamani (arun-thmn)
<details>
<summary>Changes</summary>
A transform pass to lower flat layout `vector.contract` operation to (a) amx.tile_mulf for BF16, or (b) amx.tile_muli for Int8 packed types via `online` packing.
TODOs: On an another `patch` planned to re-factor this pass + retiring `convert-vector-to-amx` pass.
---
Patch is 81.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/188192.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp (+897-170)
- (modified) mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir (+480-20)
``````````diff
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 85966a85af40e..744c065b4e05e 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -70,8 +70,9 @@ getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
if (!srcBuff)
return failure();
- if (isNotAcc)
+ if (isNotAcc) {
indexVals.pop_back();
+ }
SmallVector<Value> indices;
indices.reserve(indexVals.size());
@@ -189,37 +190,184 @@ static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
// Creates amx.tile_loads.
static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
Value operand, Value mat, Type ipType,
- bool rhs, unsigned int offset) {
+ bool rhs, unsigned int offset,
+ bool isVnni) {
auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
auto [srcBuff, indices] = *srcIndx;
- indices.pop_back();
+ if (isVnni) {
+ indices.pop_back();
+ }
- if (rhs) {
+ if (rhs && isVnni) {
auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
indices[indices.size() - 1] = arith::MulIOp::create(
rewriter, loc, indices[indices.size() - 1], cOffset);
}
amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
- return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+ auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+ return load;
}
-// Creates tiled amx dot-products.
-static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
- SmallVector<vector::ContractionOp> ops,
- Value matA, Value matB, Type ipType,
- Type opType, ValueRange accIterArgs,
- unsigned int offset) {
+static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
+ Type ipType, unsigned int offset, Value packedBuffer,
+ Value indxToStoreInBuffer) {
+
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+ auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
+ SmallVector<Value> subviewOffset(subview.getOffsets().size(), c0);
+
+ Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
+ Value offsetIndx =
+ arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
+
+ scf::ForOp::create(
+ rewriter, loc, c0, cBound, cStep, ValueRange{},
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ subviewOffset[subviewOffset.size() - 2] = iv;
+ auto vec1 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+ ValueRange(subviewOffset));
+
+ // Increment the iv by 1 or 2 based on the type to load the next 32/64
+ // elements
+ Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
+ subviewOffset[subviewOffset.size() - 2] = incIV;
+ auto vec2 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+ ValueRange(subviewOffset));
+
+ vector::ShuffleOp shuffle1;
+ vector::ShuffleOp shuffle2;
+
+ if (ipType.isBF16()) {
+
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9,
+ 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
+ 19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13,
+ 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
+ 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
+ }
+
+ if (ipType.isSignlessInteger(8)) {
+
+ shuffle1 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{
+ 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
+ 35, 67, 99, 4, 36, 68, 100, 5, 37, 69, 101, 6, 38,
+ 70, 102, 7, 39, 71, 103, 8, 40, 72, 104, 9, 41, 73,
+ 105, 10, 42, 74, 106, 11, 43, 75, 107, 12, 44, 76, 108,
+ 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111});
+
+ shuffle2 = vector::ShuffleOp::create(
+ rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+ vec2,
+ ArrayRef<int64_t>{
+ 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51,
+ 83, 115, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118,
+ 23, 55, 87, 119, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58,
+ 90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125,
+ 30, 62, 94, 126, 31, 63, 95, 127});
+ }
+
+ // iv to store the shuffled elements
+ Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
+ Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
+
+ vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
+ ValueRange{indxToStoreInBuffer, ivShuff1, c0});
+ vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
+ ValueRange{indxToStoreInBuffer, ivShuff2, c0});
+
+ scf::YieldOp::create(nestedBuilder, loc);
+ });
+}
+
+static llvm::DenseMap<Operation *, amx::TileLoadOp>
+packInputs(OpBuilder &rewriter, Location loc,
+ SmallVector<vector::ContractionOp> ops, Value matB, Type ipType,
+ unsigned int offset, Value packedBuffer, bool pack,
+ Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
+
+ llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+ for (size_t j = 0; j < ops.size(); j++) {
+ for (size_t i = 0; i < ops.size(); i++) {
+
+ if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
+
+ Operation *readOpRhs = ops[j].getRhs().getDefiningOp();
+ auto itRhs = readsToTileLoads.find(readOpRhs);
+ if (itRhs != readsToTileLoads.end()) {
+ continue;
+ }
- auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
- auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
+ if (pack) {
+ performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
+ indxToStoreInBuffer);
+ }
+
+ amx::TileType tileType =
+ amx::TileType::get({16, (16 * offset)}, ipType);
+ auto loadRow1 =
+ amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
+ ValueRange{indxToLoadFromMatB, c0, c0});
+
+ auto loadRow2 =
+ amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
+ ValueRange{indxToLoadFromMatB, c16, c0});
+
+ readsToTileLoads.try_emplace(readOpRhs, loadRow1);
+ readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
+ }
+ }
+ }
+
+ return readsToTileLoads;
+}
+
+// Creates tiled amx dot-products.
+static SmallVector<Value>
+createTiledDp(OpBuilder &rewriter, Location loc,
+ SmallVector<vector::ContractionOp> ops, Value matA, Value matB,
+ Type ipType, Type opType, ValueRange accIterArgs,
+ unsigned int offset, bool isVnni, Value packedBuffer, bool pack,
+ Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
+
+ if (isVnni) {
+ matA = collapseInnerDims(rewriter, loc, matA);
+ matB = collapseInnerDims(rewriter, loc, matB);
+ }
SmallVector<Value> accumulators;
// Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
// or load operations.
llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+ // function call to online pack the input B matrix
+ if (!isVnni) {
+ readsToTileLoads =
+ packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
+ indxToStoreInBuffer, indxToLoadFromMatB);
+ }
+
// Iterate over the contraction operations and compute the tiled dot-product.
for (size_t i = 0; i < ops.size(); i++) {
@@ -229,8 +377,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
if (itLhs != readsToTileLoads.end()) {
tilesLhs = itLhs->second;
} else {
- tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
- subviewCollapseLhs, ipType, false, offset);
+ tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
+ false, offset, isVnni);
readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
}
@@ -240,8 +388,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
if (itRhs != readsToTileLoads.end()) {
tilesRhs = itRhs->second;
} else {
- tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
- subviewCollapseRhs, ipType, true, offset);
+ tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
+ true, offset, isVnni);
readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
}
@@ -276,10 +424,186 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
return loopItrArgs;
}
+static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
+ Value ivInnerLoop, Value ivOuterLoop,
+ bool isInnerLoopUBHasOddQuot,
+ bool isInnerLoopUBLarger, bool pack,
+ unsigned int blockingFactor) {
+
+ Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+ Value packOffset =
+ arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
+
+ Value quotientInnerLoop =
+ arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
+ Value remInnerLoop = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), quotientInnerLoop, c2);
+
+ if (!isInnerLoopUBLarger && !pack) {
+ remInnerLoop = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
+ }
+
+ // if K quotient is odd. Then, BR loop iv is taken
+ // into consideration
+ if (isInnerLoopUBHasOddQuot) {
+ auto remOuterLoop = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
+ auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
+ remInnerLoop, remOuterLoop);
+ remInnerLoop = arith::RemUIOp::create(rewriter, loc,
+ rewriter.getIndexType(), remAdd, c2);
+ }
+ return remInnerLoop;
+}
+
+static scf::ForOp
+createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
+ Value upperBound, Value step, SmallVector<Value> loopItrArgs,
+ Type ipType, Type opType, unsigned int blockingFactor, bool isVnni,
+ Operation *vectorOpLhs, Operation *vectorOpRhs,
+ vector::ContractionOp contractOp, scf::ForOp outerLoop,
+ scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
+ Value ivOuterLoop, Value packedBuffer, bool pack,
+ arith::ConstantIndexOp innerLoopIndex, bool isInnerLoopUBLarger,
+ bool isInnerLoopUBHasOddQuot) {
+
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+
+ auto newLoop = scf::ForOp::create(
+ rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
+ [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
+ Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
+ IRMapping mapping;
+ if (outerLoop) {
+ mapping.map(vectorOpLhs->getOperand(
+ getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+ ivOuterLoop);
+ }
+ mapping.map(vectorOpLhs->getOperand(
+ getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
+
+ Value indxToStoreInBuffer = c0;
+ Value indxToLoadFromBuffer = c0;
+
+ if (!isVnni) {
+ if (outerLoop) {
+ if (innerLoopIndex.value() == 0) {
+ if (pack) {
+ ivNewInnerLoop = c0;
+ ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+ c1, ivOuterLoop);
+
+ if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
+ indxToStoreInBuffer = arith::RemUIOp::create(
+ rewriter, locNewInnerLoop, rewriter.getIndexType(),
+ ivOuterLoop, c2);
+ }
+
+ Value indxToLoadFromMatB = arith::AddIOp::create(
+ rewriter, loc, indxToStoreInBuffer, c1);
+ indxToLoadFromBuffer = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), indxToLoadFromMatB,
+ c2);
+ }
+
+ } else {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+ nLoadIndx, ivNewInnerLoop);
+ indxToStoreInBuffer =
+ bufferIndxToStore(rewriter, loc, ivNewInnerLoop, ivOuterLoop,
+ isInnerLoopUBHasOddQuot,
+ isInnerLoopUBLarger, pack, blockingFactor);
+ Value indxToLoadFromMatB =
+ arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
+ indxToLoadFromBuffer =
+ arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ indxToLoadFromMatB, c2);
+ }
+ } else {
+ if (pack) {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+ nLoadIndx, ivNewInnerLoop);
+ Value quotient_K = arith::DivUIOp::create(
+ rewriter, loc, ivNewInnerLoop, nLoadIndx);
+ indxToStoreInBuffer = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+
+ Value indxToLoadFromMatB =
+ arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
+ indxToLoadFromBuffer =
+ arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+ indxToLoadFromMatB, c2);
+ }
+ }
+ }
+
+ IRMapping rhsMapping;
+ if (outerLoop) {
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+ ivOuterLoop);
+ }
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto rhsClone = rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
+
+ Value matB = rhsClone->getResult(0);
+
+ if (!isVnni) {
+ if (outerLoop) {
+ if (!pack) {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ matB = Value();
+ indxToLoadFromBuffer = c0;
+ indxToLoadFromBuffer =
+ bufferIndxToStore(rewriter, loc, nLoadIndx, ivOuterLoop,
+ isInnerLoopUBHasOddQuot,
+ isInnerLoopUBLarger, pack, blockingFactor);
+ }
+ } else {
+ if (!pack) {
+ Value nLoadIndx = arith::ConstantIndexOp::create(
+ rewriter, locNewInnerLoop, (16 * blockingFactor));
+ matB = Value();
+ Value quotient_K = arith::DivUIOp::create(
+ rewriter, loc, ivNewInnerLoop, nLoadIndx);
+ indxToLoadFromBuffer = arith::RemUIOp::create(
+ rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+ }
+ }
+ }
+
+ // compute tiled dot-product
+ SmallVector<Value> accumulators = createTiledDp(
+ rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
+ ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
+ packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
+
+ scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
+ accumulators);
+ });
+
+ return newLoop;
+}
+
// Implements tiled dot-product operation for a vector.contract operation or a
// sequence of vector.contracts inside the reduction loops.
//
-// For example - for F32 type:
+// For example:
+// Case 1: register blocked vector.contract with prepacked input
// ```
// vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
// vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
@@ -293,6 +617,52 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
// amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
// amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
// ```
+//
+//
+// Case2: vector.contract with register blocked
+//
+// Output IR with online packing (with s/w pipeline advantage):
+// s/w pipeline: load, pack to VNNI, and store the B sub matrix
+// of the 0th batch-reduce and K iteration.
+// scf.for (0 to 31) {
+// - load 0th and 1st vector<32xbf16>, pack into VNNI, store the
+// first shuffle in 0th and 2nd shuffle in 16th index of the
+// buffer.
+// }
+// scf.for (0 to br-2) { batch-reduce loop
+// scf.for (0 to k-2) { K loop
+// - load A matrix
+// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
+// iteration from the buffer (d) compute the tiled dot-product
+// }
+// Last iteration of the the K Loop (k-1) {
+// - load A matrix
+// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+// matrix for the next batch-reduce + K loop iteration (c) load VNNI pack B
+// matrix of K iteration from the buffer (d) compute the tiled dot-product
+// }
+// }
+// Last iteration of the batch-reduce loop (br-1) {
+// scf.for (0 to k-2) { K loop
+// - load A matrix
+// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
+// iteration from the buffer (d) compute the tiled dot-product
+// }...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/188192
More information about the Mlir-commits
mailing list