[Mlir-commits] [mlir] [mlir][amx] Lower vector.contract to packed type tiled dot-product. (PR #182810)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 2 23:52:59 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Arun Thangamani (arun-thmn)

<details>
<summary>Changes</summary>

A transform pass to lower `vector.contract` operation to (a) `amx.tile_mulf` for BF16, or (b) `amx.tile_muli` for Int8 packed types.

---

Patch is 71.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/182810.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/AMX/CMakeLists.txt (+2) 
- (added) mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h (+31) 
- (added) mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td (+32) 
- (added) mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt (+4) 
- (modified) mlir/include/mlir/Dialect/AMX/Transforms.h (+10) 
- (modified) mlir/lib/Dialect/AMX/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp (+57) 
- (added) mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt (+17) 
- (modified) mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp (+687) 
- (modified) mlir/lib/RegisterAllExtensions.cpp (+2) 
- (added) mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir (+843) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index f875c78d240cc..b211cb3b53fd7 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
 
 add_mlir_interface(AMXInterfaces)
 add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
new file mode 100644
index 0000000000000..8806635df8eb5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
@@ -0,0 +1,31 @@
+//===- AMXTransformOps.h - AMX transform ops --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
+#define MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// AMX Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace amx {
+void registerTransformDialectExtension(DialectRegistry &registry);
+
+} // namespace amx
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
new file mode 100644
index 0000000000000..74bcba3c37e57
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
@@ -0,0 +1,32 @@
+//===- AMXTransformOps.td - AMX transform ops --*- tablegen -*-------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMX_TRANSFORM_OPS
+#define AMX_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractToPackedTypeTiledDotProductPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to lower a BF16/Int8 type vector.contract operation 
+	to a BF16/Int8 tiled dot-product.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+
+#endif // AMX_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..41255b936be71
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS AMXTransformOps.td)
+mlir_tablegen(AMXTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(AMXTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRAMXTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 7391ec2ff6b14..50eab54ac58ab 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -28,6 +28,16 @@ void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 /// Register LLVM conversion interface for AMX dialect.
 void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
 
+namespace amx {
+
+// A set of patterns for lowering 32-bit packed vector contraction operations
+// to their corresponding packed-type tiled dot-product operations, using
+// AMX ultimately targeting the relevant x86 LLVM intrinsics (e.g., BF16 and
+// Int8).
+void populateVectorContractToPackedTypeTiledDotProductPatterns(
+    RewritePatternSet &patterns);
+} // namespace amx
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/AMX/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
new file mode 100644
index 0000000000000..6ff573407b42b
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
@@ -0,0 +1,57 @@
+//===- AMXTransformOps.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::amx;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractToPackedTypeTiledDotProductPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  amx::populateVectorContractToPackedTypeTiledDotProductPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class AMXTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          AMXTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AMXTransformDialectExtension)
+
+  AMXTransformDialectExtension() {
+    declareGeneratedDialect<amx::AMXDialect>();
+    declareGeneratedDialect<LLVM::LLVMDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
+
+void mlir::amx::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<AMXTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..30b4304586ab7
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRAMXTransformOps
+  AMXTransformOps.cpp
+
+  DEPENDS
+  MLIRAMXTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRVectorDialect
+  MLIRSideEffectInterfaces
+  MLIRTransformDialect
+  MLIRTransformDialectUtils
+  MLIRAMXDialect
+  MLIRAMXTransforms
+  )
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index e827bc475e930..e422f275c7742 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRAMXTransforms
   LegalizeForLLVMExport.cpp
+  VectorContractToPackedTypeTiledDotProduct.cpp
 
   LINK_LIBS PUBLIC
   MLIRAMXDialect
diff --git a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
new file mode 100644
index 0000000000000..af4f3ef1c934a
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -0,0 +1,687 @@
+//===- VectorContractToPackedTypeTiledDotProduct.cpp ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::amx;
+
+namespace {
+
+// Function to collapse the last two dimension (vnni and k) to help the
+// amx.tile_load to correctly load the packed element type.
+static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
+                               Value input) {
+  ShapedType inputType = cast<ShapedType>(input.getType());
+  int64_t firstDimToCollapse = inputType.getRank() - 2;
+
+  if (inputType.getRank() == 1)
+    return input;
+
+  SmallVector<ReassociationIndices> reassociation;
+  for (int64_t i = 0; i < firstDimToCollapse; ++i)
+    reassociation.push_back(ReassociationIndices{i});
+
+  ReassociationIndices collapsedIndices;
+  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
+    collapsedIndices.push_back(i);
+
+  reassociation.push_back(collapsedIndices);
+  return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
+}
+
+// Get the MemRef source and offset index for the operands of
+// vector.contract.
+static FailureOr<std::pair<Value, SmallVector<Value>>>
+getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
+                bool isNotAcc) {
+  Operation *defOp = operand.getDefiningOp();
+  if (!defOp)
+    return failure();
+
+  Value srcBuff;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+      .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+        srcBuff = readOp.getOperand(0);
+      });
+
+  if (!srcBuff)
+    return failure();
+
+  if (isNotAcc) {
+    indexVals.pop_back();
+  }
+
+  SmallVector<Value> indices;
+  indices.reserve(indexVals.size());
+
+  for (OpFoldResult ofr : indexVals) {
+    indices.push_back(
+        mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  }
+
+  if (isNotAcc) {
+    srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
+  }
+
+  return std::make_pair(srcBuff, indices);
+}
+
+// Function to validate the vector.contract operation.
+static LogicalResult validateContractOps(OpBuilder &rewriter,
+                                         vector::ContractionOp contractOp,
+                                         unsigned int blockingFactor,
+                                         Value srcBuffLhs, Value srcBuffRhs,
+                                         bool srcValidate) {
+
+  if (srcValidate) {
+    // Get the MemRef buffer of LHS operand.
+    auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getLhs(), false);
+    if (failed(srcIndxLhs))
+      return failure();
+    auto [buffLhs, indicesLhs] = *srcIndxLhs;
+
+    // Get the MemRef buffer of RHS operand.
+    auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getRhs(), false);
+    if (failed(srcIndxRhs))
+      return failure();
+    auto [buffRhs, indicesRhs] = *srcIndxRhs;
+
+    // Return failure if the Memref buff didn't match.
+    if (buffLhs != srcBuffLhs)
+      return failure();
+
+    if (buffRhs != srcBuffRhs)
+      return failure();
+  }
+
+  VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+  if (!accTy)
+    return failure();
+
+  // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
+  ArrayRef<int64_t> accShape = accTy.getShape();
+  llvm::SmallVector<int64_t> nonUnitDimAcc;
+  llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+  if (nonUnitDimAcc.size() != 0)
+    return failure();
+
+  // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+  // <16x16x4>. The vnni dims should be 2 or 4.
+  VectorType lhsTy = contractOp.getLhsType();
+  ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+  llvm::SmallVector<int64_t> nonUnitDimLhs;
+  llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+                [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+  if (nonUnitDimLhs.size() != 1)
+    return failure();
+
+  if (nonUnitDimLhs[0] != blockingFactor)
+    return failure();
+
+  // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+  // <16x16x4>. The vnni dims should be 2 or 4.
+  VectorType rhsTy = contractOp.getRhsType();
+  ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+  llvm::SmallVector<int64_t> nonUnitDimRhs;
+  llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+                [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+  if (nonUnitDimRhs.size() != 1)
+    return failure();
+
+  if (nonUnitDimRhs[0] != blockingFactor)
+    return failure();
+
+  return success();
+}
+
+// Returns the loop index position to get mapped during the
+// MemRef type clone.
+static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
+  Value iv = loop.getInductionVar();
+
+  Value srcBuff;
+  llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+      .Case<TransferReadOp, LoadOp>(
+          [&](auto readOp) { srcBuff = readOp.getOperand(0); });
+
+  auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
+  if (!subview)
+    return 0;
+
+  auto offsets = subview.getOffsets();
+
+  for (auto it : llvm::enumerate(offsets)) {
+    if (it.value() == iv)
+      return it.index();
+  }
+
+  return 0;
+}
+
+// Creates amx.tile_loads.
+static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
+                                       Value operand, Value mat, Type ipType,
+                                       bool rhs, unsigned int offset) {
+
+  auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
+  auto [srcBuff, indices] = *srcIndx;
+  indices.pop_back();
+
+  if (rhs) {
+    auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
+    indices[indices.size() - 1] = arith::MulIOp::create(
+        rewriter, loc, indices[indices.size() - 1], cOffset);
+  }
+
+  amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
+  auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+  return load;
+}
+
+// Creates tiled amx dot-products.
+static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
+                                        SmallVector<vector::ContractionOp> ops,
+                                        Value matA, Value matB, Type ipType,
+                                        Type opType, ValueRange accIterArgs,
+                                        unsigned int offset) {
+
+  auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
+  auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
+
+  SmallVector<Value> accumulators;
+  // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
+  // or load operations.
+  llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+
+  // Iterate over the contraction operations and compute the tiled dot-product.
+  for (size_t i = 0; i < ops.size(); i++) {
+
+    Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
+    amx::TileLoadOp tilesLhs;
+    auto itLhs = readsToTileLoads.find(readOpLhs);
+    if (itLhs != readsToTileLoads.end()) {
+      tilesLhs = itLhs->second;
+    } else {
+      tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
+                                 subviewCollapseLhs, ipType, false, offset);
+      readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
+    }
+
+    Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
+    amx::TileLoadOp tilesRhs;
+    auto itRhs = readsToTileLoads.find(readOpRhs);
+    if (itRhs != readsToTileLoads.end()) {
+      tilesRhs = itRhs->second;
+    } else {
+      tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
+                                 subviewCollapseRhs, ipType, true, offset);
+      readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
+    }
+
+    auto accTileType = amx::TileType::get({16, 16}, opType);
+
+    Value dp;
+    if (ipType.isBF16())
+      dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
+                                   tilesRhs, accIterArgs[i]);
+
+    if (ipType.isSignlessInteger(8))
+      dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
+                                   tilesRhs, accIterArgs[i]);
+
+    accumulators.push_back(dp);
+  }
+  return accumulators;
+}
+
+static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
+                                          Type opType, scf::ForOp outerLoop,
+                                          int64_t size) {
+  rewriter.setInsertionPoint(outerLoop);
+
+  SmallVector<Value> loopItrArgs;
+  auto zeroTileType = amx::TileType::get({16, 16}, opType);
+
+  for (int i = 0; i < size; i++) {
+    auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
+    loopItrArgs.push_back(zeroTile);
+  }
+  return loopItrArgs;
+}
+
+// Implements tiled dot-product operation for a vector.contract operation or a
+// sequence of vector.contracts inside the reduction loops.
+//
+// For example - for F32 type:
+// ```
+//   vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+//   vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+//   vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
+//   vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
+// ```
+// to
+// ```
+//   amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+//   amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+//   amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
+//   amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
+// ```
+struct VectorContractToPackedTypeTiledDotProduct
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind.");
+
+    unsigned int blockingFactor =
+        contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
+    bool isVnni = x86vector::isInVnniLayout(contractOp.getOperation(),
+                                            contractOp.getIndexingMapsArray(),
+                                            blockin...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/182810


More information about the Mlir-commits mailing list