[Mlir-commits] [mlir] [mlir][amx] Vector to AMX conversion pass (PR #151121)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue Jul 29 03:33:25 PDT 2025
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/151121
Adds a pass for Vector to AMX operation conversion.
Initially, a direct rewrite for vector contraction in packed VNNI layout is supported. Operations are expected to already be in shapes which are AMX-compatible for the rewriting to occur.
>From c17d62d6be3d035cc82383b4633cb26227009274 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 28 Jul 2025 14:25:51 +0200
Subject: [PATCH] [mlir][amx] Vector to AMX conversion pass
Adds a pass for Vector to AMX operation conversion.
Initially, a direct rewrite for vector contraction in packed VNNI
layout is supported. Operations are expected to already be in shapes
which are AMX-compatible for the rewriting to occur.
---
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 13 +
.../mlir/Conversion/VectorToAMX/VectorToAMX.h | 26 ++
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../lib/Conversion/VectorToAMX/CMakeLists.txt | 19 ++
.../Conversion/VectorToAMX/VectorToAMX.cpp | 287 +++++++++++++++++
.../VectorToAMX/contract-to-amx.mlir | 291 ++++++++++++++++++
7 files changed, 638 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
create mode 100644 mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
create mode 100644 mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 3dc48b2201cf2..91b2ecf8922a3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -75,6 +75,7 @@
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cf7596cc8a928..20ead98acc371 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1515,6 +1515,19 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
];
}
+//===----------------------------------------------------------------------===//
+// VectorToAMX
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
+ let summary = "Lower the operations from the vector dialect into the AMX "
+ "dialect";
+ let dependentDialects = [
+ "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
+ "memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
+ ];
+}
+
//===----------------------------------------------------------------------===//
// XeVMToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
new file mode 100644
index 0000000000000..b075ac92990a2
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
@@ -0,0 +1,26 @@
+//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
+#define MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert from the vector to AMX ops.
+void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 785cb8293810c..171f7169fd41d 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -68,6 +68,7 @@ add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
add_subdirectory(UBToSPIRV)
+add_subdirectory(VectorToAMX)
add_subdirectory(VectorToArmSME)
add_subdirectory(VectorToGPU)
add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
new file mode 100644
index 0000000000000..2d4b2b6e9283c
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRVectorToAMX
+ VectorToAMX.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAMXDialect
+ MLIRAffineUtils
+ MLIRArithDialect
+ MLIRLinalgUtils
+ MLIRMemRefDialect
+ MLIRSCFDialect
+ MLIRTransforms
+ MLIRVectorDialect
+ )
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
new file mode 100644
index 0000000000000..fc24275a1467c
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -0,0 +1,287 @@
+//===- VectorToXeGPU.cpp - Convert vector to AMX dialect --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector operations to XeGPU dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Return true if vector shape is compatible with AMX tiles.
+/// The validation accounts for VNNI packing.
+static bool verifyAmxShape(VectorType vec) {
+ // Check overall shape:
+ // - 2D for plain layout input or output
+ // - 3D for VNNI packed input
+ if (vec.getRank() != 2 && vec.getRank() != 3)
+ return false;
+
+ ArrayRef<int64_t> shape = vec.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = shape[1];
+ unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
+
+ // 3D shape indicates VNNI packed layout.
+ if (vec.getRank() == 3) {
+ int64_t vnniFactor = 32 / elemBitWidth;
+ if (shape.back() != vnniFactor)
+ return false;
+ cols *= vnniFactor;
+ }
+
+ // AMX tile supports up to 16 rows of 64 bytes each.
+ constexpr unsigned maxRows = 16;
+ constexpr unsigned maxBitsPerRow = 64 * 8;
+ return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
+}
+
+/// Checks if contraction operands are in AMX-compatible packed VNNI layout.
+static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType || accType.getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+ // Expect 3D inputs for VNNI packed data.
+ VectorType lhsType = contractOp.getLhs().getType();
+ VectorType rhsType = contractOp.getRhs().getType();
+ if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects lhs and rhs 3D vectors");
+
+ // Check if shapes are compatible with AMX tile.
+ if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
+ !verifyAmxShape(accType))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
+
+ // Validate affine maps.
+ //
+ // Iterators can be ordered arbitrarily. Indexing map positions are based on
+ // operands' target shapes.
+ // The matrix layouts must match the following:
+ // - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
+ // - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
+ // - matrix C - [M]x[N]
+ SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
+ AffineMap mapA = indexingMaps[0];
+ AffineMap mapB = indexingMaps[1];
+ if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
+ mapB.getNumResults() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid input indexing maps");
+ FailureOr<linalg::ContractionDimensions> dims =
+ linalg::inferContractionDims(indexingMaps);
+ if (failed(dims))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Failed to infer contraction dims");
+ // Two reduction dimensions are expected:
+ // - one for the K dimension
+ // - one for the VNNI factor
+ if (dims->k.size() != 2)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expected two reduction dims");
+ assert(dims->m.size() == 1 && dims->n.size() == 1 &&
+ "Invalid parallel contraction dims");
+
+ SmallVector<vector::IteratorType> iteratorTypes =
+ contractOp.getIteratorTypesArray();
+ // Check VNNI dim maps - the innermost dim for A and B inputs.
+ auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
+ auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
+ if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+ iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
+ // Check K dim maps - non-transposed row-major layout.
+ auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
+ auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
+ if (!redDimA || !redDimB || redDimA != redDimB ||
+ iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
+ // Check M and N dim maps - map to non-transposed output.
+ AffineMap mapC = indexingMaps[2];
+ auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
+ auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
+ if (!mDimC || !nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
+ auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
+ if (!parallelDimA ||
+ iteratorTypes[parallelDimA.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimA != mDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
+ auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
+ if (!parallelDimB ||
+ iteratorTypes[parallelDimB.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimB != nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
+
+ return success();
+}
+
+/// Validate contraction operands for AMX lowering.
+static LogicalResult validateOperands(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType)
+ return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
+
+ // Check if operand types are compatible with AMX compute ops.
+ bool validElemTypes = false;
+ Type lhsElemType = contractOp.getLhs().getType().getElementType();
+ Type rhsElemType = contractOp.getRhs().getType().getElementType();
+ Type accElemType = accType.getElementType();
+ if (accElemType.isInteger(32)) {
+ validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
+ } else if (accElemType.isF32()) {
+ validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
+ (lhsElemType.isBF16() && rhsElemType.isBF16());
+ }
+ if (!validElemTypes)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid combination of operand types");
+
+ if (failed(isAmxVnniLayout(rewriter, contractOp)))
+ return failure();
+
+ return success();
+}
+
+/// Collapses the two innermost dimensions together.
+static Value collapseLastDim(PatternRewriter &rewriter,
+ TypedValue<MemRefType> memref) {
+ int64_t rank = memref.getType().getRank();
+ SmallVector<ReassociationIndices> reassocIndices;
+ for (auto i : llvm::seq<int64_t>(0, rank - 2))
+ reassocIndices.push_back({i});
+ reassocIndices.push_back({rank - 2, rank - 1});
+ return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
+ reassocIndices);
+}
+
+/// Loads vector values to an AMX tile.
+static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
+ TypedValue<VectorType> vec) {
+ Location loc = vec.getLoc();
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ // Transfer the vector to a tile through an intermediate buffer.
+ VectorType vecTy = vec.getType();
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
+ SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
+ vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
+
+ // Collapse the VNNI dimension in case of packing.
+ bool isPacked = vecTy.getRank() == 3;
+ if (isPacked)
+ buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
+
+ ArrayRef<int64_t> shape = vecTy.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
+ std::multiplies<int64_t>());
+ auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+ return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
+ {zeroIndex, zeroIndex});
+}
+
+/// Stores an AMX tile in a vector.
+static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
+ TypedValue<amx::TileType> tile) {
+ Location loc = tile.getLoc();
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ // Transfer the tile to a vector through an intermediate buffer.
+ amx::TileType tileTy = tile.getType();
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc,
+ MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
+ SmallVector<Value> indices(2, zeroIndex);
+ amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
+
+ auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
+ return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
+}
+
+struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ if (failed(validateOperands(rewriter, contractOp)))
+ return failure();
+
+ TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
+ TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
+ auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
+ assert(acc && "Invalid accumulator type");
+ TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
+
+ TypedValue<amx::TileType> tileMul;
+ if (acc.getType().getElementType().isFloat()) {
+ tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ } else {
+ tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ }
+
+ Value res = storeTile(rewriter, tileMul);
+ rewriter.replaceOp(contractOp, res);
+
+ return success();
+ }
+};
+
+struct ConvertVectorToAMXPass
+ : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
+ void runOnOperation() override {
+ MLIRContext &ctx = getContext();
+ RewritePatternSet patterns(&ctx);
+ populateVectorToAMXConversionPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ContractionToAMX>(patterns.getContext());
+}
diff --git a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
new file mode 100644
index 0000000000000..ad23964a15dd2
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
@@ -0,0 +1,291 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// VNNI format is Intel's packed data layout.
+/// For matrix multiplication, elements from the reduction dimension `k`
+/// are packed into 32-bit tuples. Then the appropriate AMX operations can
+/// perform tile multiplication directly on the packed data.
+///
+/// These packed elements are represented in the indexing maps by a separate
+/// reduction dimension `vnni`.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_f16(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @contract_vnni_f16(
+// CHECK-SAME: %[[A:.+]]: vector<4x8x2xf16>,
+// CHECK-SAME: %[[B:.+]]: vector<8x16x2xf16>,
+// CHECK-SAME: %[[C:.+]]: vector<4x16xf32>
+
+/// AMX hardware has no direct access to the registers. Thus, data must
+/// be transfered through intermediate buffers.
+///
+/// Load A vector into an AMX tile
+// CHECK: %[[A_BUF:.+]] = memref.alloca() : memref<4x8x2xf16>
+// CHECK: vector.transfer_write %[[A]], %[[A_BUF]]
+// CHECK: %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]]
+// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16>
+// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]]
+
+/// Load B vector into an AMX tile
+// CHECK: %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16>
+// CHECK: vector.transfer_write %[[B]], %[[B_BUF]]
+// CHECK: %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]]
+// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16>
+// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]]
+
+/// Load C vector into an AMX tile
+// CHECK: %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32>
+// CHECK: vector.transfer_write %[[C]], %[[C_BUF]]
+// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]]
+
+/// Perform tile multiplication
+// CHECK: %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Load the result back into a vector
+// CHECK: %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32>
+// CHECK: amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
+// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
+
+// CHECK: return %[[RES_VEC]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xbf16>, vector<8x16x2xbf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @contract_vnni_bf16(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_mulf
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
+ %C: vector<4x8xi32>) -> vector<4x8xi32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32>
+ return %0 : vector<4x8xi32>
+}
+
+// CHECK-LABEL: @contract_vnni_i8(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(vnni, m, k, n) -> (m, k, vnni)>
+#map1 = affine_map<(vnni, m, k, n) -> (k, n, vnni)>
+#map2 = affine_map<(vnni, m, k, n) -> (m, n)>
+func.func @contract_shuffled_iterators(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
+ %C: vector<4x8xi32>) -> vector<4x8xi32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "reduction", "parallel"]}
+ %A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32>
+ return %0 : vector<4x8xi32>
+}
+
+// CHECK-LABEL: @contract_shuffled_iterators(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_kind(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<mul>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_kind(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, k, vnni) -> (k, m, vnni)>
+#map2 = affine_map<(m, k, vnni) -> ()>
+func.func @negative_non_vector_acc(%A: vector<4x8x2xf16>, %B: vector<8x4x2xf16>,
+ %C: f32) -> f32 {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf16>, vector<8x4x2xf16> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @negative_non_vector_acc(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_operand_types(%A: vector<4x8x2xf32>, %B: vector<8x16x2xf32>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf32>, vector<8x16x2xf32> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_operand_types(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k) -> (m, k)>
+#map1 = affine_map<(m, n, k) -> (k, n)>
+#map2 = affine_map<(m, n, k) -> (m, n)>
+func.func @negative_non_packed_layout(%A: vector<4x16xf16>, %B: vector<16x16xf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ %A, %B, %C : vector<4x16xf16>, vector<16x16xf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_non_packed_layout(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_vnni_factor(%A: vector<4x2x4xf16>, %B: vector<2x2x4xf16>,
+ %C: vector<4x2xf32>) -> vector<4x2xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x2x4xf16>, vector<2x2x4xf16> into vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_vnni_factor(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_too_many_rows(%A: vector<32x8x2xf16>, %B: vector<8x16x2xf16>,
+ %C: vector<32x16xf32>) -> vector<32x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<32x8x2xf16>, vector<8x16x2xf16> into vector<32x16xf32>
+ return %0 : vector<32x16xf32>
+}
+
+// CHECK-LABEL: @negative_too_many_rows(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_too_wide_rows(%A: vector<4x32x2xf16>, %B: vector<32x16x2xf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x32x2xf16>, vector<32x16x2xf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_too_wide_rows(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (k, vnni, m)>
+#map1 = affine_map<(m, n, k, vnni) -> (n, k, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_input_dim_permutation(%A: vector<2x2x2xf16>,
+ %B: vector<2x2x2xf16>, %C: vector<2x2xf32>) -> vector<2x2xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<2x2x2xf16>, vector<2x2x2xf16> into vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// CHECK-LABEL: @negative_input_dim_permutation(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (n, m)>
+func.func @negative_output_dim_permutation(%A: vector<4x8x2xf16>,
+ %B: vector<8x16x2xf16>, %C: vector<16x4xf32>) -> vector<16x4xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<16x4xf32>
+ return %0 : vector<16x4xf32>
+}
+
+// CHECK-LABEL: @negative_output_dim_permutation(
+// CHECK-NOT: amx
+// CHECK: vector.contract
More information about the Mlir-commits
mailing list