[Mlir-commits] [mlir] [mlir][amx] Vector to AMX conversion pass (PR #151121)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Jul 29 03:46:34 PDT 2025


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/151121

>From c17d62d6be3d035cc82383b4633cb26227009274 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 28 Jul 2025 14:25:51 +0200
Subject: [PATCH 1/2] [mlir][amx] Vector to AMX conversion pass

Adds a pass for Vector to AMX operation conversion.

Initially, a direct rewrite for vector contraction in packed VNNI
layout is supported. Operations are expected to already be in shapes
which are AMX-compatible for the rewriting to occur.
---
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  13 +
 .../mlir/Conversion/VectorToAMX/VectorToAMX.h |  26 ++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../lib/Conversion/VectorToAMX/CMakeLists.txt |  19 ++
 .../Conversion/VectorToAMX/VectorToAMX.cpp    | 287 +++++++++++++++++
 .../VectorToAMX/contract-to-amx.mlir          | 291 ++++++++++++++++++
 7 files changed, 638 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
 create mode 100644 mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
 create mode 100644 mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 3dc48b2201cf2..91b2ecf8922a3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -75,6 +75,7 @@
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cf7596cc8a928..20ead98acc371 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1515,6 +1515,19 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// VectorToAMX
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
+  let summary = "Lower the operations from the vector dialect into the AMX "
+                "dialect";
+  let dependentDialects = [
+    "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
+    "memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // XeVMToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
new file mode 100644
index 0000000000000..b075ac92990a2
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
@@ -0,0 +1,26 @@
+//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
+#define MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert from the vector to AMX ops.
+void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 785cb8293810c..171f7169fd41d 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -68,6 +68,7 @@ add_subdirectory(TosaToSCF)
 add_subdirectory(TosaToTensor)
 add_subdirectory(UBToLLVM)
 add_subdirectory(UBToSPIRV)
+add_subdirectory(VectorToAMX)
 add_subdirectory(VectorToArmSME)
 add_subdirectory(VectorToGPU)
 add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
new file mode 100644
index 0000000000000..2d4b2b6e9283c
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRVectorToAMX
+  VectorToAMX.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRAMXDialect
+  MLIRAffineUtils
+  MLIRArithDialect
+  MLIRLinalgUtils
+  MLIRMemRefDialect
+  MLIRSCFDialect
+  MLIRTransforms
+  MLIRVectorDialect
+  )
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
new file mode 100644
index 0000000000000..fc24275a1467c
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -0,0 +1,287 @@
+//===- VectorToXeGPU.cpp - Convert vector to AMX dialect --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector operations to XeGPU dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Return true if vector shape is compatible with AMX tiles.
+/// The validation accounts for VNNI packing.
+static bool verifyAmxShape(VectorType vec) {
+  // Check overall shape:
+  //   - 2D for plain layout input or output
+  //   - 3D for VNNI packed input
+  if (vec.getRank() != 2 && vec.getRank() != 3)
+    return false;
+
+  ArrayRef<int64_t> shape = vec.getShape();
+  int64_t rows = shape[0];
+  int64_t cols = shape[1];
+  unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
+
+  // 3D shape indicates VNNI packed layout.
+  if (vec.getRank() == 3) {
+    int64_t vnniFactor = 32 / elemBitWidth;
+    if (shape.back() != vnniFactor)
+      return false;
+    cols *= vnniFactor;
+  }
+
+  // AMX tile supports up to 16 rows of 64 bytes each.
+  constexpr unsigned maxRows = 16;
+  constexpr unsigned maxBitsPerRow = 64 * 8;
+  return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
+}
+
+/// Checks if contraction operands are in AMX-compatible packed VNNI layout.
+static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
+                                     vector::ContractionOp contractOp) {
+  VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+  if (!accType || accType.getRank() != 2)
+    return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+  // Expect 3D inputs for VNNI packed data.
+  VectorType lhsType = contractOp.getLhs().getType();
+  VectorType rhsType = contractOp.getRhs().getType();
+  if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
+    return rewriter.notifyMatchFailure(contractOp,
+                                       "Expects lhs and rhs 3D vectors");
+
+  // Check if shapes are compatible with AMX tile.
+  if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
+      !verifyAmxShape(accType))
+    return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
+
+  // Validate affine maps.
+  //
+  // Iterators can be ordered arbitrarily. Indexing map positions are based on
+  // operands' target shapes.
+  // The matrix layouts must match the following:
+  //   - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
+  //   - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
+  //   - matrix C - [M]x[N]
+  SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
+  AffineMap mapA = indexingMaps[0];
+  AffineMap mapB = indexingMaps[1];
+  if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
+      mapB.getNumResults() != 3)
+    return rewriter.notifyMatchFailure(contractOp,
+                                       "Invalid input indexing maps");
+  FailureOr<linalg::ContractionDimensions> dims =
+      linalg::inferContractionDims(indexingMaps);
+  if (failed(dims))
+    return rewriter.notifyMatchFailure(contractOp,
+                                       "Failed to infer contraction dims");
+  // Two reduction dimensions are expected:
+  //   - one for the K dimension
+  //   - one for the VNNI factor
+  if (dims->k.size() != 2)
+    return rewriter.notifyMatchFailure(contractOp,
+                                       "Expected two reduction dims");
+  assert(dims->m.size() == 1 && dims->n.size() == 1 &&
+         "Invalid parallel contraction dims");
+
+  SmallVector<vector::IteratorType> iteratorTypes =
+      contractOp.getIteratorTypesArray();
+  // Check VNNI dim maps - the innermost dim for A and B inputs.
+  auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
+  auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
+  if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+      iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
+    return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
+  // Check K dim maps - non-transposed row-major layout.
+  auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
+  auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
+  if (!redDimA || !redDimB || redDimA != redDimB ||
+      iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
+    return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
+  // Check M and N dim maps - map to non-transposed output.
+  AffineMap mapC = indexingMaps[2];
+  auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
+  auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
+  if (!mDimC || !nDimC)
+    return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
+  auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
+  if (!parallelDimA ||
+      iteratorTypes[parallelDimA.getPosition()] !=
+          vector::IteratorType::parallel ||
+      parallelDimA != mDimC)
+    return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
+  auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
+  if (!parallelDimB ||
+      iteratorTypes[parallelDimB.getPosition()] !=
+          vector::IteratorType::parallel ||
+      parallelDimB != nDimC)
+    return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
+
+  return success();
+}
+
+/// Validate contraction operands for AMX lowering.
+static LogicalResult validateOperands(PatternRewriter &rewriter,
+                                      vector::ContractionOp contractOp) {
+  VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+  if (!accType)
+    return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
+
+  // Check if operand types are compatible with AMX compute ops.
+  bool validElemTypes = false;
+  Type lhsElemType = contractOp.getLhs().getType().getElementType();
+  Type rhsElemType = contractOp.getRhs().getType().getElementType();
+  Type accElemType = accType.getElementType();
+  if (accElemType.isInteger(32)) {
+    validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
+  } else if (accElemType.isF32()) {
+    validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
+                     (lhsElemType.isBF16() && rhsElemType.isBF16());
+  }
+  if (!validElemTypes)
+    return rewriter.notifyMatchFailure(contractOp,
+                                       "Invalid combination of operand types");
+
+  if (failed(isAmxVnniLayout(rewriter, contractOp)))
+    return failure();
+
+  return success();
+}
+
+/// Collapses the two innermost dimensions together.
+static Value collapseLastDim(PatternRewriter &rewriter,
+                             TypedValue<MemRefType> memref) {
+  int64_t rank = memref.getType().getRank();
+  SmallVector<ReassociationIndices> reassocIndices;
+  for (auto i : llvm::seq<int64_t>(0, rank - 2))
+    reassocIndices.push_back({i});
+  reassocIndices.push_back({rank - 2, rank - 1});
+  return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
+                                         reassocIndices);
+}
+
+/// Loads vector values to an AMX tile.
+static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
+                                          TypedValue<VectorType> vec) {
+  Location loc = vec.getLoc();
+  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+  // Transfer the vector to a tile through an intermediate buffer.
+  VectorType vecTy = vec.getType();
+  Value buf = memref::AllocaOp::create(
+      rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
+  SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
+  vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
+
+  // Collapse the VNNI dimension in case of packing.
+  bool isPacked = vecTy.getRank() == 3;
+  if (isPacked)
+    buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
+
+  ArrayRef<int64_t> shape = vecTy.getShape();
+  int64_t rows = shape[0];
+  int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
+                                 std::multiplies<int64_t>());
+  auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+  return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
+                                 {zeroIndex, zeroIndex});
+}
+
+/// Stores an AMX tile in a vector.
+static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
+                                        TypedValue<amx::TileType> tile) {
+  Location loc = tile.getLoc();
+  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+  // Transfer the tile to a vector through an intermediate buffer.
+  amx::TileType tileTy = tile.getType();
+  Value buf = memref::AllocaOp::create(
+      rewriter, loc,
+      MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
+  SmallVector<Value> indices(2, zeroIndex);
+  amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
+
+  auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
+  return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
+}
+
+struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = contractOp.getLoc();
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind");
+    if (failed(validateOperands(rewriter, contractOp)))
+      return failure();
+
+    TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
+    TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
+    auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
+    assert(acc && "Invalid accumulator type");
+    TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
+
+    TypedValue<amx::TileType> tileMul;
+    if (acc.getType().getElementType().isFloat()) {
+      tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
+                                        lhsTile, rhsTile, accTile);
+    } else {
+      tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
+                                        lhsTile, rhsTile, accTile);
+    }
+
+    Value res = storeTile(rewriter, tileMul);
+    rewriter.replaceOp(contractOp, res);
+
+    return success();
+  }
+};
+
+struct ConvertVectorToAMXPass
+    : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
+  void runOnOperation() override {
+    MLIRContext &ctx = getContext();
+    RewritePatternSet patterns(&ctx);
+    populateVectorToAMXConversionPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) {
+  patterns.add<ContractionToAMX>(patterns.getContext());
+}
diff --git a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
new file mode 100644
index 0000000000000..ad23964a15dd2
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
@@ -0,0 +1,291 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// VNNI format is Intel's packed data layout.
+/// For matrix multiplication, elements from the reduction dimension `k`
+/// are packed into 32-bit tuples. Then the appropriate AMX operations can
+/// perform tile multiplication directly on the packed data.
+///
+/// These packed elements are represented in the indexing maps by a separate
+/// reduction dimension `vnni`.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_f16(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
+    %C: vector<4x16xf32>) -> vector<4x16xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @contract_vnni_f16(
+// CHECK-SAME:    %[[A:.+]]: vector<4x8x2xf16>,
+// CHECK-SAME:    %[[B:.+]]: vector<8x16x2xf16>,
+// CHECK-SAME:    %[[C:.+]]: vector<4x16xf32>
+
+/// AMX hardware has no direct access to the registers. Thus, data must
+/// be transfered through intermediate buffers.
+///
+/// Load A vector into an AMX tile
+// CHECK:       %[[A_BUF:.+]] = memref.alloca() : memref<4x8x2xf16>
+// CHECK:       vector.transfer_write %[[A]], %[[A_BUF]]
+// CHECK:       %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]]
+// CHECK-SAME:    {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16>
+// CHECK:       %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]]
+
+/// Load B vector into an AMX tile
+// CHECK:       %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16>
+// CHECK:       vector.transfer_write %[[B]], %[[B_BUF]]
+// CHECK:       %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]]
+// CHECK-SAME:    {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16>
+// CHECK:       %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]]
+
+/// Load C vector into an AMX tile
+// CHECK:       %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32>
+// CHECK:       vector.transfer_write %[[C]], %[[C_BUF]]
+// CHECK:       %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]]
+
+/// Perform tile multiplication
+// CHECK:       %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME:    %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Load the result back into a vector
+// CHECK:       %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32>
+// CHECK:       amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
+// CHECK:       %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
+
+// CHECK:       return %[[RES_VEC]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>,
+    %C: vector<4x16xf32>) -> vector<4x16xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<4x8x2xbf16>, vector<8x16x2xbf16> into vector<4x16xf32>
+  return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @contract_vnni_bf16(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_mulf
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
+    %C: vector<4x8xi32>) -> vector<4x8xi32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32>
+  return %0 : vector<4x8xi32>
+}
+
+// CHECK-LABEL: @contract_vnni_i8(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(vnni, m, k, n) -> (m, k, vnni)>
+#map1 = affine_map<(vnni, m, k, n) -> (k, n, vnni)>
+#map2 = affine_map<(vnni, m, k, n) -> (m, n)>
+func.func @contract_shuffled_iterators(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
+    %C: vector<4x8xi32>) -> vector<4x8xi32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "reduction", "parallel"]}
+    %A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32>
+  return %0 : vector<4x8xi32>
+}
+
+// CHECK-LABEL: @contract_shuffled_iterators(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_kind(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
+    %C: vector<4x16xf32>) -> vector<4x16xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<mul>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_kind(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, k, vnni) -> (k, m, vnni)>
+#map2 = affine_map<(m, k, vnni) -> ()>
+func.func @negative_non_vector_acc(%A: vector<4x8x2xf16>, %B: vector<8x4x2xf16>,
+    %C: f32) -> f32 {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "reduction"]}
+    %A, %B, %C : vector<4x8x2xf16>, vector<8x4x2xf16> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @negative_non_vector_acc(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_operand_types(%A: vector<4x8x2xf32>, %B: vector<8x16x2xf32>,
+    %C: vector<4x16xf32>) -> vector<4x16xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<4x8x2xf32>, vector<8x16x2xf32> into vector<4x16xf32>
+  return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_operand_types(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k) -> (m, k)>
+#map1 = affine_map<(m, n, k) -> (k, n)>
+#map2 = affine_map<(m, n, k) -> (m, n)>
+func.func @negative_non_packed_layout(%A: vector<4x16xf16>, %B: vector<16x16xf16>,
+    %C: vector<4x16xf32>) -> vector<4x16xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"]}
+    %A, %B, %C : vector<4x16xf16>, vector<16x16xf16> into vector<4x16xf32>
+  return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_non_packed_layout(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_vnni_factor(%A: vector<4x2x4xf16>, %B: vector<2x2x4xf16>,
+    %C: vector<4x2xf32>) -> vector<4x2xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<4x2x4xf16>, vector<2x2x4xf16> into vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_vnni_factor(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_too_many_rows(%A: vector<32x8x2xf16>, %B: vector<8x16x2xf16>,
+    %C: vector<32x16xf32>) -> vector<32x16xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<32x8x2xf16>, vector<8x16x2xf16> into vector<32x16xf32>
+  return %0 : vector<32x16xf32>
+}
+
+// CHECK-LABEL: @negative_too_many_rows(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_too_wide_rows(%A: vector<4x32x2xf16>, %B: vector<32x16x2xf16>,
+    %C: vector<4x16xf32>) -> vector<4x16xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<4x32x2xf16>, vector<32x16x2xf16> into vector<4x16xf32>
+  return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_too_wide_rows(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (k, vnni, m)>
+#map1 = affine_map<(m, n, k, vnni) -> (n, k, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_input_dim_permutation(%A: vector<2x2x2xf16>,
+    %B: vector<2x2x2xf16>, %C: vector<2x2xf32>) -> vector<2x2xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<2x2x2xf16>, vector<2x2x2xf16> into vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// CHECK-LABEL: @negative_input_dim_permutation(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (n, m)>
+func.func @negative_output_dim_permutation(%A: vector<4x8x2xf16>,
+    %B: vector<8x16x2xf16>, %C: vector<16x4xf32>) -> vector<16x4xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<16x4xf32>
+  return %0 : vector<16x4xf32>
+}
+
+// CHECK-LABEL: @negative_output_dim_permutation(
+// CHECK-NOT: amx
+// CHECK: vector.contract

