[Mlir-commits] [mlir] [mlir][amx] Lower vector.contract to packed type tiled dot-product. (PR #182810)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 2 23:52:59 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Arun Thangamani (arun-thmn)
<details>
<summary>Changes</summary>
A transform pass to lower `vector.contract` operation to (a) `amx.tile_mulf` for BF16, or (b) `amx.tile_muli` for Int8 packed types.
---
Patch is 71.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/182810.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMX/CMakeLists.txt (+2)
- (added) mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h (+31)
- (added) mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td (+32)
- (added) mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt (+4)
- (modified) mlir/include/mlir/Dialect/AMX/Transforms.h (+10)
- (modified) mlir/lib/Dialect/AMX/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp (+57)
- (added) mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt (+17)
- (modified) mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp (+687)
- (modified) mlir/lib/RegisterAllExtensions.cpp (+2)
- (added) mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir (+843)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index f875c78d240cc..b211cb3b53fd7 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
add_mlir_interface(AMXInterfaces)
add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
new file mode 100644
index 0000000000000..8806635df8eb5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
@@ -0,0 +1,31 @@
+//===- AMXTransformOps.h - AMX transform ops --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
+#define MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// AMX Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace amx {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace amx
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
new file mode 100644
index 0000000000000..74bcba3c37e57
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
@@ -0,0 +1,32 @@
+//===- AMXTransformOps.td - AMX transform ops --*- tablegen -*-------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMX_TRANSFORM_OPS
+#define AMX_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractToPackedTypeTiledDotProductPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect patterns to lower a BF16/Int8 type vector.contract operation
+ to a BF16/Int8 tiled dot-product.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+
+#endif // AMX_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..41255b936be71
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS AMXTransformOps.td)
+mlir_tablegen(AMXTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(AMXTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRAMXTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 7391ec2ff6b14..50eab54ac58ab 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -28,6 +28,16 @@ void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
/// Register LLVM conversion interface for AMX dialect.
void registerConvertAMXToLLVMInterface(DialectRegistry ®istry);
+namespace amx {
+
+// A set of patterns for lowering 32-bit packed vector contraction operations
+// to their corresponding packed-type tiled dot-product operations, using
+// AMX ultimately targeting the relevant x86 LLVM intrinsics (e.g., BF16 and
+// Int8).
+void populateVectorContractToPackedTypeTiledDotProductPatterns(
+ RewritePatternSet &patterns);
+} // namespace amx
+
} // namespace mlir
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/AMX/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
new file mode 100644
index 0000000000000..6ff573407b42b
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
@@ -0,0 +1,57 @@
+//===- AMXTransformOps.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::amx;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractToPackedTypeTiledDotProductPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ amx::populateVectorContractToPackedTypeTiledDotProductPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class AMXTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ AMXTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AMXTransformDialectExtension)
+
+ AMXTransformDialectExtension() {
+ declareGeneratedDialect<amx::AMXDialect>();
+ declareGeneratedDialect<LLVM::LLVMDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
+
+void mlir::amx::registerTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<AMXTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..30b4304586ab7
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRAMXTransformOps
+ AMXTransformOps.cpp
+
+ DEPENDS
+ MLIRAMXTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRAMXDialect
+ MLIRAMXTransforms
+ )
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index e827bc475e930..e422f275c7742 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRAMXTransforms
LegalizeForLLVMExport.cpp
+ VectorContractToPackedTypeTiledDotProduct.cpp
LINK_LIBS PUBLIC
MLIRAMXDialect
diff --git a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
new file mode 100644
index 0000000000000..af4f3ef1c934a
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -0,0 +1,687 @@
+//===- VectorContractToPackedTypeTiledDotProduct.cpp ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::amx;
+
+namespace {
+
+// Function to collapse the last two dimension (vnni and k) to help the
+// amx.tile_load to correctly load the packed element type.
+static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
+ Value input) {
+ ShapedType inputType = cast<ShapedType>(input.getType());
+ int64_t firstDimToCollapse = inputType.getRank() - 2;
+
+ if (inputType.getRank() == 1)
+ return input;
+
+ SmallVector<ReassociationIndices> reassociation;
+ for (int64_t i = 0; i < firstDimToCollapse; ++i)
+ reassociation.push_back(ReassociationIndices{i});
+
+ ReassociationIndices collapsedIndices;
+ for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
+ collapsedIndices.push_back(i);
+
+ reassociation.push_back(collapsedIndices);
+ return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
+}
+
+// Get the MemRef source and offset index for the operands of
+// vector.contract.
+static FailureOr<std::pair<Value, SmallVector<Value>>>
+getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
+ bool isNotAcc) {
+ Operation *defOp = operand.getDefiningOp();
+ if (!defOp)
+ return failure();
+
+ Value srcBuff;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+ .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ srcBuff = readOp.getOperand(0);
+ });
+
+ if (!srcBuff)
+ return failure();
+
+ if (isNotAcc) {
+ indexVals.pop_back();
+ }
+
+ SmallVector<Value> indices;
+ indices.reserve(indexVals.size());
+
+ for (OpFoldResult ofr : indexVals) {
+ indices.push_back(
+ mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ }
+
+ if (isNotAcc) {
+ srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
+ }
+
+ return std::make_pair(srcBuff, indices);
+}
+
+// Function to validate the vector.contract operation.
+static LogicalResult validateContractOps(OpBuilder &rewriter,
+ vector::ContractionOp contractOp,
+ unsigned int blockingFactor,
+ Value srcBuffLhs, Value srcBuffRhs,
+ bool srcValidate) {
+
+ if (srcValidate) {
+ // Get the MemRef buffer of LHS operand.
+ auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getLhs(), false);
+ if (failed(srcIndxLhs))
+ return failure();
+ auto [buffLhs, indicesLhs] = *srcIndxLhs;
+
+ // Get the MemRef buffer of RHS operand.
+ auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+ contractOp.getRhs(), false);
+ if (failed(srcIndxRhs))
+ return failure();
+ auto [buffRhs, indicesRhs] = *srcIndxRhs;
+
+ // Return failure if the Memref buff didn't match.
+ if (buffLhs != srcBuffLhs)
+ return failure();
+
+ if (buffRhs != srcBuffRhs)
+ return failure();
+ }
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return failure();
+
+ // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+ if (nonUnitDimAcc.size() != 0)
+ return failure();
+
+ // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+ // <16x16x4>. The vnni dims should be 2 or 4.
+ VectorType lhsTy = contractOp.getLhsType();
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+ if (nonUnitDimLhs.size() != 1)
+ return failure();
+
+ if (nonUnitDimLhs[0] != blockingFactor)
+ return failure();
+
+ // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+ // <16x16x4>. The vnni dims should be 2 or 4.
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+ if (nonUnitDimRhs.size() != 1)
+ return failure();
+
+ if (nonUnitDimRhs[0] != blockingFactor)
+ return failure();
+
+ return success();
+}
+
+// Returns the loop index position to get mapped during the
+// MemRef type clone.
+static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
+ Value iv = loop.getInductionVar();
+
+ Value srcBuff;
+ llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+ .Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) { srcBuff = readOp.getOperand(0); });
+
+ auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
+ if (!subview)
+ return 0;
+
+ auto offsets = subview.getOffsets();
+
+ for (auto it : llvm::enumerate(offsets)) {
+ if (it.value() == iv)
+ return it.index();
+ }
+
+ return 0;
+}
+
+// Creates amx.tile_loads.
+static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
+ Value operand, Value mat, Type ipType,
+ bool rhs, unsigned int offset) {
+
+ auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
+ auto [srcBuff, indices] = *srcIndx;
+ indices.pop_back();
+
+ if (rhs) {
+ auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ indices[indices.size() - 1] = arith::MulIOp::create(
+ rewriter, loc, indices[indices.size() - 1], cOffset);
+ }
+
+ amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
+ auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+ return load;
+}
+
+// Creates tiled amx dot-products.
+static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
+ SmallVector<vector::ContractionOp> ops,
+ Value matA, Value matB, Type ipType,
+ Type opType, ValueRange accIterArgs,
+ unsigned int offset) {
+
+ auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
+ auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
+
+ SmallVector<Value> accumulators;
+ // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
+ // or load operations.
+ llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+
+ // Iterate over the contraction operations and compute the tiled dot-product.
+ for (size_t i = 0; i < ops.size(); i++) {
+
+ Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
+ amx::TileLoadOp tilesLhs;
+ auto itLhs = readsToTileLoads.find(readOpLhs);
+ if (itLhs != readsToTileLoads.end()) {
+ tilesLhs = itLhs->second;
+ } else {
+ tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
+ subviewCollapseLhs, ipType, false, offset);
+ readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
+ }
+
+ Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
+ amx::TileLoadOp tilesRhs;
+ auto itRhs = readsToTileLoads.find(readOpRhs);
+ if (itRhs != readsToTileLoads.end()) {
+ tilesRhs = itRhs->second;
+ } else {
+ tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
+ subviewCollapseRhs, ipType, true, offset);
+ readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
+ }
+
+ auto accTileType = amx::TileType::get({16, 16}, opType);
+
+ Value dp;
+ if (ipType.isBF16())
+ dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
+ tilesRhs, accIterArgs[i]);
+
+ if (ipType.isSignlessInteger(8))
+ dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
+ tilesRhs, accIterArgs[i]);
+
+ accumulators.push_back(dp);
+ }
+ return accumulators;
+}
+
+static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
+ Type opType, scf::ForOp outerLoop,
+ int64_t size) {
+ rewriter.setInsertionPoint(outerLoop);
+
+ SmallVector<Value> loopItrArgs;
+ auto zeroTileType = amx::TileType::get({16, 16}, opType);
+
+ for (int i = 0; i < size; i++) {
+ auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
+ loopItrArgs.push_back(zeroTile);
+ }
+ return loopItrArgs;
+}
+
+// Implements tiled dot-product operation for a vector.contract operation or a
+// sequence of vector.contracts inside the reduction loops.
+//
+// For example - for F32 type:
+// ```
+// vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+// vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+// vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
+// vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
+// ```
+// to
+// ```
+// amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+// amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+// amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
+// amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
+// ```
+struct VectorContractToPackedTypeTiledDotProduct
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ unsigned int blockingFactor =
+ contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
+ bool isVnni = x86vector::isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(),
+ blockin...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/182810
More information about the Mlir-commits
mailing list