[Mlir-commits] [mlir] [mlir][x86] Lower vector.contract to packed type tiled dot-product. (PR #182810)
Arun Thangamani
llvmlistbot at llvm.org
Wed Mar 4 22:21:52 PST 2026
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/182810
>From 4700e8d10e8a4dca1b466dffc4a0b1552d324967 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 22 Feb 2026 22:47:13 -0800
Subject: [PATCH 1/7] initial commit for VC to tiled-dp lowering
---
mlir/include/mlir/Dialect/AMX/CMakeLists.txt | 2 +
.../AMX/TransformOps/AMXTransformOps.h | 31 +
.../AMX/TransformOps/AMXTransformOps.td | 32 +
.../Dialect/AMX/TransformOps/CMakeLists.txt | 4 +
mlir/include/mlir/Dialect/AMX/Transforms.h | 7 +
mlir/lib/Dialect/AMX/CMakeLists.txt | 1 +
.../AMX/TransformOps/AMXTransformOps.cpp | 57 ++
.../Dialect/AMX/TransformOps/CMakeLists.txt | 17 +
.../lib/Dialect/AMX/Transforms/CMakeLists.txt | 1 +
...torContractToPackedTypeTiledDotProduct.cpp | 586 ++++++++++++++++++
mlir/lib/RegisterAllExtensions.cpp | 2 +
.../AMX/vector-contract-to-tiled-dp.mlir | 242 ++++++++
12 files changed, 982 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
create mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
create mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
create mode 100644 mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
create mode 100644 mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index f875c78d240cc..b211cb3b53fd7 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
add_mlir_interface(AMXInterfaces)
add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
new file mode 100644
index 0000000000000..8806635df8eb5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
@@ -0,0 +1,31 @@
+//===- AMXTransformOps.h - AMX transform ops --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
+#define MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// AMX Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace amx {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace amx
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
new file mode 100644
index 0000000000000..74bcba3c37e57
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
@@ -0,0 +1,32 @@
+//===- AMXTransformOps.td - AMX transform ops --*- tablegen -*-------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMX_TRANSFORM_OPS
+#define AMX_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractToPackedTypeTiledDotProductPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect patterns to lower a BF16/Int8 type vector.contract operation
+ to a BF16/Int8 tiled dot-product.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+
+#endif // AMX_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..41255b936be71
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS AMXTransformOps.td)
+mlir_tablegen(AMXTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(AMXTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRAMXTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 7391ec2ff6b14..f95bf296c1555 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -28,6 +28,13 @@ void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
/// Register LLVM conversion interface for AMX dialect.
void registerConvertAMXToLLVMInterface(DialectRegistry ®istry);
+namespace amx {
+
+void populateVectorContractToPackedTypeTiledDotProductPatterns(
+ RewritePatternSet &patterns);
+
+}
+
} // namespace mlir
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/AMX/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
new file mode 100644
index 0000000000000..6ff573407b42b
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
@@ -0,0 +1,57 @@
+//===- AMXTransformOps.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::amx;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractToPackedTypeTiledDotProductPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ amx::populateVectorContractToPackedTypeTiledDotProductPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class AMXTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ AMXTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AMXTransformDialectExtension)
+
+ AMXTransformDialectExtension() {
+ declareGeneratedDialect<amx::AMXDialect>();
+ declareGeneratedDialect<LLVM::LLVMDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
+
+void mlir::amx::registerTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<AMXTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..30b4304586ab7
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRAMXTransformOps
+ AMXTransformOps.cpp
+
+ DEPENDS
+ MLIRAMXTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRAMXDialect
+ MLIRAMXTransforms
+ )
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index e827bc475e930..e422f275c7742 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRAMXTransforms
LegalizeForLLVMExport.cpp
+ VectorContractToPackedTypeTiledDotProduct.cpp
LINK_LIBS PUBLIC
MLIRAMXDialect
diff --git a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
new file mode 100644
index 0000000000000..37fd3bb33848b
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -0,0 +1,586 @@
+//===- VectorContractToPackedTypeTiledDotProduct.cpp ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::amx;
+
+namespace {
+
+static Operation *traceToVectorReadLikeParentOperation(Value v) {
+ while (true) {
+ // Case 1: Value defined by an operation
+ if (Operation *defOp = v.getDefiningOp()) {
+ if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
+ return defOp;
+ }
+
+ if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
+ return nullptr;
+ }
+
+ return nullptr;
+ }
+
+ // Case 2: BlockArgument (scf.for iter_arg)
+ if (auto barg = dyn_cast<BlockArgument>(v)) {
+ auto *parentOp = barg.getOwner()->getParentOp();
+
+ if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+ unsigned argNum = barg.getArgNumber();
+
+ // arg0 = induction variable (not an iter_arg)
+ if (argNum == 0)
+ return nullptr;
+
+ unsigned iterIdx = argNum - 1;
+ v = forOp.getInitArgs()[iterIdx];
+ continue;
+ }
+
+ return nullptr;
+ }
+
+ return nullptr;
+ }
+}
+
+static Operation *traceToVectorWriteLikeUserOperation(Value v) {
+ for (OpOperand &use : v.getUses()) {
+ Operation *user = use.getOwner();
+
+ // --- TERMINAL OPS ---
+ if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
+ return user;
+ }
+
+ if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
+ return nullptr;
+ }
+
+ // --- SCF YIELD ---
+ if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+ Operation *parent = yield->getParentOp();
+ unsigned idx = use.getOperandNumber();
+ if (auto *res =
+ traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
+ return res;
+ continue;
+ }
+
+ // --- SCF FOR ---
+ if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+ unsigned idx = use.getOperandNumber();
+ if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
+ return res;
+ continue;
+ }
+
+ // --- GENERIC CASE ---
+ for (Value res : user->getResults()) {
+ if (auto *found = traceToVectorWriteLikeUserOperation(res))
+ return found;
+ }
+ }
+
+ return nullptr;
+}
+
+static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
+ Value input, int64_t firstDimToCollapse) {
+ ShapedType inputType = cast<ShapedType>(input.getType());
+ if (inputType.getRank() == 1)
+ return input;
+ SmallVector<ReassociationIndices> reassociation;
+ for (int64_t i = 0; i < firstDimToCollapse; ++i)
+ reassociation.push_back(ReassociationIndices{i});
+ ReassociationIndices collapsedIndices;
+ for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
+ collapsedIndices.push_back(i);
+ reassociation.push_back(collapsedIndices);
+ return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
+}
+
+static std::pair<Value, SmallVector<Value>> getSrcIndxValue(OpBuilder &rewriter,
+ Location loc,
+ Value operand,
+ bool isNotAcc) {
+ Value srcBuff;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ srcBuff = readOp.getOperand(0);
+ });
+
+ if (isNotAcc) {
+ indexVals.pop_back();
+ }
+
+ SmallVector<Value> indices;
+ indices.reserve(indexVals.size());
+
+ for (OpFoldResult ofr : indexVals) {
+ indices.push_back(
+ mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ }
+
+ if (isNotAcc) {
+ auto subviewType = cast<ShapedType>(srcBuff.getType());
+ auto subviewRank = subviewType.getRank();
+ srcBuff = collapseInnerDims(rewriter, loc, srcBuff, subviewRank - 2);
+ }
+ return {srcBuff, indices};
+}
+
+static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
+ Value iv = loop.getInductionVar();
+
+ Value srcBuff;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ srcBuff = readOp.getOperand(0);
+ });
+
+ auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
+ if (!subview)
+ return 0;
+
+ auto offsets = subview.getOffsets();
+
+ for (auto it : llvm::enumerate(offsets)) {
+ if (it.value() == iv)
+ return it.index();
+ }
+
+ return 0;
+}
+
+static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
+ Value operand, Value mat, Type ipType,
+ bool rhs) {
+
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ });
+
+ indexVals.pop_back();
+ SmallVector<Value> indices;
+ indices.reserve(indexVals.size());
+
+ for (OpFoldResult ofr : indexVals) {
+ indices.push_back(
+ mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ }
+
+ if (rhs) {
+ int offset = 4;
+ if (ipType.isBF16()) {
+ offset = 2;
+ }
+
+ auto c2 = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ indices[indices.size() - 1] =
+ arith::MulIOp::create(rewriter, loc, indices[indices.size() - 1], c2);
+ }
+
+ amx::TileType tileType = amx::TileType::get({16, 64}, ipType);
+ if (ipType.isBF16()) {
+ tileType = amx::TileType::get({16, 32}, ipType);
+ }
+
+ auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+
+ return load;
+}
+
+static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
+ SmallVector<vector::ContractionOp> ops,
+ Value matA, Value matB, Type ipType,
+ Type opType, ValueRange accIterArgs) {
+ auto subviewType = cast<ShapedType>(matA.getType());
+ auto subviewRank = subviewType.getRank();
+ auto collapsedOpnd = collapseInnerDims(rewriter, loc, matA, subviewRank - 2);
+
+ auto subviewType1 = cast<ShapedType>(matB.getType());
+ auto subviewRank1 = subviewType1.getRank();
+ auto collapsedOpnd1 =
+ collapseInnerDims(rewriter, loc, matB, subviewRank1 - 2);
+
+ SmallVector<Value> accumulators;
+ llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+
+ for (size_t i = 0; i < ops.size(); i++) {
+
+ Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
+
+ amx::TileLoadOp tilesLhs;
+
+ auto it = readsToTileLoads.find(readOpLhs);
+ if (it != readsToTileLoads.end()) {
+ tilesLhs = it->second;
+ } else {
+
+ tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), collapsedOpnd,
+ ipType, false);
+ readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
+ }
+
+ Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
+
+ amx::TileLoadOp tilesRhs;
+
+ auto it1 = readsToTileLoads.find(readOpRhs);
+ if (it1 != readsToTileLoads.end()) {
+ tilesRhs = it1->second;
+ } else {
+
+ tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), collapsedOpnd1,
+ ipType, true);
+ readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
+ }
+
+ auto tileType1 = amx::TileType::get({16, 16}, opType);
+
+ Value dp;
+ if (ipType.isBF16())
+ dp = amx::TileMulFOp::create(rewriter, loc, tileType1, tilesLhs, tilesRhs,
+ accIterArgs[i]);
+
+ if (ipType.isSignlessInteger(8))
+ dp = amx::TileMulIOp::create(rewriter, loc, tileType1, tilesLhs, tilesRhs,
+ accIterArgs[i]);
+
+ accumulators.push_back(dp);
+ }
+ return accumulators;
+}
+
+static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
+ Type opType, scf::ForOp outerLoop,
+ int64_t size) {
+ rewriter.setInsertionPoint(outerLoop);
+
+ SmallVector<Value> loopItrArgs;
+ auto zeroTileType = amx::TileType::get({16, 16}, opType);
+
+ for (int i = 0; i < size; i++) {
+ auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
+ loopItrArgs.push_back(zeroTile);
+ }
+ return loopItrArgs;
+}
+
+// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <16xf32>
+// vector.fma vector<16xf32>
+// ```
+struct VectorContractToPackedTypeTiledDotProduct
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ Operation *accReadOp =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+ Operation *resultWriteOp =
+ traceToVectorWriteLikeUserOperation(contractOp.getResult());
+
+ if (!accReadOp || !resultWriteOp)
+ return failure();
+
+ /*if (accReadOp->getBlock() == contractOp->getBlock())
+ return failure();
+
+ if (resultWriteOp->getBlock() == contractOp->getBlock())
+ return failure(); */
+
+ // case for just one vc rewrite.
+ if (accReadOp->getBlock() == contractOp->getBlock() &&
+ resultWriteOp->getBlock() == contractOp->getBlock()) {
+ Location loc = contractOp.getLoc();
+ auto tileType = amx::TileType::get({16, 32}, rewriter.getBF16Type());
+
+ auto [srcBuffLhs, indicesLhs] = getSrcIndxValue(
+ rewriter, contractOp.getLoc(), contractOp.getLhs(), true);
+ auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
+ srcBuffLhs, indicesLhs);
+
+ auto [srcBuffRhs, indicesRhs] = getSrcIndxValue(
+ rewriter, contractOp.getLoc(), contractOp.getRhs(), true);
+ auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
+ srcBuffRhs, indicesRhs);
+
+ auto [srcBuffAcc, indicesAcc] = getSrcIndxValue(
+ rewriter, contractOp.getLoc(), contractOp.getAcc(), false);
+ auto tileTypeAcc = amx::TileType::get({16, 16}, rewriter.getF32Type());
+ auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
+ srcBuffAcc, indicesAcc);
+
+ auto tdp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+ loadRhs, loadAcc);
+ amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, tdp);
+
+ rewriter.eraseOp(resultWriteOp);
+ return success();
+ }
+
+ SmallVector<scf::ForOp> list;
+ Operation *current = contractOp;
+
+ while (true) {
+ Operation *parent = current->getParentOfType<scf::ForOp>();
+ list.push_back(dyn_cast<scf::ForOp>(parent));
+
+ if (accReadOp->getBlock() == parent->getBlock()) {
+ break;
+ }
+
+ current = parent;
+ }
+
+ if (list.size() > 2 || list.size() == 0)
+ return failure();
+
+ SmallVector<vector::ContractionOp> ops;
+
+ for (mlir::Operation &op : list[0].getBody()->getOperations()) {
+
+ if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
+ ops.push_back(contract);
+ }
+ }
+
+ Type ipType;
+ Type opType;
+ VectorType lhsTy = contractOp.getLhsType();
+ if (lhsTy.getElementType().isBF16()) {
+ ipType = rewriter.getBF16Type();
+ opType = rewriter.getF32Type();
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ ipType = rewriter.getIntegerType(8);
+ opType = rewriter.getIntegerType(32);
+ }
+
+ scf::ForOp outerLoop;
+ scf::ForOp innerLoop;
+
+ auto vectorReadOpLhs =
+ contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
+ auto vectorReadOpRhs =
+ contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
+
+ scf::ForOp newLoop;
+ if (list.size() == 2) {
+ outerLoop = list[1];
+ innerLoop = list[0];
+
+ SmallVector<Value> loopItrArgs = createTileZeros(
+ rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
+
+ newLoop = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+ outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ auto newInnerLoop = scf::ForOp::create(
+ rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+ innerLoop.getUpperBound(), innerLoop.getStep(),
+ iterArgsOuterLoop,
+ [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
+ Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
+ IRMapping mapping;
+ mapping.map(
+ vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+ getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+ ivOuterLoop);
+ mapping.map(
+ vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+ getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto lhsClone = rewriterNewInnerLoop.clone(
+ *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+ getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+ ivOuterLoop);
+ rhsMapping.map(
+ vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+ getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+ ivNewInnerLoop);
+ auto rhsClone = rewriterNewInnerLoop.clone(
+ *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+ SmallVector<Value> accumulators = createTiledDp(
+ rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
+ rhsClone->getResult(0), ipType, opType,
+ iterArgsNewInnerLoop);
+
+ scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
+ accumulators);
+ });
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+ newInnerLoop.getResults());
+ });
+ }
+
+ if (list.size() == 1) {
+ outerLoop = list[0];
+
+ SmallVector<Value> loopItrArgs = createTileZeros(
+ rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
+ newLoop = scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+ outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+ Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+ IRMapping mapping;
+ mapping.map(
+ vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+ getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+ ivOuterLoop);
+
+ auto lhsClone = rewriterOuterLoop.clone(
+ *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+ getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+ ivOuterLoop);
+
+ auto rhsClone = rewriterOuterLoop.clone(
+ *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+ SmallVector<Value> accumulators = createTiledDp(
+ rewriter, locOuterLoop, ops, lhsClone->getResult(0),
+ rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop);
+
+ scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
+ });
+ }
+
+ // post processing after the loop
+ auto bufferType = MemRefType::get({16, 16}, opType);
+ auto bBuffer =
+ memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+ SmallVector<Value> dps = newLoop.getResults();
+
+ for (size_t i = 0; i < ops.size(); i++) {
+ vector::ContractionOp contOp = ops[i];
+ Operation *resultWriteOp =
+ traceToVectorWriteLikeUserOperation(contOp.getResult());
+ rewriter.setInsertionPoint(resultWriteOp);
+
+ Value indexOp_0 =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+ amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+ ValueRange{indexOp_0, indexOp_0}, dps[i]);
+
+ auto c0 = arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+ auto one =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+ auto mBound =
+ arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+
+ scf::ForOp::create(
+ rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
+ [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
+ auto row1 = vector::LoadOp::create(rewriter, loc,
+ VectorType::get(16, opType),
+ bBuffer, ValueRange{iv, c0});
+
+ Operation *readOp1 =
+ traceToVectorReadLikeParentOperation(ops[i].getAcc());
+
+ Value srcBuff;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(readOp1).Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) {
+ indexVals = SmallVector<OpFoldResult>(
+ readOp.getIndices().begin(), readOp.getIndices().end());
+ srcBuff = readOp.getOperand(0);
+ });
+
+ SmallVector<Value> indices;
+ indices.reserve(indexVals.size());
+
+ for (OpFoldResult ofr : indexVals) {
+ indices.push_back(
+ mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ }
+
+ Value sum = arith::AddIOp::create(builder, loc, iv, indices[0]);
+ indices[0] = sum;
+
+ auto row2 = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), srcBuff, indices);
+
+ Value addition;
+ if (ipType.isBF16())
+ addition = arith::AddFOp::create(rewriter, loc, row1, row2);
+
+ if (ipType.isSignlessInteger(8))
+ addition = arith::AddIOp::create(rewriter, loc, row1, row2);
+
+ vector::StoreOp::create(builder, loc, addition, srcBuff, indices);
+
+ scf::YieldOp::create(builder, outerLoop.getLoc());
+ });
+
+ rewriter.eraseOp(resultWriteOp);
+ }
+
+ return success();
+ }
+};
+
+} // namespace
+
+void amx::populateVectorContractToPackedTypeTiledDotProductPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToPackedTypeTiledDotProduct>(
+ patterns.getContext());
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 14cfb7b5ac352..dbd482ba85ee1 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -33,6 +33,7 @@
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -97,6 +98,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
+ amx::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
dlti::registerTransformDialectExtension(registry);
func::registerTransformDialectExtension(registry);
diff --git a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
new file mode 100644
index 0000000000000..081eb56a029f6
--- /dev/null
+++ b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
@@ -0,0 +1,242 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+ module {
+ func.func @brgemm_amx(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+ %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
+ %7:4 = scf.for %arg10 = %c0 to %c64 step %c16 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x64x2xbf16> to memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
+ %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x64x128x2xbf16> to memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
+ %8 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+ %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+ %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg11 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
+ %11 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %11, %arg12 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
+ %13 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %13, %9, %arg13 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
+ %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %13, %11, %arg14 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
+ scf.yield %10, %12, %14, %15 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+ }
+ scf.yield %7#0, %7#1, %7#2, %7#3 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+ }
+ vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+ vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xf32>
+ memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+ return %alloc : memref<64x128xf32>
+ }
+ }
+
+// CHECK-LABEL: @brgemm_amx
+// CHECK: amx.tile_mulf
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+ module {
+ func.func @batch_amx(%arg0: memref<64x64x2xbf16>, %arg1: memref<64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c32 = arith.constant 32 : index
+ %c16 = arith.constant 16 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+ %6:4 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
+ %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [32, 16, 2] [1, 1, 1] : memref<64x64x2xbf16> to memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>
+ %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 2] [1, 1, 1] : memref<64x128x2xbf16> to memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>
+ %7 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>, vector<16x16x2xbf16>
+ %8 = vector.transfer_read %subview_0[%c16, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>, vector<16x16x2xbf16>
+ %9 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>, vector<16x16x2xbf16>
+ %10 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} : memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>, vector<16x16x2xbf16>
+ %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %9, %arg6 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %10, %arg7 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg8 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %10, %arg9 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
+ scf.yield %11, %12, %13, %14 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+ }
+ vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+ vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xf32>
+ memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+ return %alloc : memref<64x128xf32>
+ }
+ }
+
+// CHECK-LABEL: @batch_amx
+// CHECK: amx.tile_mulf
+// CHECK-NOT: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+ module {
+ func.func @matmul_amx(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<16x64x128xf32>) -> memref<16x64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ scf.for %arg5 = %c0 to %c16 step %c1 {
+ %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 32, 32] [1, 1, 1] : memref<16x64x128xf32> to memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+ %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
+ %3 = vector.transfer_read %subview[%c0, %c0, %c16], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
+ %4 = vector.transfer_read %subview[%c0, %c16, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
+ %5 = vector.transfer_read %subview[%c0, %c16, %c16], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
+ %6:4 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2, %arg8 = %3, %arg9 = %4, %arg10 = %5) -> (vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>) {
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x64x2xbf16> to memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
+ %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x64x128x2xbf16> to memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
+ %7 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+ %8 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+ %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+ %10 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+ %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %9, %arg7 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %10, %arg8 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg9 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %10, %arg10 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
+ scf.yield %11, %12, %13, %14 : vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>
+ }
+ vector.transfer_write %6#3, %subview[%c0, %c16, %c16] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+ vector.transfer_write %6#2, %subview[%c0, %c16, %c0] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+ vector.transfer_write %6#1, %subview[%c0, %c0, %c16] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+ vector.transfer_write %6#0, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+ }
+ }
+ }
+ %alloc = memref.alloc() : memref<16x64x128xf32>
+ memref.copy %arg2, %alloc : memref<16x64x128xf32> to memref<16x64x128xf32>
+ return %alloc : memref<16x64x128xf32>
+ }
+ }
+
+// CHECK-LABEL: @matmul_amx
+// CHECK: amx.tile_mulf
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @amx(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @amx
+// CHECK: amx.tile_mulf
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
>From 43b55930d0edf7e032ecce2738d2e610fc1c3122 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 22 Feb 2026 22:55:44 -0800
Subject: [PATCH 2/7] fix code formatting errors
---
mlir/lib/RegisterAllExtensions.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index dbd482ba85ee1..6f01088a0ac3b 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -32,8 +32,8 @@
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
+#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
>From 4097bfb53ce3ab19376a99c7e875753d53001811 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 22 Feb 2026 23:27:54 -0800
Subject: [PATCH 3/7] replacing common fun with x86vector utils API
---
...torContractToPackedTypeTiledDotProduct.cpp | 96 ++-----------------
1 file changed, 6 insertions(+), 90 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
index 37fd3bb33848b..faac5530373e9 100644
--- a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
@@ -26,85 +27,6 @@ using namespace mlir::amx;
namespace {
-static Operation *traceToVectorReadLikeParentOperation(Value v) {
- while (true) {
- // Case 1: Value defined by an operation
- if (Operation *defOp = v.getDefiningOp()) {
- if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
- return defOp;
- }
-
- if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
- return nullptr;
- }
-
- return nullptr;
- }
-
- // Case 2: BlockArgument (scf.for iter_arg)
- if (auto barg = dyn_cast<BlockArgument>(v)) {
- auto *parentOp = barg.getOwner()->getParentOp();
-
- if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
- unsigned argNum = barg.getArgNumber();
-
- // arg0 = induction variable (not an iter_arg)
- if (argNum == 0)
- return nullptr;
-
- unsigned iterIdx = argNum - 1;
- v = forOp.getInitArgs()[iterIdx];
- continue;
- }
-
- return nullptr;
- }
-
- return nullptr;
- }
-}
-
-static Operation *traceToVectorWriteLikeUserOperation(Value v) {
- for (OpOperand &use : v.getUses()) {
- Operation *user = use.getOwner();
-
- // --- TERMINAL OPS ---
- if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
- return user;
- }
-
- if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
- return nullptr;
- }
-
- // --- SCF YIELD ---
- if (auto yield = dyn_cast<scf::YieldOp>(user)) {
- Operation *parent = yield->getParentOp();
- unsigned idx = use.getOperandNumber();
- if (auto *res =
- traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
- return res;
- continue;
- }
-
- // --- SCF FOR ---
- if (auto forOp = dyn_cast<scf::ForOp>(user)) {
- unsigned idx = use.getOperandNumber();
- if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
- return res;
- continue;
- }
-
- // --- GENERIC CASE ---
- for (Value res : user->getResults()) {
- if (auto *found = traceToVectorWriteLikeUserOperation(res))
- return found;
- }
- }
-
- return nullptr;
-}
-
static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
Value input, int64_t firstDimToCollapse) {
ShapedType inputType = cast<ShapedType>(input.getType());
@@ -297,13 +219,6 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
return loopItrArgs;
}
-// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
-// ```
-// to
-// ```
-// vector.broadcast %lhs to <16xf32>
-// vector.fma vector<16xf32>
-// ```
struct VectorContractToPackedTypeTiledDotProduct
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -316,10 +231,10 @@ struct VectorContractToPackedTypeTiledDotProduct
"Expects add combining kind.");
Operation *accReadOp =
- traceToVectorReadLikeParentOperation(contractOp.getAcc());
+ x86vector::traceToVectorReadLikeParentOperation(contractOp.getAcc());
Operation *resultWriteOp =
- traceToVectorWriteLikeUserOperation(contractOp.getResult());
+ x86vector::traceToVectorWriteLikeUserOperation(contractOp.getResult());
if (!accReadOp || !resultWriteOp)
return failure();
@@ -510,7 +425,7 @@ struct VectorContractToPackedTypeTiledDotProduct
for (size_t i = 0; i < ops.size(); i++) {
vector::ContractionOp contOp = ops[i];
Operation *resultWriteOp =
- traceToVectorWriteLikeUserOperation(contOp.getResult());
+ x86vector::traceToVectorWriteLikeUserOperation(contOp.getResult());
rewriter.setInsertionPoint(resultWriteOp);
Value indexOp_0 =
@@ -533,7 +448,8 @@ struct VectorContractToPackedTypeTiledDotProduct
bBuffer, ValueRange{iv, c0});
Operation *readOp1 =
- traceToVectorReadLikeParentOperation(ops[i].getAcc());
+ x86vector::traceToVectorReadLikeParentOperation(
+ ops[i].getAcc());
Value srcBuff;
SmallVector<OpFoldResult> indexVals;
>From 4f74ede48afb561423fe79204b35d6415539c97e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 2 Mar 2026 23:45:37 -0800
Subject: [PATCH 4/7] clean-up + more test-cases
---
mlir/include/mlir/Dialect/AMX/Transforms.h | 7 +-
...torContractToPackedTypeTiledDotProduct.cpp | 515 +++++++----
.../AMX/vector-contract-to-tiled-dp.mlir | 867 +++++++++++++++---
3 files changed, 1089 insertions(+), 300 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index f95bf296c1555..50eab54ac58ab 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -30,10 +30,13 @@ void registerConvertAMXToLLVMInterface(DialectRegistry ®istry);
namespace amx {
+// A set of patterns for lowering 32-bit packed vector contraction operations
+// to their corresponding packed-type tiled dot-product operations, using
+// AMX ultimately targeting the relevant x86 LLVM intrinsics (e.g., BF16 and
+// Int8).
void populateVectorContractToPackedTypeTiledDotProductPatterns(
RewritePatternSet &patterns);
-
-}
+} // namespace amx
} // namespace mlir
diff --git a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
index faac5530373e9..af4f3ef1c934a 100644
--- a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -27,25 +27,37 @@ using namespace mlir::amx;
namespace {
+// Function to collapse the last two dimension (vnni and k) to help the
+// amx.tile_load to correctly load the packed element type.
static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
- Value input, int64_t firstDimToCollapse) {
+ Value input) {
ShapedType inputType = cast<ShapedType>(input.getType());
+ int64_t firstDimToCollapse = inputType.getRank() - 2;
+
if (inputType.getRank() == 1)
return input;
+
SmallVector<ReassociationIndices> reassociation;
for (int64_t i = 0; i < firstDimToCollapse; ++i)
reassociation.push_back(ReassociationIndices{i});
+
ReassociationIndices collapsedIndices;
for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
collapsedIndices.push_back(i);
+
reassociation.push_back(collapsedIndices);
return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
}
-static std::pair<Value, SmallVector<Value>> getSrcIndxValue(OpBuilder &rewriter,
- Location loc,
- Value operand,
- bool isNotAcc) {
+// Get the MemRef source and offset index for the operands of
+// vector.contract.
+static FailureOr<std::pair<Value, SmallVector<Value>>>
+getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
+ bool isNotAcc) {
+ Operation *defOp = operand.getDefiningOp();
+ if (!defOp)
+ return failure();
+
Value srcBuff;
SmallVector<OpFoldResult> indexVals;
llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
@@ -55,6 +67,9 @@ static std::pair<Value, SmallVector<Value>> getSrcIndxValue(OpBuilder &rewriter,
srcBuff = readOp.getOperand(0);
});
+ if (!srcBuff)
+ return failure();
+
if (isNotAcc) {
indexVals.pop_back();
}
@@ -68,24 +83,95 @@ static std::pair<Value, SmallVector<Value>> getSrcIndxValue(OpBuilder &rewriter,
}
if (isNotAcc) {
- auto subviewType = cast<ShapedType>(srcBuff.getType());
- auto subviewRank = subviewType.getRank();
- srcBuff = collapseInnerDims(rewriter, loc, srcBuff, subviewRank - 2);
+ srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
+ }
+
+ return std::make_pair(srcBuff, indices);
+}
+
+// Function to validate the vector.contract operation.
+static LogicalResult validateContractOps(OpBuilder &rewriter,
+ vector::ContractionOp contractOp,
+ unsigned int blockingFactor,
+ Value srcBuffLhs, Value srcBuffRhs,
+ bool srcValidate) {
+
+ if (srcValidate) {
+ // Get the MemRef buffer of LHS operand.
+ auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getLhs(), false);
+ if (failed(srcIndxLhs))
+ return failure();
+ auto [buffLhs, indicesLhs] = *srcIndxLhs;
+
+ // Get the MemRef buffer of RHS operand.
+ auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getRhs(), false);
+ if (failed(srcIndxRhs))
+ return failure();
+ auto [buffRhs, indicesRhs] = *srcIndxRhs;
+
+ // Return failure if the Memref buff didn't match.
+ if (buffLhs != srcBuffLhs)
+ return failure();
+
+ if (buffRhs != srcBuffRhs)
+ return failure();
}
- return {srcBuff, indices};
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return failure();
+
+ // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+ if (nonUnitDimAcc.size() != 0)
+ return failure();
+
+ // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+ // <16x16x4>. The vnni dims should be 2 or 4.
+ VectorType lhsTy = contractOp.getLhsType();
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+ if (nonUnitDimLhs.size() != 1)
+ return failure();
+
+ if (nonUnitDimLhs[0] != blockingFactor)
+ return failure();
+
+ // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+ // <16x16x4>. The vnni dims should be 2 or 4.
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+ if (nonUnitDimRhs.size() != 1)
+ return failure();
+
+ if (nonUnitDimRhs[0] != blockingFactor)
+ return failure();
+
+ return success();
}
+// Returns the loop index position to get mapped during the
+// MemRef type clone.
static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
Value iv = loop.getInductionVar();
Value srcBuff;
- SmallVector<OpFoldResult> indexVals;
llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
- .Case<TransferReadOp, LoadOp>([&](auto readOp) {
- indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
- readOp.getIndices().end());
- srcBuff = readOp.getOperand(0);
- });
+ .Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) { srcBuff = readOp.getOperand(0); });
auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
if (!subview)
@@ -101,103 +187,76 @@ static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
return 0;
}
+// Creates amx.tile_loads.
static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
Value operand, Value mat, Type ipType,
- bool rhs) {
-
- SmallVector<OpFoldResult> indexVals;
- llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
- .Case<TransferReadOp, LoadOp>([&](auto readOp) {
- indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
- readOp.getIndices().end());
- });
-
- indexVals.pop_back();
- SmallVector<Value> indices;
- indices.reserve(indexVals.size());
+ bool rhs, unsigned int offset) {
- for (OpFoldResult ofr : indexVals) {
- indices.push_back(
- mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
- }
+ auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
+ auto [srcBuff, indices] = *srcIndx;
+ indices.pop_back();
if (rhs) {
- int offset = 4;
- if (ipType.isBF16()) {
- offset = 2;
- }
-
- auto c2 = arith::ConstantIndexOp::create(rewriter, loc, offset);
- indices[indices.size() - 1] =
- arith::MulIOp::create(rewriter, loc, indices[indices.size() - 1], c2);
- }
-
- amx::TileType tileType = amx::TileType::get({16, 64}, ipType);
- if (ipType.isBF16()) {
- tileType = amx::TileType::get({16, 32}, ipType);
+ auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ indices[indices.size() - 1] = arith::MulIOp::create(
+ rewriter, loc, indices[indices.size() - 1], cOffset);
}
+ amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
-
return load;
}
+// Creates tiled amx dot-products.
static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
SmallVector<vector::ContractionOp> ops,
Value matA, Value matB, Type ipType,
- Type opType, ValueRange accIterArgs) {
- auto subviewType = cast<ShapedType>(matA.getType());
- auto subviewRank = subviewType.getRank();
- auto collapsedOpnd = collapseInnerDims(rewriter, loc, matA, subviewRank - 2);
+ Type opType, ValueRange accIterArgs,
+ unsigned int offset) {
- auto subviewType1 = cast<ShapedType>(matB.getType());
- auto subviewRank1 = subviewType1.getRank();
- auto collapsedOpnd1 =
- collapseInnerDims(rewriter, loc, matB, subviewRank1 - 2);
+ auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
+ auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
SmallVector<Value> accumulators;
+ // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
+ // or load operations.
llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+ // Iterate over the contraction operations and compute the tiled dot-product.
for (size_t i = 0; i < ops.size(); i++) {
Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
-
amx::TileLoadOp tilesLhs;
-
- auto it = readsToTileLoads.find(readOpLhs);
- if (it != readsToTileLoads.end()) {
- tilesLhs = it->second;
+ auto itLhs = readsToTileLoads.find(readOpLhs);
+ if (itLhs != readsToTileLoads.end()) {
+ tilesLhs = itLhs->second;
} else {
-
- tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), collapsedOpnd,
- ipType, false);
+ tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
+ subviewCollapseLhs, ipType, false, offset);
readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
}
Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
-
amx::TileLoadOp tilesRhs;
-
- auto it1 = readsToTileLoads.find(readOpRhs);
- if (it1 != readsToTileLoads.end()) {
- tilesRhs = it1->second;
+ auto itRhs = readsToTileLoads.find(readOpRhs);
+ if (itRhs != readsToTileLoads.end()) {
+ tilesRhs = itRhs->second;
} else {
-
- tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), collapsedOpnd1,
- ipType, true);
+ tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
+ subviewCollapseRhs, ipType, true, offset);
readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
}
- auto tileType1 = amx::TileType::get({16, 16}, opType);
+ auto accTileType = amx::TileType::get({16, 16}, opType);
Value dp;
if (ipType.isBF16())
- dp = amx::TileMulFOp::create(rewriter, loc, tileType1, tilesLhs, tilesRhs,
- accIterArgs[i]);
+ dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
+ tilesRhs, accIterArgs[i]);
if (ipType.isSignlessInteger(8))
- dp = amx::TileMulIOp::create(rewriter, loc, tileType1, tilesLhs, tilesRhs,
- accIterArgs[i]);
+ dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
+ tilesRhs, accIterArgs[i]);
accumulators.push_back(dp);
}
@@ -219,6 +278,23 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
return loopItrArgs;
}
+// Implements tiled dot-product operation for a vector.contract operation or a
+// sequence of vector.contracts inside the reduction loops.
+//
+// For example - for F32 type:
+// ```
+// vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+// vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+// vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
+// vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
+// ```
+// to
+// ```
+// amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+// amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+// amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
+// amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
+// ```
struct VectorContractToPackedTypeTiledDotProduct
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -230,6 +306,32 @@ struct VectorContractToPackedTypeTiledDotProduct
return rewriter.notifyMatchFailure(contractOp,
"Expects add combining kind.");
+ unsigned int blockingFactor =
+ contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
+ bool isVnni = x86vector::isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(),
+ blockingFactor);
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isBF16() &&
+ !lhsTy.getElementType().isSignlessInteger(8))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only BF16/Int8 lowering is supported.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+ if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+ (lhsTy.getElementType().isSignlessInteger(8) &&
+ !accTy.getElementType().isSignlessInteger(32)))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only F32 for BF16 or Int32 for Int8 "
+ "accumulation type is supported.");
+ if (!isVnni)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only VNNI-packed inputs are supported.");
+
Operation *accReadOp =
x86vector::traceToVectorReadLikeParentOperation(contractOp.getAcc());
@@ -237,50 +339,110 @@ struct VectorContractToPackedTypeTiledDotProduct
x86vector::traceToVectorWriteLikeUserOperation(contractOp.getResult());
if (!accReadOp || !resultWriteOp)
- return failure();
+ return rewriter.notifyMatchFailure(
+ contractOp, "The ACC operand of the vector.contract should be a "
+ "transfer_read or a load. And, the result should be "
+ "stored using transfer_write or store.");
- /*if (accReadOp->getBlock() == contractOp->getBlock())
- return failure();
+ Type ipType;
+ Type opType;
+
+ if (lhsTy.getElementType().isBF16()) {
+ ipType = rewriter.getBF16Type();
+ opType = rewriter.getF32Type();
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ ipType = rewriter.getIntegerType(8);
+ opType = rewriter.getIntegerType(32);
+ }
- if (resultWriteOp->getBlock() == contractOp->getBlock())
- return failure(); */
+ if (accReadOp->getBlock() == contractOp->getBlock() &&
+ resultWriteOp->getBlock() != contractOp->getBlock())
+ return rewriter.notifyMatchFailure(
+ contractOp, "The accumulator store is in different block.");
+
+ if (accReadOp->getBlock() != contractOp->getBlock() &&
+ resultWriteOp->getBlock() == contractOp->getBlock())
+ return rewriter.notifyMatchFailure(
+ contractOp, "The accumulator read is in different block.");
- // case for just one vc rewrite.
+ // Case 1: For just one VC rewrite. Where all accumulator read/write
+ // within the same block.
if (accReadOp->getBlock() == contractOp->getBlock() &&
resultWriteOp->getBlock() == contractOp->getBlock()) {
- Location loc = contractOp.getLoc();
- auto tileType = amx::TileType::get({16, 32}, rewriter.getBF16Type());
- auto [srcBuffLhs, indicesLhs] = getSrcIndxValue(
- rewriter, contractOp.getLoc(), contractOp.getLhs(), true);
+ LogicalResult validate = validateContractOps(
+ rewriter, contractOp, blockingFactor, Value(), Value(), false);
+
+ if (failed(validate))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The contract operation doesn't satisfy the operands "
+ "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
+ "The rest dims should be 1.");
+
+ Location loc = contractOp.getLoc();
+ llvm::outs() << "Reaching-here1" << "\n";
+ auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getLhs(), true);
+ llvm::outs() << "Reaching-here2" << "\n";
+ if (failed(srcIndxLhs))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The LHS src is not a MemRef type.");
+ llvm::outs() << "Reaching-here3" << "\n";
+
+ auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
+
+ auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getRhs(), true);
+ if (failed(srcIndxRhs))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The RHS src is not a MemRef type.");
+ auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+
+ auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getAcc(), false);
+ if (failed(srcIndxAcc))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The ACC src is not a MemRef type.");
+ auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
+
+ // amx.tile_loads
+ auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
srcBuffLhs, indicesLhs);
-
- auto [srcBuffRhs, indicesRhs] = getSrcIndxValue(
- rewriter, contractOp.getLoc(), contractOp.getRhs(), true);
auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
srcBuffRhs, indicesRhs);
- auto [srcBuffAcc, indicesAcc] = getSrcIndxValue(
- rewriter, contractOp.getLoc(), contractOp.getAcc(), false);
- auto tileTypeAcc = amx::TileType::get({16, 16}, rewriter.getF32Type());
+ auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
srcBuffAcc, indicesAcc);
- auto tdp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
- loadRhs, loadAcc);
- amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, tdp);
+ // Tiled dot-product.
+ Value dp;
+ if (ipType.isBF16())
+ dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+ loadRhs, loadAcc);
+
+ if (ipType.isSignlessInteger(8))
+ dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+ loadRhs, loadAcc);
+
+ amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
rewriter.eraseOp(resultWriteOp);
return success();
}
- SmallVector<scf::ForOp> list;
+ // Case 2: The acc are passed as iter args through the reduction loop.
+ // We support, reduction loop depth until 2. TODO: Support for n-depth
+ // reduction loop.
+ SmallVector<scf::ForOp> loopLists;
Operation *current = contractOp;
while (true) {
Operation *parent = current->getParentOfType<scf::ForOp>();
- list.push_back(dyn_cast<scf::ForOp>(parent));
+ loopLists.push_back(dyn_cast<scf::ForOp>(parent));
if (accReadOp->getBlock() == parent->getBlock()) {
break;
@@ -289,43 +451,63 @@ struct VectorContractToPackedTypeTiledDotProduct
current = parent;
}
- if (list.size() > 2 || list.size() == 0)
- return failure();
+ if (loopLists.size() > 2 || loopLists.size() == 0)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Rewrite is supported until reduction loop depth of 2.");
- SmallVector<vector::ContractionOp> ops;
+ auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getLhs(), false);
+ if (failed(srcIndxLhs))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The LHS src is not a MemRef type.");
+ auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
- for (mlir::Operation &op : list[0].getBody()->getOperations()) {
+ auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getRhs(), false);
+ if (failed(srcIndxRhs))
+ return rewriter.notifyMatchFailure(contractOp,
+ "The RHS src is not a MemRef type.");
+ auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+
+ Operation *vectorOpLhs;
+ llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ vectorOpLhs = readOp.getBase().getDefiningOp();
+ });
+
+ Operation *vectorOpRhs;
+ llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ vectorOpRhs = readOp.getBase().getDefiningOp();
+ });
+
+ // Retrive all the contaction operation within the loop.
+ SmallVector<vector::ContractionOp> ops;
+ for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
- ops.push_back(contract);
- }
- }
- Type ipType;
- Type opType;
- VectorType lhsTy = contractOp.getLhsType();
- if (lhsTy.getElementType().isBF16()) {
- ipType = rewriter.getBF16Type();
- opType = rewriter.getF32Type();
- }
+ LogicalResult validate = validateContractOps(
+ rewriter, contract, blockingFactor, srcBuffLhs, srcBuffRhs, true);
- if (lhsTy.getElementType().isSignlessInteger(8)) {
- ipType = rewriter.getIntegerType(8);
- opType = rewriter.getIntegerType(32);
+ if (failed(validate))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The associated contract operations doesn't satisfy "
+ "the re-write conditions either the dimensions are "
+ "wrong or MemRef source are different.");
+
+ ops.push_back(contract);
+ }
}
scf::ForOp outerLoop;
scf::ForOp innerLoop;
- auto vectorReadOpLhs =
- contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
- auto vectorReadOpRhs =
- contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
-
scf::ForOp newLoop;
- if (list.size() == 2) {
- outerLoop = list[1];
- innerLoop = list[0];
+ // Case 2a: Reduction loop depth is 2.
+ if (loopLists.size() == 2) {
+ outerLoop = loopLists[1];
+ innerLoop = loopLists[0];
SmallVector<Value> loopItrArgs = createTileZeros(
rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
@@ -343,32 +525,32 @@ struct VectorContractToPackedTypeTiledDotProduct
Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
IRMapping mapping;
mapping.map(
- vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+ vectorOpLhs->getOperand(
getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
ivOuterLoop);
mapping.map(
- vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+ vectorOpLhs->getOperand(
getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
ivNewInnerLoop);
- auto lhsClone = rewriterNewInnerLoop.clone(
- *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+ auto lhsClone =
+ rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
IRMapping rhsMapping;
rhsMapping.map(
- vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+ vectorOpRhs->getOperand(
getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
ivOuterLoop);
rhsMapping.map(
- vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+ vectorOpRhs->getOperand(
getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
ivNewInnerLoop);
- auto rhsClone = rewriterNewInnerLoop.clone(
- *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+ auto rhsClone =
+ rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
SmallVector<Value> accumulators = createTiledDp(
rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
rhsClone->getResult(0), ipType, opType,
- iterArgsNewInnerLoop);
+ iterArgsNewInnerLoop, blockingFactor);
scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
accumulators);
@@ -379,8 +561,9 @@ struct VectorContractToPackedTypeTiledDotProduct
});
}
- if (list.size() == 1) {
- outerLoop = list[0];
+ // Case 2a: Reduction loop depth is 1.
+ if (loopLists.size() == 1) {
+ outerLoop = loopLists[0];
SmallVector<Value> loopItrArgs = createTileZeros(
rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
@@ -391,37 +574,37 @@ struct VectorContractToPackedTypeTiledDotProduct
Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
IRMapping mapping;
mapping.map(
- vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+ vectorOpLhs->getOperand(
getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
ivOuterLoop);
- auto lhsClone = rewriterOuterLoop.clone(
- *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+ auto lhsClone = rewriterOuterLoop.clone(*vectorOpLhs, mapping);
IRMapping rhsMapping;
rhsMapping.map(
- vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+ vectorOpRhs->getOperand(
getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
ivOuterLoop);
- auto rhsClone = rewriterOuterLoop.clone(
- *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+ auto rhsClone = rewriterOuterLoop.clone(*vectorOpRhs, rhsMapping);
SmallVector<Value> accumulators = createTiledDp(
rewriter, locOuterLoop, ops, lhsClone->getResult(0),
- rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop);
+ rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop,
+ blockingFactor);
scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
});
}
- // post processing after the loop
+ // post processing after the loop creation.
+ // Copy the amx tile accumulation results to a MemRef buffer, add the
+ // initial accumulation value, and store back to the C-Matrix
auto bufferType = MemRefType::get({16, 16}, opType);
auto bBuffer =
memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
SmallVector<Value> dps = newLoop.getResults();
-
for (size_t i = 0; i < ops.size(); i++) {
vector::ContractionOp contOp = ops[i];
Operation *resultWriteOp =
@@ -443,45 +626,47 @@ struct VectorContractToPackedTypeTiledDotProduct
scf::ForOp::create(
rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
[&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
- auto row1 = vector::LoadOp::create(rewriter, loc,
- VectorType::get(16, opType),
- bBuffer, ValueRange{iv, c0});
+ auto resultAcc = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(16, opType), bBuffer,
+ ValueRange{iv, c0});
- Operation *readOp1 =
+ Operation *accReadOp =
x86vector::traceToVectorReadLikeParentOperation(
ops[i].getAcc());
- Value srcBuff;
- SmallVector<OpFoldResult> indexVals;
- llvm::TypeSwitch<Operation *>(readOp1).Case<TransferReadOp, LoadOp>(
- [&](auto readOp) {
- indexVals = SmallVector<OpFoldResult>(
- readOp.getIndices().begin(), readOp.getIndices().end());
- srcBuff = readOp.getOperand(0);
- });
+ Value srcBuffAcc;
+ SmallVector<Value> indicesAcc;
- SmallVector<Value> indices;
- indices.reserve(indexVals.size());
+ llvm::TypeSwitch<Operation *>(accReadOp)
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ srcBuffAcc = readOp.getOperand(0);
- for (OpFoldResult ofr : indexVals) {
- indices.push_back(
- mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
- }
+ auto indices = readOp.getIndices();
+ indicesAcc.reserve(indices.size());
- Value sum = arith::AddIOp::create(builder, loc, iv, indices[0]);
- indices[0] = sum;
+ llvm::transform(
+ indices, std::back_inserter(indicesAcc),
+ [&](OpFoldResult ofr) {
+ return mlir::getValueOrCreateConstantIndexOp(rewriter,
+ loc, ofr);
+ });
+ });
- auto row2 = vector::LoadOp::create(
- rewriter, loc, VectorType::get(16, opType), srcBuff, indices);
+ Value sum = arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
+ indicesAcc[0] = sum;
+ auto acc = vector::LoadOp::create(rewriter, loc,
+ VectorType::get(16, opType),
+ srcBuffAcc, indicesAcc);
Value addition;
if (ipType.isBF16())
- addition = arith::AddFOp::create(rewriter, loc, row1, row2);
+ addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
if (ipType.isSignlessInteger(8))
- addition = arith::AddIOp::create(rewriter, loc, row1, row2);
+ addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
- vector::StoreOp::create(builder, loc, addition, srcBuff, indices);
+ vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+ indicesAcc);
scf::YieldOp::create(builder, outerLoop.getLoc());
});
diff --git a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
index 081eb56a029f6..ccc1a0f0c9bdc 100644
--- a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
@@ -1,57 +1,311 @@
// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x32x16x4xi8>
+!memrefB = memref<1x16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_int8(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_int8
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x16xi32>
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x16x4xi8>
+!vecB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x16x4xi8>
+!memrefB = memref<16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_int8(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_int8
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x16xi32>
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_bf16(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_bf16
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x32xbf16>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x32xbf16>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x16xf32>
+// CHECK: amx.tile_mulf
+// CHECK: amx.tile_store {{.*}} !amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<1x16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<1x32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_bf16(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0, %c0], %32 {in_bounds = [true, true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @batch_matmul_bf16
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x32xbf16>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x32xbf16>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x16xf32>
+// CHECK: amx.tile_mulf
+// CHECK: amx.tile_store {{.*}} !amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecAB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
+!memrefB = memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
- module {
- func.func @brgemm_amx(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
- %0 = ub.poison : f32
- %1 = ub.poison : bf16
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- %c128 = arith.constant 128 : index
- %c16 = arith.constant 16 : index
- %c32 = arith.constant 32 : index
- %c1 = arith.constant 1 : index
- scf.for %arg3 = %c0 to %c64 step %c32 {
- scf.for %arg4 = %c0 to %c128 step %c32 {
- %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
- %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
- %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
- %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
- %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
- %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
- %7:4 = scf.for %arg10 = %c0 to %c64 step %c16 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
- %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x64x2xbf16> to memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
- %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x64x128x2xbf16> to memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
- %8 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
- %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
- %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg11 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
- %11 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
- %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %11, %arg12 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
- %13 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
- %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %13, %9, %arg13 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
- %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %13, %11, %arg14 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
- scf.yield %10, %12, %14, %15 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
- }
- scf.yield %7#0, %7#1, %7#2, %7#3 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
- }
- vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
- vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
- vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
- vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+func.func @brgemm_bf16_loop(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
+ %0 = ub.poison : f32
+ %1 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+
+ scf.for %arg3 = %c0 to %c64 step %c32 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+ memref<64x128xf32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %7:4 = scf.for %arg10 = %c0 to %c64 step %c16 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10, 0] [1, 32, 16, 2] [1, 1, 1, 1] :
+ memref<16x64x64x2xbf16> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] :
+ memref<16x64x128x2xbf16> to !memrefB
+ %8 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecAB
+ %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecAB
+
+ %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %9, %arg11 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+ %11 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecAB
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %11, %arg12 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %13 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecAB
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %13, %9, %arg13 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %13, %11, %arg14 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+ scf.yield %10, %12, %14, %15 : !vecC, !vecC, !vecC, !vecC
}
+ scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC
}
- %alloc = memref.alloc() : memref<64x128xf32>
- memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
- return %alloc : memref<64x128xf32>
+
+ vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
}
}
-// CHECK-LABEL: @brgemm_amx
-// CHECK: amx.tile_mulf
+ return %arg2 : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @brgemm_bf16_loop
+// CHECK-2: scf.for {{.*}} -> (!amx.tile<16x16xf32>, !amx.tile<16x16xf32>, !amx.tile<16x16xf32>, !amx.tile<16x16xf32>) {
+// CHECK-4: amx.tile_zero : !amx.tile<16x16xf32>
+// CHECK-4: amx.tile_load
+// CHECK-4: amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !amx.tile<16x16xf32>, !amx.tile<16x16xf32>, !amx.tile<16x16xf32>, !amx.tile<16x16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
// CHECK-NOT: vector.contract
module attributes {transform.with_named_sequence} {
@@ -66,55 +320,75 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecAB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x16x4xi8, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xi8, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- module {
- func.func @batch_amx(%arg0: memref<64x64x2xbf16>, %arg1: memref<64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
- %0 = ub.poison : f32
- %1 = ub.poison : bf16
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- %c128 = arith.constant 128 : index
- %c32 = arith.constant 32 : index
- %c16 = arith.constant 16 : index
- scf.for %arg3 = %c0 to %c64 step %c32 {
- scf.for %arg4 = %c0 to %c128 step %c32 {
- %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
- %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
- %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
- %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
- %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
- %6:4 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
- %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [32, 16, 2] [1, 1, 1] : memref<64x64x2xbf16> to memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>
- %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 2] [1, 1, 1] : memref<64x128x2xbf16> to memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>
- %7 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>, vector<16x16x2xbf16>
- %8 = vector.transfer_read %subview_0[%c16, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>, vector<16x16x2xbf16>
- %9 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>, vector<16x16x2xbf16>
- %10 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} : memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>, vector<16x16x2xbf16>
- %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %9, %arg6 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
- %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %10, %arg7 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
- %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg8 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
- %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %10, %arg9 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
- scf.yield %11, %12, %13, %14 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
- }
- vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
- vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
- vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
- vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
- }
+func.func @matmul_int8_loop(%arg0: memref<64x64x4xi8>, %arg1: memref<64x128x4xi8>, %arg2: memref<64x128xi32>) {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+ memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+ memref<64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+ memref<64x128x4xi8> to !memrefB
+ %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecAB
+ %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+ %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+ scf.yield %8, %9 : !vecC, !vecC
}
- %alloc = memref.alloc() : memref<64x128xf32>
- memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
- return %alloc : memref<64x128xf32>
+
+ vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
}
}
-// CHECK-LABEL: @batch_amx
-// CHECK: amx.tile_mulf
-// CHECK-NOT: vector.contract
+ return
+}
+// CHECK-LABEL: @matmul_int8_loop
+// CHECK-2: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!amx.tile<16x16xi32>, !amx.tile<16x16xi32>) {
+// CHECK-3: amx.tile_load
+// CHECK-2: amx.tile_muli
+// CHECK: scf.yield {{.*}} !amx.tile<16x16xi32>, !amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>
+// CHECK-NOT: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -128,56 +402,73 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecAB = vector<1x16x16x4xi8>
+!vecC = vector<1x16x16xi32>
+!memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x16x32x4xi8, strided<[32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<1x16x32xi32, strided<[8192, 128, 1], offset: ?>>
+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
- module {
- func.func @matmul_amx(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<16x64x128xf32>) -> memref<16x64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
- %0 = ub.poison : f32
- %1 = ub.poison : bf16
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- %c128 = arith.constant 128 : index
- %c16 = arith.constant 16 : index
- %c32 = arith.constant 32 : index
- %c1 = arith.constant 1 : index
- scf.for %arg3 = %c0 to %c64 step %c32 {
- scf.for %arg4 = %c0 to %c128 step %c32 {
- scf.for %arg5 = %c0 to %c16 step %c1 {
- %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 32, 32] [1, 1, 1] : memref<16x64x128xf32> to memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
- %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
- %3 = vector.transfer_read %subview[%c0, %c0, %c16], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
- %4 = vector.transfer_read %subview[%c0, %c16, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
- %5 = vector.transfer_read %subview[%c0, %c16, %c16], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
- %6:4 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2, %arg8 = %3, %arg9 = %4, %arg10 = %5) -> (vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>) {
- %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x64x2xbf16> to memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
- %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x64x128x2xbf16> to memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
- %7 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
- %8 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
- %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
- %10 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
- %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %9, %arg7 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
- %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %10, %arg8 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
- %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg9 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
- %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %10, %arg10 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
- scf.yield %11, %12, %13, %14 : vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>
- }
- vector.transfer_write %6#3, %subview[%c0, %c16, %c16] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
- vector.transfer_write %6#2, %subview[%c0, %c16, %c0] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
- vector.transfer_write %6#1, %subview[%c0, %c0, %c16] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
- vector.transfer_write %6#0, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
- }
+func.func @batch_matmul_int8_loop(%arg0: memref<16x64x64x4xi8>, %arg1: memref<16x64x128x4xi8>, %arg2: memref<16x64x128xi32>) {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ scf.for %arg5 = %c0 to %c16 step %c1 {
+
+ %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 16, 32] [1, 1, 1] :
+ memref<16x64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c0, %c16], %0 {in_bounds = [true, true, true]} :
+ !memrefC, !vecC
+ %4:2 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2, %arg8 = %3) -> (!vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 16, 16, 4] [1, 1, 1, 1] :
+ memref<16x64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 4] [1, 1, 1, 1] :
+ memref<16x64x128x4xi8> to !memrefB
+ %5 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecAB
+ %6 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecAB
+ %7 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecAB
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %6, %arg7 {unroll_shape = array<i64: 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %7, %arg8 {unroll_shape = array<i64: 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ scf.yield %8, %9 : !vecC, !vecC
}
+
+ vector.transfer_write %4#1, %subview[%c0, %c0, %c16] {in_bounds = [true, true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ !vecC, !memrefC
}
- %alloc = memref.alloc() : memref<16x64x128xf32>
- memref.copy %arg2, %alloc : memref<16x64x128xf32> to memref<16x64x128xf32>
- return %alloc : memref<16x64x128xf32>
}
}
+ return
+}
-// CHECK-LABEL: @matmul_amx
-// CHECK: amx.tile_mulf
+// CHECK-LABEL: @batch_matmul_int8_loop
+// CHECK-2: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!amx.tile<16x16xi32>, !amx.tile<16x16xi32>) {
+// CHECK-3: amx.tile_load
+// CHECK-2: amx.tile_muli
+// CHECK: scf.yield {{.*}} !amx.tile<16x16xi32>, !amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>
// CHECK-NOT: vector.contract
module attributes {transform.with_named_sequence} {
@@ -192,16 +483,68 @@ module attributes {transform.with_named_sequence} {
// -----
-!vecA = vector<1x16x16x2xbf16>
-!vecB = vector<1x16x16x2xbf16>
-!vecC = vector<16x16xf32>
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x32x16x4xi8>
+!memrefB = memref<1x16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_invalid_vc_kind(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<mul>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_invalid_vc_kind
+// CHECK-NOT: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK-NOT: amx.tile_muli
+// CHECK-NOT: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x32x16x2xbf16>
+!vecB = vector<1x16x32x2xbf16>
+!vecC = vector<1x32x32xf32>
!memrefA = memref<1x32x16x2xbf16>
!memrefB = memref<1x16x32x2xbf16>
-!memrefC = memref<32x32xf32>
+!memrefC = memref<1x32x32xf32>
#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
-#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
-func.func @amx(
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_dimensions(
%arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
{
%c0 = arith.constant 0 : index
@@ -213,11 +556,113 @@ func.func @amx(
%2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
!memrefB, !vecB
+ %3 = vector.transfer_read %arg2[%c0, %c0, %c0], %32 {in_bounds = [true, true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_wrong_dimensions
+// CHECK-NOT: amx.tile_load
+// CHECK-NOT: amx.tile_mulf
+// CHECK-NOT: amx.tile_store
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x16x4xi8>
+!vecB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefB = memref<16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_no_memref_source_LHS(
+ %arg0: !vecA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefB, !vecB
+
%3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
%4 = vector.contract {
indexing_maps = [#map, #map1, #map2],
- iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %2, %3 : !vecA, !vecB into !vecC
+
+ vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_no_memref_source_LHS
+// CHECK-NOT: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK-NOT: amx.tile_muli
+// CHECK-NOT: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8>
+!memrefB = memref<64x32xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+func.func @negative_no_vnni_packed(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+ %4 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>}
%1, %2, %3 : !vecA, !vecB into !vecC
@@ -226,9 +671,12 @@ func.func @amx(
return %arg2 : !memrefC
}
-// CHECK-LABEL: @amx
-// CHECK: amx.tile_mulf
-// CHECK-NOT: vector.contract
+// CHECK-LABEL: @negative_no_vnni_packed
+// CHECK-NOT: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK-NOT: amx.tile_muli
+// CHECK-NOT: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK: vector.contract
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -240,3 +688,156 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+!vecAB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x16x4xi8, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xi8, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @negative_VCs_LHS_src_differ(%arg0: memref<64x64x4xi8>, %arg1: memref<64x128x4xi8>, %arg2: memref<64x128xi32>, %arg13: memref<64x64x4xi8>) {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+
+ %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+ memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+ %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+ memref<64x64x4xi8> to !memrefA
+ %subview_negative = memref.subview %arg13[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+ memref<64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+ memref<64x128x4xi8> to !memrefB
+ %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecAB
+ %odd_load = vector.transfer_read %subview_negative[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefA, !vecAB
+ %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+ %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+ !memrefB, !vecAB
+
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %odd_load, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+ scf.yield %8, %9 : !vecC, !vecC
+ }
+
+ vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+
+ return
+}
+
+// CHECK-LABEL: @negative_VCs_LHS_src_differ
+// CHECK-NOT: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} -> (!amx.tile<16x16xi32>, !amx.tile<16x16xi32>) {
+// CHECK-NOT: amx.tile_load
+// CHECK-NOT: amx.tile_muli
+// CHECK-NOT: scf.yield {{.*}} !amx.tile<16x16xi32>, !amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (vector<16x16xi32>, vector<16x16xi32>) {
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x32x4xi8>
+!vecC = vector<1x16x32xi32>
+!memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x16x32x4xi8, strided<[32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<1x16x32xi32, strided<[8192, 128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @negative_wrong_N_dim(%arg0: memref<16x64x64x4xi8>, %arg1: memref<16x64x128x4xi8>, %arg2: memref<16x64x128xi32>) {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ scf.for %arg5 = %c0 to %c16 step %c1 {
+
+ %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 16, 32] [1, 1, 1] :
+ memref<16x64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !memrefC, !vecC
+ %3 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2) -> (!vecC) {
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 16, 16, 4] [1, 1, 1, 1] :
+ memref<16x64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 4] [1, 1, 1, 1] :
+ memref<16x64x128x4xi8> to !memrefB
+ %4 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %5 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+ %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %4, %5, %arg7 {unroll_shape = array<i64: 1, 4, 16, 32, 16>} : !vecA, !vecB into !vecC
+ scf.yield %6 : !vecC
+ }
+ vector.transfer_write %3, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ !vecC, !memrefC
+ }
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: @negative_wrong_N_dim
+// CHECK-NOT: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK-NOT: amx.tile_load
+// CHECK-NOT: amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From 58f0eda53f2fe0e6d705798711ae20e2ffed5625 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 3 Mar 2026 00:51:26 -0800
Subject: [PATCH 5/7] added a -ve test-cases
---
.../AMX/vector-contract-to-tiled-dp.mlir | 84 +++++++++++++++++++
1 file changed, 84 insertions(+)
diff --git a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
index ccc1a0f0c9bdc..e95f739ea0197 100644
--- a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
@@ -841,3 +841,87 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+!vecAB = vector<1x1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x1x16x16x4xi8, strided<[262144, 16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x1x16x32x4xi8, strided<[524288, 32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+
+func.func @negative_reduction_loop_depth_3(%arg0: memref<2x16x64x64x4xi8>, %arg1: memref<2x16x64x128x4xi8>, %arg2: memref<64x128xi32>) attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c2 = arith.constant 2 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg3 = %c0 to %c64 step %c16 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+ memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+
+ %4:2 = scf.for %arg5 = %c0 to %c2 step %c1 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+ %5:2 = scf.for %arg8 = %c0 to %c16 step %c1 iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!vecC, !vecC) {
+ %6:2 = scf.for %arg11 = %c0 to %c64 step %c16 iter_args(%arg12 = %arg9, %arg13 = %arg10) -> (!vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg5, %arg8, %arg3, %arg11, 0] [1, 1, 16, 16, 4] [1, 1, 1, 1, 1] :
+ memref<2x16x64x64x4xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg8, %arg11, %arg4, 0] [1, 1, 16, 32, 4] [1, 1, 1, 1, 1] :
+ memref<2x16x64x128x4xi8> to !memrefB
+ %7 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+ !memrefA, !vecAB
+ %8 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+ !memrefB, !vecAB
+ %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+ !memrefB, !vecAB
+
+ %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %8, %arg12 {unroll_shape = array<i64: 1, 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %7, %9, %arg13 {unroll_shape = array<i64: 1, 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+ scf.yield %10, %11 : !vecC, !vecC
+ }
+ scf.yield %6#0, %6#1 : !vecC, !vecC
+ }
+ scf.yield %5#0, %5#1 : !vecC, !vecC
+ }
+
+ vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+ !vecC, !memrefC
+ }
+ }
+ return
+}
+
+
+// CHECK-LABEL: @negative_reduction_loop_depth_3
+// CHECK-NOT: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK-NOT: amx.tile_load
+// CHECK-NOT: amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From 0b8d1298999184b2b403d57e4317dca17b40a7c5 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 3 Mar 2026 01:46:34 -0800
Subject: [PATCH 6/7] remove the new AMX related files
---
.../AMX/TransformOps/AMXTransformOps.h | 31 ----------
.../AMX/TransformOps/AMXTransformOps.td | 32 -----------
.../Dialect/AMX/TransformOps/CMakeLists.txt | 4 --
.../AMX/TransformOps/AMXTransformOps.cpp | 57 -------------------
.../Dialect/AMX/TransformOps/CMakeLists.txt | 17 ------
5 files changed, 141 deletions(-)
delete mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
delete mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
delete mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
delete mode 100644 mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
delete mode 100644 mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
deleted file mode 100644
index 8806635df8eb5..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
+++ /dev/null
@@ -1,31 +0,0 @@
-//===- AMXTransformOps.h - AMX transform ops --------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
-#define MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
-
-#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
-#include "mlir/IR/OpImplementation.h"
-
-//===----------------------------------------------------------------------===//
-// AMX Transform Operations
-//===----------------------------------------------------------------------===//
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h.inc"
-
-namespace mlir {
-class DialectRegistry;
-
-namespace amx {
-void registerTransformDialectExtension(DialectRegistry ®istry);
-
-} // namespace amx
-} // namespace mlir
-
-#endif // MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
deleted file mode 100644
index 74bcba3c37e57..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
+++ /dev/null
@@ -1,32 +0,0 @@
-//===- AMXTransformOps.td - AMX transform ops --*- tablegen -*-------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef AMX_TRANSFORM_OPS
-#define AMX_TRANSFORM_OPS
-
-include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/IR/OpBase.td"
-include "mlir/Dialect/Transform/IR/TransformAttrs.td"
-include "mlir/Dialect/Transform/IR/TransformTypes.td"
-include "mlir/IR/RegionKindInterface.td"
-
-def ApplyVectorContractToPackedTypeTiledDotProductPatternsOp : Op<Transform_Dialect,
- "apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
- let description = [{
- Collect patterns to lower a BF16/Int8 type vector.contract operation
- to a BF16/Int8 tiled dot-product.
- }];
-
- let assemblyFormat = "attr-dict";
-}
-
-
-#endif // AMX_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
deleted file mode 100644
index 41255b936be71..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-set(LLVM_TARGET_DEFINITIONS AMXTransformOps.td)
-mlir_tablegen(AMXTransformOps.h.inc -gen-op-decls)
-mlir_tablegen(AMXTransformOps.cpp.inc -gen-op-defs)
-add_mlir_dialect_tablegen_target(MLIRAMXTransformOpsIncGen)
diff --git a/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
deleted file mode 100644
index 6ff573407b42b..0000000000000
--- a/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
+++ /dev/null
@@ -1,57 +0,0 @@
-//===- AMXTransformOps.cpp ------------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
-#include "mlir/Dialect/AMX/Transforms.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformDialect.h"
-#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/RegionKindInterface.h"
-
-using namespace mlir;
-using namespace mlir::amx;
-using namespace mlir::transform;
-
-void mlir::transform::ApplyVectorContractToPackedTypeTiledDotProductPatternsOp::
- populatePatterns(RewritePatternSet &patterns) {
- amx::populateVectorContractToPackedTypeTiledDotProductPatterns(patterns);
-}
-
-//===----------------------------------------------------------------------===//
-// Transform op registration
-//===----------------------------------------------------------------------===//
-
-namespace {
-class AMXTransformDialectExtension
- : public transform::TransformDialectExtension<
- AMXTransformDialectExtension> {
-public:
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AMXTransformDialectExtension)
-
- AMXTransformDialectExtension() {
- declareGeneratedDialect<amx::AMXDialect>();
- declareGeneratedDialect<LLVM::LLVMDialect>();
- registerTransformOps<
-#define GET_OP_LIST
-#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
- >();
- }
-};
-} // namespace
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
-
-void mlir::amx::registerTransformDialectExtension(DialectRegistry ®istry) {
- registry.addExtensions<AMXTransformDialectExtension>();
-}
diff --git a/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
deleted file mode 100644
index 30b4304586ab7..0000000000000
--- a/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
+++ /dev/null
@@ -1,17 +0,0 @@
-add_mlir_dialect_library(MLIRAMXTransformOps
- AMXTransformOps.cpp
-
- DEPENDS
- MLIRAMXTransformOpsIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRLLVMCommonConversion
- MLIRLLVMDialect
- MLIRVectorDialect
- MLIRSideEffectInterfaces
- MLIRTransformDialect
- MLIRTransformDialectUtils
- MLIRAMXDialect
- MLIRAMXTransforms
- )
>From 06fc39f126a511fe620bfc61c7decb6d34f17033 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 4 Mar 2026 22:21:35 -0800
Subject: [PATCH 7/7] removing couple of debugger statements.
---
.../Transforms/VectorContractToPackedTypeTiledDotProduct.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
index 4bdb5ff83bb74..d5c17528f1564 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -382,15 +382,12 @@ struct VectorContractToPackedTypeTiledDotProduct
"The rest dims should be 1.");
Location loc = contractOp.getLoc();
- llvm::outs() << "Reaching-here1" << "\n";
+
auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
contractOp.getLhs(), true);
- llvm::outs() << "Reaching-here2" << "\n";
if (failed(srcIndxLhs))
return rewriter.notifyMatchFailure(contractOp,
"The LHS src is not a MemRef type.");
- llvm::outs() << "Reaching-here3" << "\n";
-
auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
More information about the Mlir-commits
mailing list