>From c909b55c24c7d01cc5c4d4732cfd1a5972797926 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 29 Jul 2025 12:46:21 +0200
Subject: [PATCH 2/2] Extra test case

---
 .../VectorToAMX/contract-to-amx.mlir          | 23 +++++++++++++++++--
 1 file changed, 21 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
index ad23964a15dd2..4fb88dd165126 100644
--- a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
+++ b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
@@ -162,8 +162,8 @@ func.func @negative_non_vector_acc(%A: vector<4x8x2xf16>, %B: vector<8x4x2xf16>,
 #map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
 #map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
 #map2 = affine_map<(m, n, k, vnni) -> (m, n)>
-func.func @negative_invalid_operand_types(%A: vector<4x8x2xf32>, %B: vector<8x16x2xf32>,
-    %C: vector<4x16xf32>) -> vector<4x16xf32> {
+func.func @negative_invalid_operand_types(%A: vector<4x8x2xf32>,
+    %B: vector<8x16x2xf32>, %C: vector<4x16xf32>) -> vector<4x16xf32> {
   %0 = vector.contract
     {kind = #vector.kind<add>,
     indexing_maps = [#map, #map1, #map2],
@@ -216,6 +216,25 @@ func.func @negative_invalid_vnni_factor(%A: vector<4x2x4xf16>, %B: vector<2x2x4x
 
 // -----
 
+#map = affine_map<(batch, m, n, k, vnni) -> (batch, m, k, vnni)>
+#map1 = affine_map<(batch, m, n, k, vnni) -> (batch, k, n, vnni)>
+#map2 = affine_map<(batch, m, n, k, vnni) -> (batch, m, n)>
+func.func @negative_invalid_operands_shapes(%A: vector<1x4x8x2xf16>,
+    %B: vector<1x8x16x2xf16>, %C: vector<1x4x16xf32>) -> vector<1x4x16xf32> {
+  %0 = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
+    %A, %B, %C : vector<1x4x8x2xf16>, vector<1x8x16x2xf16> into vector<1x4x16xf32>
+  return %0 : vector<1x4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_operands_shapes(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
 #map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
 #map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
 #map2 = affine_map<(m, n, k, vnni) -> (m, n)>



More information about the Mlir-commits mailing list