[Mlir-commits] [mlir] ee0ac74 - [mlir][x86] Lower packed type vector.contract to AMX dot-product (#182810)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 17 21:55:01 PDT 2026
Author: Arun Thangamani
Date: 2026-03-18T10:24:55+05:30
New Revision: ee0ac7443e4dc48f0ab2371dd5cbdcca32732e48
URL: https://github.com/llvm/llvm-project/commit/ee0ac7443e4dc48f0ab2371dd5cbdcca32732e48
DIFF: https://github.com/llvm/llvm-project/commit/ee0ac7443e4dc48f0ab2371dd5cbdcca32732e48.diff
LOG: [mlir][x86] Lower packed type vector.contract to AMX dot-product (#182810)
A transform pass to lower `vector.contract` operation to (a)
`amx.tile_mulf` for BF16, or (b) `amx.tile_muli` for Int8 packed types.
Added:
mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
Modified:
mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
mlir/include/mlir/Dialect/X86/Transforms.h
mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td b/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
index e7508f9ba4abb..c474cfb47d003 100644
--- a/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
+++ b/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
@@ -71,5 +71,16 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyVectorContractToAMXDotProductPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86.vector_contract_to_amx_dot_product",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect patterns to lower a BF16/Int8 type vector.contract operation
+ to a BF16/Int8 AMX tiled dot-product.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // X86_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86/Transforms.h b/mlir/include/mlir/Dialect/X86/Transforms.h
index 2862e83f06f79..6ebba5e94ec7c 100644
--- a/mlir/include/mlir/Dialect/X86/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86/Transforms.h
@@ -104,6 +104,12 @@ void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
// grouped with respect to odd/even packed index.
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
+// A set of patterns for lowering 32-bit packed vector contraction operations
+// to their corresponding packed-type tiled dot-product operations, using
+// AMX ultimately targeting the relevant x86 LLVM intrinsics (e.g., BF16 and
+// Int8).
+void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp b/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
index 2511f6f1b8b4c..390b21e12b0ed 100644
--- a/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
+++ b/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
@@ -47,6 +47,11 @@ void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
x86::populateShuffleVectorFMAOpsPatterns(patterns);
}
+void mlir::transform::ApplyVectorContractToAMXDotProductPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86::populateVectorContractToAMXDotProductPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
index efb6fde4fade1..9c3695536cda9 100644
--- a/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRX86Transforms
VectorContractBF16ToFMA.cpp
SinkVectorProducerOps.cpp
ShuffleVectorFMAOps.cpp
+ VectorContractToAMXDotProduct.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
new file mode 100644
index 0000000000000..85966a85af40e
--- /dev/null
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -0,0 +1,675 @@
+//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86/Transforms.h"
+#include "mlir/Dialect/X86/Utils/X86Utils.h"
+#include "mlir/Dialect/X86/X86Dialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86;
+
+namespace {
+
+// Function to collapse the last two dimension (vnni and k) to help the
+// amx.tile_load to correctly load the packed element type.
+static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
+ Value input) {
+ ShapedType inputType = cast<ShapedType>(input.getType());
+ int64_t firstDimToCollapse = inputType.getRank() - 2;
+
+ if (inputType.getRank() == 1)
+ return input;
+
+ SmallVector<ReassociationIndices> reassociation;
+ for (int64_t i = 0; i < firstDimToCollapse; ++i)
+ reassociation.push_back(ReassociationIndices{i});
+
+ ReassociationIndices collapsedIndices;
+ for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
+ collapsedIndices.push_back(i);
+
+ reassociation.push_back(collapsedIndices);
+ return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
+}
+
+// Get the MemRef source and offset index for the operands of
+// vector.contract.
+static FailureOr<std::pair<Value, SmallVector<Value>>>
+getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
+ bool isNotAcc) {
+ Operation *defOp = operand.getDefiningOp();
+ if (!defOp)
+ return failure();
+
+ Value srcBuff;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ srcBuff = readOp.getOperand(0);
+ });
+
+ if (!srcBuff)
+ return failure();
+
+ if (isNotAcc)
+ indexVals.pop_back();
+
+ SmallVector<Value> indices;
+ indices.reserve(indexVals.size());
+
+ for (OpFoldResult ofr : indexVals) {
+ indices.push_back(
+ mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ }
+
+ if (isNotAcc) {
+ srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
+ }
+
+ return std::make_pair(srcBuff, indices);
+}
+
+// Function to validate the vector.contract operation.
+static LogicalResult validateContractOps(OpBuilder &rewriter,
+ vector::ContractionOp contractOp,
+ unsigned int blockingFactor,
+ Value srcBuffLhs, Value srcBuffRhs,
+ bool srcValidate) {
+
+ if (srcValidate) {
+ // Get the MemRef buffer of LHS operand.
+ auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getLhs(), false);
+ if (failed(srcIndxLhs))
+ return failure();
+ auto [buffLhs, indicesLhs] = *srcIndxLhs;
+
+ // Get the MemRef buffer of RHS operand.
+ auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getRhs(), false);
+ if (failed(srcIndxRhs))
+ return failure();
+ auto [buffRhs, indicesRhs] = *srcIndxRhs;
+
+ // Return failure if the Memref buff didn't match.
+ if (buffLhs != srcBuffLhs)
+ return failure();
+
+ if (buffRhs != srcBuffRhs)
+ return failure();
+ }
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return failure();
+
+ // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+ if (nonUnitDimAcc.size() != 0)
+ return failure();
+
+ // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+ // <16x16x4>. The vnni dims should be 2 or 4.
+ VectorType lhsTy = contractOp.getLhsType();
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+ if (nonUnitDimLhs.size() != 1)
+ return failure();
+
+ if (nonUnitDimLhs[0] != blockingFactor)
+ return failure();
+
+ // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+ // <16x16x4>. The vnni dims should be 2 or 4.
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+ if (nonUnitDimRhs.size() != 1)
+ return failure();
+
+ if (nonUnitDimRhs[0] != blockingFactor)
+ return failure();
+
+ return success();
+}
+
+// Returns the loop index position to get mapped during the
+// MemRef type clone.
+static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
+ Value iv = loop.getInductionVar();
+
+ Value srcBuff;
+ llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+ .Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) { srcBuff = readOp.getOperand(0); });
+
+ auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
+ if (!subview)
+ return 0;
+
+ auto offsets = subview.getOffsets();
+
+ for (auto it : llvm::enumerate(offsets)) {
+ if (it.value() == iv)
+ return it.index();
+ }
+
+ return 0;
+}
+
+// Creates amx.tile_loads.
+static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
+ Value operand, Value mat, Type ipType,
+ bool rhs, unsigned int offset) {
+
+ auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
+ auto [srcBuff, indices] = *srcIndx;
+ indices.pop_back();
+
+ if (rhs) {
+ auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ indices[indices.size() - 1] = arith::MulIOp::create(
+ rewriter, loc, indices[indices.size() - 1], cOffset);
+ }
+
+ amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
+ return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+}
+
+// Creates tiled amx dot-products.
+static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
+ SmallVector<vector::ContractionOp> ops,
+ Value matA, Value matB, Type ipType,
+ Type opType, ValueRange accIterArgs,
+ unsigned int offset) {
+
+ auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
+ auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
+
+ SmallVector<Value> accumulators;
+ // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
+ // or load operations.
+ llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+
+ // Iterate over the contraction operations and compute the tiled dot-product.
+ for (size_t i = 0; i < ops.size(); i++) {
+
+ Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
+ amx::TileLoadOp tilesLhs;
+ auto itLhs = readsToTileLoads.find(readOpLhs);
+ if (itLhs != readsToTileLoads.end()) {
+ tilesLhs = itLhs->second;
+ } else {
+ tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
+ subviewCollapseLhs, ipType, false, offset);
+ readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
+ }
+
+ Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
+ amx::TileLoadOp tilesRhs;
+ auto itRhs = readsToTileLoads.find(readOpRhs);
+ if (itRhs != readsToTileLoads.end()) {
+ tilesRhs = itRhs->second;
+ } else {
+ tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
+ subviewCollapseRhs, ipType, true, offset);
+ readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
+ }
+
+ auto accTileType = amx::TileType::get({16, 16}, opType);
+
+ Value dp;
+ if (ipType.isBF16())
+ dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
+ tilesRhs, accIterArgs[i]);
+
+ if (ipType.isSignlessInteger(8))
+ dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
+ tilesRhs, accIterArgs[i]);
+
+ accumulators.push_back(dp);
+ }
+ return accumulators;
+}
+
+static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
+ Type opType, scf::ForOp outerLoop,
+ int64_t size) {
+ rewriter.setInsertionPoint(outerLoop);
+
+ SmallVector<Value> loopItrArgs;
+ auto zeroTileType = amx::TileType::get({16, 16}, opType);
+
+ for (int i = 0; i < size; i++) {
+ auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
+ loopItrArgs.push_back(zeroTile);
+ }
+ return loopItrArgs;
+}
+
+// Implements tiled dot-product operation for a vector.contract operation or a
+// sequence of vector.contracts inside the reduction loops.
+//
+// For example - for F32 type:
+// ```
+// vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+// vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+// vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
+// vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
+// ```
+// to
+// ```
+// amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+// amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+// amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
+// amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
+// ```
+struct VectorContractToAMXDotProduct
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ unsigned int blockingFactor =
+ contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
+ bool isVnni =
+ isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(), blockingFactor);
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isBF16() &&
+ !lhsTy.getElementType().isSignlessInteger(8))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only BF16/Int8 lowering is supported.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+ if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+ (lhsTy.getElementType().isSignlessInteger(8) &&
+ !accTy.getElementType().isSignlessInteger(32)))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only F32 for BF16 or Int32 for Int8 "
+ "accumulation type is supported.");
+ if (!isVnni)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only VNNI-packed inputs are supported.");
+
+ Operation *accReadOp =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+ Operation *resultWriteOp =
+ traceToVectorWriteLikeUserOperation(contractOp.getResult());
+
+ if (!accReadOp || !resultWriteOp)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The ACC operand of the vector.contract should be a "
+ "transfer_read or a load. And, the result should be "
+ "stored using transfer_write or store.");
+
+ Type ipType = rewriter.getBF16Type();
+ Type opType = rewriter.getF32Type();
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ ipType = rewriter.getIntegerType(8);
+ opType = rewriter.getIntegerType(32);
+ }
+
+ if (accReadOp->getBlock() == contractOp->getBlock() &&
+ resultWriteOp->getBlock() != contractOp->getBlock())
+ return rewriter.notifyMatchFailure(
+ contractOp, "The accumulator store is in
diff erent block.");
+
+ if (accReadOp->getBlock() != contractOp->getBlock() &&
+ resultWriteOp->getBlock() == contractOp->getBlock())
+ return rewriter.notifyMatchFailure(
+ contractOp, "The accumulator read is in
diff erent block.");
+
+ // Case 1: For just one VC rewrite. Where all accumulator read/write
+ // within the same block.
+ if (accReadOp->getBlock() == contractOp->getBlock() &&
+ resultWriteOp->getBlock() == contractOp->getBlock()) {
+
+ LogicalResult validate = validateContractOps(
+ rewriter, contractOp, blockingFactor, Value(), Value(), false);
+
+ if (failed(validate))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The contract operation doesn't satisfy the operands "
+ "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
+ "The rest dims should be 1.");
+
+ Location loc = contractOp.getLoc();
+
+ auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getLhs(), true);
+ if (failed(srcIndxLhs))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The LHS src is not a MemRef type.");
+ auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
+
+ auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getRhs(), true);
+ if (failed(srcIndxRhs))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The RHS src is not a MemRef type.");
+ auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+
+ auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getAcc(), false);
+ if (failed(srcIndxAcc))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The ACC src is not a MemRef type.");
+ auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
+
+ // amx.tile_loads
+ auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
+ auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
+ srcBuffLhs, indicesLhs);
+ auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
+ srcBuffRhs, indicesRhs);
+
+ auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
+ auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
+ srcBuffAcc, indicesAcc);
+
+ // Tiled dot-product.
+ Value dp;
+ if (ipType.isBF16())
+ dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+ loadRhs, loadAcc);
+
+ if (ipType.isSignlessInteger(8))
+ dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+ loadRhs, loadAcc);
+
+ amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
+
+ rewriter.eraseOp(resultWriteOp);
+ return success();
+ }
+
+ // Case 2: The acc are passed as iter args through the reduction loop.
+ // We support, reduction loop depth until 2. TODO: Support for n-depth
+ // reduction loop.
+ SmallVector<scf::ForOp> loopLists;
+ Operation *current = contractOp;
+
+ while (true) {
+ Operation *parent = current->getParentOfType<scf::ForOp>();
+ loopLists.push_back(dyn_cast<scf::ForOp>(parent));
+
+ if (accReadOp->getBlock() == parent->getBlock()) {
+ break;
+ }
+
+ current = parent;
+ }
+
+ if (loopLists.size() > 2 || loopLists.size() == 0)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Rewrite is supported until reduction loop depth of 2.");
+
+ auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getLhs(), false);
+ if (failed(srcIndxLhs))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The LHS src is not a MemRef type.");
+ auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
+
+ auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getRhs(), false);
+ if (failed(srcIndxRhs))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The RHS src is not a MemRef type.");
+ auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+
+ Operation *vectorOpLhs;
+ llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ vectorOpLhs = readOp.getBase().getDefiningOp();
+ });
+
+ Operation *vectorOpRhs;
+ llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ vectorOpRhs = readOp.getBase().getDefiningOp();
+ });
+
+ // Retrive all the contaction operation within the loop.
+ SmallVector<vector::ContractionOp> ops;
+ for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
+
+ if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
+
+ LogicalResult validate = validateContractOps(
+ rewriter, contract, blockingFactor, srcBuffLhs, srcBuffRhs, true);
+
+ if (failed(validate))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The associated contract operations doesn't satisfy "
+ "the re-write conditions either the dimensions are "
+ "wrong or MemRef source are
diff erent.");
+
+ ops.push_back(contract);
+ }
+ }
+
+ scf::ForOp outerLoop;
+ scf::ForOp innerLoop;
+
+ scf::ForOp newLoop;
+ // Case 2a: Reduction loop depth is 2.
+ if (loopLists.size() == 2) {
+ outerLoop = loopLists[1];
+ innerLoop = loopLists[0];
+
+ SmallVector<Value> loopItrArgs = createTileZeros(
+ rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
+
+ newLoop = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+ outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ auto newInnerLoop = scf::ForOp::create(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ iterArgsOuterLoop,
+ [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
+ Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
+ IRMapping mapping;
+ mapping.map(
+ vectorOpLhs->getOperand(
+ getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+ ivOuterLoop);
+ mapping.map(
+ vectorOpLhs->getOperand(
+ getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto lhsClone =
+ rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
+
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+ ivOuterLoop);
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto rhsClone =
+ rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
+
+ SmallVector<Value> accumulators = createTiledDp(
+ rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
+ rhsClone->getResult(0), ipType, opType,
+ iterArgsNewInnerLoop, blockingFactor);
+
+ scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
+ accumulators);
+ });
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+ newInnerLoop.getResults());
+ });
+ }
+
+ // Case 2b: Reduction loop depth is 1.
+ if (loopLists.size() == 1) {
+ outerLoop = loopLists[0];
+
+ SmallVector<Value> loopItrArgs = createTileZeros(
+ rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
+ newLoop = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+ outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ IRMapping mapping;
+ mapping.map(
+ vectorOpLhs->getOperand(
+ getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+ ivOuterLoop);
+
+ auto lhsClone = rewriterOuterLoop.clone(*vectorOpLhs, mapping);
+
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ vectorOpRhs->getOperand(
+ getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+ ivOuterLoop);
+
+ auto rhsClone = rewriterOuterLoop.clone(*vectorOpRhs, rhsMapping);
+
+ SmallVector<Value> accumulators = createTiledDp(
+ rewriter, locOuterLoop, ops, lhsClone->getResult(0),
+ rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop,
+ blockingFactor);
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
+ });
+ }
+
+ // post processing after the loop creation.
+ // Copy the amx tile accumulation results to a MemRef buffer, add the
+ // initial accumulation value, and store back to the C-Matrix
+ auto bufferType = MemRefType::get({16, 16}, opType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+ SmallVector<Value> dps = newLoop.getResults();
+ for (size_t i = 0; i < ops.size(); i++) {
+ vector::ContractionOp contOp = ops[i];
+ Operation *resultWriteOp =
+ traceToVectorWriteLikeUserOperation(contOp.getResult());
+ rewriter.setInsertionPoint(resultWriteOp);
+
+ Value indexOp_0 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+ amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+ ValueRange{indexOp_0, indexOp_0}, dps[i]);
+
+ auto c0 = arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+ auto one =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+ auto mBound =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+
+ scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
+ [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
+ auto resultAcc = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), bBuffer,
+ ValueRange{iv, c0});
+
+ Operation *accReadOp =
+ traceToVectorReadLikeParentOperation(ops[i].getAcc());
+
+ Value srcBuffAcc;
+ SmallVector<Value> indicesAcc;
+
+ llvm::TypeSwitch<Operation *>(accReadOp)
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ srcBuffAcc = readOp.getOperand(0);
+
+ auto indices = readOp.getIndices();
+ indicesAcc.reserve(indices.size());
+
+ llvm::transform(
+ indices, std::back_inserter(indicesAcc),
+ [&](OpFoldResult ofr) {
+ return mlir::getValueOrCreateConstantIndexOp(rewriter,
+ loc, ofr);
+ });
+ });
+
+ Value sum = arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
+ indicesAcc[indicesAcc.size() - 2] = sum;
+
+ auto acc = vector::LoadOp::create(rewriter, loc,
+ VectorType::get(16, opType),
+ srcBuffAcc, indicesAcc);
+ Value addition;
+ if (ipType.isBF16())
+ addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
+
+ if (ipType.isSignlessInteger(8))
+ addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
+
+ vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+ indicesAcc);
+
+ scf::YieldOp::create(builder, outerLoop.getLoc());
+ });
+
+ rewriter.eraseOp(resultWriteOp);
+ }
+
+ return success();
+ }
+};
+
+} // namespace
+
+void x86::populateVectorContractToAMXDotProductPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToAMXDotProduct>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
new file mode 100644
index 0000000000000..cde15b680a037
--- /dev/null
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -0,0 +1,1034 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x32x16x4xi8>
+!memrefB = memref<1x16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_int8(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_int8
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: x86.amx.tile_muli
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x16x4xi8>
+!vecB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x16x4xi8>
+!memrefB = memref<16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_int8(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_int8
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: x86.amx.tile_muli
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_bf16(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_bf16
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<1x16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<1x32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_bf16(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0, %c0], %32 {in_bounds = [true, true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @batch_matmul_bf16
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecAB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
+!memrefB = memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @brgemm_bf16_loop(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+ memref<64x128xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %7:4 = scf.for %arg10 = %c0 to %c64 step %c16 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10, 0] [1, 32, 16, 2] [1, 1, 1, 1] :
+ memref<16x64x64x2xbf16> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] :
+ memref<16x64x128x2xbf16> to !memrefB
+ %8 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecAB
+ %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecAB
+
+ %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %9, %arg11 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+ %11 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecAB
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %11, %arg12 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %13 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecAB
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %13, %9, %arg13 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %13, %11, %arg14 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+ scf.yield %10, %12, %14, %15 : !vecC, !vecC, !vecC, !vecC
+ }
+ scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC
+ }
+
+ vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+
+ return %arg2 : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @brgemm_bf16_loop
+// CHECK-2: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK-4: x86.amx.tile_load
+// CHECK-4: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecAB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x16x4xi8, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xi8, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @matmul_int8_loop(%arg0: memref<64x64x4xi8>, %arg1: memref<64x128x4xi8>, %arg2: memref<64x128xi32>) {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+ memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+ memref<64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+ memref<64x128x4xi8> to !memrefB
+ %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecAB
+ %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+ %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+ scf.yield %8, %9 : !vecC, !vecC
+ }
+
+ vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+
+ return
+}
+
+// CHECK-LABEL: @matmul_int8_loop
+// CHECK-2: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
+// CHECK-3: x86.amx.tile_load
+// CHECK-2: x86.amx.tile_muli
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecAB = vector<1x16x16x4xi8>
+!vecC = vector<1x16x16xi32>
+!memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x16x32x4xi8, strided<[32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<1x16x32xi32, strided<[8192, 128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @batch_matmul_int8_loop(%arg0: memref<16x64x64x4xi8>, %arg1: memref<16x64x128x4xi8>, %arg2: memref<16x64x128xi32>) {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ scf.for %arg5 = %c0 to %c16 step %c1 {
+
+ %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 16, 32] [1, 1, 1] :
+ memref<16x64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c0, %c16], %0 {in_bounds = [true, true, true]} :
+ !memrefC, !vecC
+ %4:2 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2, %arg8 = %3) -> (!vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 16, 16, 4] [1, 1, 1, 1] :
+ memref<16x64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 4] [1, 1, 1, 1] :
+ memref<16x64x128x4xi8> to !memrefB
+ %5 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecAB
+ %6 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecAB
+ %7 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecAB
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %6, %arg7 {unroll_shape = array<i64: 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %7, %arg8 {unroll_shape = array<i64: 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ scf.yield %8, %9 : !vecC, !vecC
+ }
+
+ vector.transfer_write %4#1, %subview[%c0, %c0, %c16] {in_bounds = [true, true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ !vecC, !memrefC
+ }
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: @batch_matmul_int8_loop
+// CHECK-2: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
+// CHECK-3: x86.amx.tile_load
+// CHECK-2: x86.amx.tile_muli
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x32x16x4xi8>
+!memrefB = memref<1x16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_invalid_vc_kind(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<mul>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_invalid_vc_kind
+// CHECK-NOT: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK-NOT: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x32x16x2xbf16>
+!vecB = vector<1x16x32x2xbf16>
+!vecC = vector<1x32x32xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<1x32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_dimensions(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0, %c0], %32 {in_bounds = [true, true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_wrong_dimensions
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK-NOT: x86.amx.tile_store
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x16x4xi8>
+!vecB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefB = memref<16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_no_memref_source_LHS(
+ %arg0: !vecA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_no_memref_source_LHS
+// CHECK-NOT: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK-NOT: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8>
+!memrefB = memref<64x32xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+func.func @negative_no_vnni_packed(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_no_vnni_packed
+// CHECK-NOT: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK-NOT: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecAB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x16x4xi8, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xi8, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @negative_VCs_LHS_src_
diff er(%arg0: memref<64x64x4xi8>, %arg1: memref<64x128x4xi8>, %arg2: memref<64x128xi32>, %arg13: memref<64x64x4xi8>) {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+ memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+ memref<64x64x4xi8> to !memrefA
+ %subview_negative = memref.subview %arg13[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+ memref<64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+ memref<64x128x4xi8> to !memrefB
+ %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecAB
+ %odd_load = vector.transfer_read %subview_negative[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecAB
+ %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+ %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %odd_load, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+ scf.yield %8, %9 : !vecC, !vecC
+ }
+
+ vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+
+ return
+}
+
+// CHECK-LABEL: @negative_VCs_LHS_src_
diff er
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK-NOT: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (vector<16x16xi32>, vector<16x16xi32>) {
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x32x4xi8>
+!vecC = vector<1x16x32xi32>
+!memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x16x32x4xi8, strided<[32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<1x16x32xi32, strided<[8192, 128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @negative_wrong_N_dim(%arg0: memref<16x64x64x4xi8>, %arg1: memref<16x64x128x4xi8>, %arg2: memref<16x64x128xi32>) {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ scf.for %arg5 = %c0 to %c16 step %c1 {
+
+ %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 16, 32] [1, 1, 1] :
+ memref<16x64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefC, !vecC
+ %3 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2) -> (!vecC) {
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 16, 16, 4] [1, 1, 1, 1] :
+ memref<16x64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 4] [1, 1, 1, 1] :
+ memref<16x64x128x4xi8> to !memrefB
+ %4 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %5 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+ %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %4, %5, %arg7 {unroll_shape = array<i64: 1, 4, 16, 32, 16>} : !vecA, !vecB into !vecC
+ scf.yield %6 : !vecC
+ }
+ vector.transfer_write %3, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ !vecC, !memrefC
+ }
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: @negative_wrong_N_dim
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecAB = vector<1x1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x1x16x16x4xi8, strided<[262144, 16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x1x16x32x4xi8, strided<[524288, 32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+
+func.func @negative_reduction_loop_depth_3(%arg0: memref<2x16x64x64x4xi8>, %arg1: memref<2x16x64x128x4xi8>, %arg2: memref<64x128xi32>) attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c2 = arith.constant 2 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+ memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %4:2 = scf.for %arg5 = %c0 to %c2 step %c1 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+ %5:2 = scf.for %arg8 = %c0 to %c16 step %c1 iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!vecC, !vecC) {
+ %6:2 = scf.for %arg11 = %c0 to %c64 step %c16 iter_args(%arg12 = %arg9, %arg13 = %arg10) -> (!vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg5, %arg8, %arg3, %arg11, 0] [1, 1, 16, 16, 4] [1, 1, 1, 1, 1] :
+ memref<2x16x64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg8, %arg11, %arg4, 0] [1, 1, 16, 32, 4] [1, 1, 1, 1, 1] :
+ memref<2x16x64x128x4xi8> to !memrefB
+ %7 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+ !memrefA, !vecAB
+ %8 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+ !memrefB, !vecAB
+ %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+ !memrefB, !vecAB
+
+ %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %8, %arg12 {unroll_shape = array<i64: 1, 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %9, %arg13 {unroll_shape = array<i64: 1, 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ scf.yield %10, %11 : !vecC, !vecC
+ }
+ scf.yield %6#0, %6#1 : !vecC, !vecC
+ }
+ scf.yield %5#0, %5#1 : !vecC, !vecC
+ }
+
+ vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+ return
+}
+
+
+// CHECK-LABEL: @negative_reduction_loop_depth_3
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xf16>
+!vecB = vector<1x16x16x2xf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xf16>
+!memrefB = memref<1x16x32x2xf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_wrong_type_f16(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : f16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_wrong_type_f16
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK-NOT: x86.amx.tile_store
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xbf16>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<32x32xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_wrong_acc_type_bf16(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : bf16
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_wrong_acc_type_bf16
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK-NOT: x86.amx.tile_store
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
More information about the Mlir-commits
mailing list