[Mlir-commits] [mlir] [mlir][x86] Move AMX dialect into X86 dialect (PR #183717)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 27 02:15:06 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sparse
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Unifies the two dialects that define x86 operations into a single one. The AMX dialect is moved into X86 in line with other x86 extensions.
The two dialects are simply merged together. X86 dialect refactoring will be addressed separately.
List of changes:
- operations: 'amx.tile_*' => 'x86.amx.tile_*'
- types: '!amx.tile' => '!x86.amx.tile'
- namespace: 'mlir::amx' => 'mlir::x86::amx'
- test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
- vector lowering: AMX is enabled by default together with X86
The MLIR AMX tests are now nested under X86 directory. To enable AMX integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
---
Patch is 163.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/183717.diff
56 Files Affected:
- (modified) mlir/Maintainers.md (-1)
- (modified) mlir/docs/TargetLLVMIR.md (+2-2)
- (removed) mlir/include/mlir-c/Dialect/AMX.h (-25)
- (modified) mlir/include/mlir/Conversion/Passes.td (+5-9)
- (modified) mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h (+2-2)
- (removed) mlir/include/mlir/Dialect/AMX/AMX.td (-440)
- (removed) mlir/include/mlir/Dialect/AMX/AMXDialect.h (-34)
- (removed) mlir/include/mlir/Dialect/AMX/AMXInterfaces.td (-31)
- (removed) mlir/include/mlir/Dialect/AMX/CMakeLists.txt (-5)
- (removed) mlir/include/mlir/Dialect/AMX/Transforms.h (-33)
- (modified) mlir/include/mlir/Dialect/CMakeLists.txt (-1)
- (modified) mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h (-5)
- (modified) mlir/include/mlir/Dialect/X86/Transforms.h (+5-2)
- (modified) mlir/include/mlir/Dialect/X86/X86.td (+384)
- (modified) mlir/include/mlir/Dialect/X86/X86Dialect.h (+13)
- (removed) mlir/lib/CAPI/Dialect/AMX.cpp (-13)
- (modified) mlir/lib/CAPI/Dialect/CMakeLists.txt (-9)
- (modified) mlir/lib/Conversion/VectorToAMX/CMakeLists.txt (+1-1)
- (modified) mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp (+32-30)
- (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (-2)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (-8)
- (removed) mlir/lib/Dialect/AMX/CMakeLists.txt (-2)
- (removed) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (-318)
- (removed) mlir/lib/Dialect/AMX/IR/CMakeLists.txt (-15)
- (removed) mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt (-9)
- (removed) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (-70)
- (modified) mlir/lib/Dialect/CMakeLists.txt (-1)
- (modified) mlir/lib/Dialect/X86/IR/X86Dialect.cpp (+285)
- (modified) mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp (+24-1)
- (modified) mlir/lib/RegisterAllDialects.cpp (-2)
- (modified) mlir/lib/RegisterAllExtensions.cpp (+2-2)
- (modified) mlir/test/CMakeLists.txt (+2-2)
- (modified) mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir (+14-14)
- (modified) mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir (+9-9)
- (modified) mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir (-1)
- (removed) mlir/test/Dialect/AMX/invalid.mlir (-158)
- (removed) mlir/test/Dialect/AMX/roundtrip.mlir (-77)
- (removed) mlir/test/Dialect/AMX/side-effects.mlir (-32)
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+9-9)
- (added) mlir/test/Dialect/X86/AMX/invalid.mlir (+158)
- (renamed) mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir (+34-34)
- (added) mlir/test/Dialect/X86/AMX/roundtrip.mlir (+77)
- (added) mlir/test/Dialect/X86/AMX/side-effects.mlir (+32)
- (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/lit.local.cfg (+1-1)
- (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf-full.mlir (+6-6)
- (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf.mlir (+11-11)
- (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-ext.mlir (+21-21)
- (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-full.mlir (+6-6)
- (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli.mlir (+11-11)
- (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero-block.mlir (+3-3)
- (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero.mlir (+3-3)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir (+4-3)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir (+11-13)
- (modified) mlir/test/Target/LLVMIR/amx.mlir (+42-42)
- (modified) mlir/test/lit.site.cfg.py.in (+1-1)
- (modified) mlir/test/mlir-opt/commandline.mlir (-1)
``````````diff
diff --git a/mlir/Maintainers.md b/mlir/Maintainers.md
index a023ee0ea1bba..181541a0f3a93 100644
--- a/mlir/Maintainers.md
+++ b/mlir/Maintainers.md
@@ -104,7 +104,6 @@ available, should be contacted first, as they're more active in those areas.
* ‘arm_neon’ Dialect ([@banach-space](https://github.com/banach-space))
* ‘arm_sve’ Dialect ([@banach-space](https://github.com/banach-space))
* ‘ArmSME’ Dialect ([@banach-space](https://github.com/banach-space))
-* ‘amx’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
* ‘x86’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
* ‘vcix’ Dialect ([@mshockwave](https://github.com/mshockwave))
diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md
index 2bdf400a7759f..2bcacbb4ee946 100644
--- a/mlir/docs/TargetLLVMIR.md
+++ b/mlir/docs/TargetLLVMIR.md
@@ -5,8 +5,8 @@ overall flow is two-stage:
1. **conversion** of the IR to a set of dialects translatable to LLVM IR, for
example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
- dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
- [X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md);
+ dialects derived from LLVM IR intrinsics such as [X86](Dialects/X86.md)
+ or [ArmNeon](Dialects/ArmNeon.md);
2. **translation** of MLIR dialects to LLVM IR.
This flow allows the non-trivial transformation to be performed within MLIR
diff --git a/mlir/include/mlir-c/Dialect/AMX.h b/mlir/include/mlir-c/Dialect/AMX.h
deleted file mode 100644
index ac4695a107ae6..0000000000000
--- a/mlir/include/mlir-c/Dialect/AMX.h
+++ /dev/null
@@ -1,25 +0,0 @@
-//===-- mlir-c/Dialect/AMX.h - C API for AMX Dialect --------*- C -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM
-// Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_C_DIALECT_AMX_H
-#define MLIR_C_DIALECT_AMX_H
-
-#include "mlir-c/IR.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMX, amx);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif // MLIR_C_DIALECT_AMX_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ecc22abb0f935..e77860897399f 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1521,8 +1521,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
operations. The lowering pass provides several options to control
the kinds of optimizations that are allowed. It also provides options
that enable the use of one or more architectural-specific dialects
- (AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the
- architectural-neutral vector dialect lowering.
+ (X86, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral
+ vector dialect lowering.
}];
// Override explicitly in C++ to allow conditional dialect dependence.
@@ -1544,10 +1544,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"vector access are naturally aligned. If operations have an "
"alignment attribute set, the alignment attribute takes priority "
"over this option ">,
- Option<"amx", "enable-amx",
- "bool", /*default=*/"false",
- "Enables the use of AMX dialect while lowering the vector "
- "dialect.">,
Option<"armNeon", "enable-arm-neon",
"bool", /*default=*/"false",
"Enables the use of ArmNeon dialect while lowering the vector "
@@ -1626,10 +1622,10 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
//===----------------------------------------------------------------------===//
def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
- let summary = "Lower the operations from the vector dialect into the AMX "
- "dialect";
+ let summary = "Lower the operations from the vector dialect into the X86 "
+ "dialect AMX operations";
let dependentDialects = [
- "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
+ "affine::AffineDialect", "x86::X86Dialect", "arith::ArithDialect",
"memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
];
}
diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
index b075ac92990a2..6b178e02684c0 100644
--- a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
+++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
@@ -1,4 +1,4 @@
-//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
+//===- VectorToAMX.h - Convert vector to X86 dialect AMX ops ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -18,7 +18,7 @@ class RewritePatternSet;
#define GEN_PASS_DECL_CONVERTVECTORTOAMX
#include "mlir/Conversion/Passes.h.inc"
-/// Collect a set of patterns to convert from the vector to AMX ops.
+/// Collect a set of patterns to convert from the vector to X86 AMX ops.
void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
deleted file mode 100644
index cace63d32fd80..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ /dev/null
@@ -1,440 +0,0 @@
-//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the basic operations for the AMX dialect.
-//
-// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
-// multiply unit (TMUL), a tile control register (TILECFG), and eight
-// tile registers TMM0 through TMM7 (TILEDATA).
-//
-// The AMX dialect provides a bridge between MLIR concepts, such as
-// 2-d vector, operations, and memrefs, and the lower level details
-// of Intel AMX, such as configuration setup, tile sizes, instructions,
-// and tile release.
-//
-// Note that since configuration changes (implicit at dialect level) are
-// costly, it is highly recommended to use the AMX dialect on same-shaped
-// vectors, at least within a single method.
-//
-// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef AMX
-#define AMX
-
-include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
-include "mlir/Dialect/AMX/AMXInterfaces.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/IR/AttrTypeBase.td"
-include "mlir/IR/BuiltinTypes.td"
-
-//===----------------------------------------------------------------------===//
-// AMX dialect definition.
-//===----------------------------------------------------------------------===//
-
-def AMX_Dialect : Dialect {
- let name = "amx";
- let cppNamespace = "::mlir::amx";
- let description = [{
- The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
- multiply unit (TMUL), a tile control register (TILECFG), and eight
- tile registers TMM0 through TMM7 (TILEDATA).
-
- This `AMX` dialect provides a bridge between MLIR concepts such as
- vectors and memrefs and the lower level LLVM IR support of AMX.
-
- Note that since configuration changes (implicit at dialect level) are
- costly, it is highly recommended to use the AMX dialect on same-shaped
- vectors, at least within a single method.
-
- For details, see the Intel documentation:
- https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
- }];
- let useDefaultTypePrinterParser = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// AMX Tile definition.
-//===----------------------------------------------------------------------===//
-
-class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
- : TypeDef<AMX_Dialect, typeName, traits> {
- let mnemonic = typeMnemonic;
-}
-
-def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
- let cppFunctionName = "isValidTileTypeElementType";
-}
-
-def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
- let summary = "AMX 2D tile to be used by AMX opertaions.";
-
- let description = [{
- This type is used to represent values in AMX tile registers. All AMX operations
- work on AMX tiles and these tiles cannot be used in other operations directly.
- LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
- element type for IR verification and lowering to LLVMIR dialect.
- }];
-
- let parameters = (ins
- ArrayRefParameter<"int64_t">:$shape,
- AMX_TileTypeElementType:$elementType
- );
-
- let builders = [
- TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
- return $_get(elementType.getContext(), shape, elementType);
- }]>
- ];
-
- let extraClassDeclaration = [{
- /// Returns if this type is ranked (always true).
- bool hasRank() const { return true; }
-
- /// Clone this tile type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const {
- return get(shape.value_or(getShape()), elementType);
- }
- }];
-
- let hasCustomAssemblyFormat = 1;
- let skipDefaultBuilders = 1;
-}
-
-def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
- CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
-
-class AMXTileOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
- "::mlir::amx::TileType">;
-
-def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
-
-def AMXTileF32 : AMXTileOf<[F32]>;
-
-def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
-
-def AMXTileI32 : AMXTileOf<[I32]>;
-
-def AMXTileI8 : AMXTileOf<[I8]>;
-
-//===----------------------------------------------------------------------===//
-// AMX Op and IntrOp definitions.
-//===----------------------------------------------------------------------===//
-
-class AMX_Op<string mnemonic, list<Trait> traits = []> :
- Op<AMX_Dialect, mnemonic, traits> {}
-
-//===----------------------------------------------------------------------===//
-// AMX Op definitions
-//===----------------------------------------------------------------------===//
-
-//
-// Tile reset.
-//
-
-def TileZeroOp : AMX_Op<"tile_zero", [
- AMXIntrinsicOpInterface,
- MemoryEffects<[MemWrite]>
- ]> {
- let summary = "tile zero operation";
- let description = [{
- Zeroes the destination tile, with the shape defined by the 2-dim
- vector type of the result.
-
- The operation is eventually lowered into the "tilezero" instruction
- with the corresponding tile configuration.
-
- With the write memory effect, each `amx.tile_zero` operation serves as
- a compilation hint to use a separate tile register.
-
- Example:
-
- ```mlir
- %0 = amx.tile_zero : !amx.tile<16x16xbf16>
- ```
- }];
- let results = (outs AnyAMXTile:$res);
- let extraClassDeclaration = [{
- TileType getTileType() {
- return ::llvm::cast<TileType>(getRes().getType());
- }
-
- std::string getIntrinsicName() {
- return "llvm.x86.tilezero.internal";
- }
- SmallVector<Value> getIntrinsicOperands(
- ::mlir::ArrayRef<Value> operands,
- const ::mlir::LLVMTypeConverter &typeConverter,
- ::mlir::RewriterBase &rewriter);
- }];
- let assemblyFormat = "attr-dict `:` qualified(type($res))";
- let hasVerifier = 1;
-}
-
-//
-// Tile memory operations.
-//
-
-def TileLoadOp : AMX_Op<"tile_load", [
- AMXIntrinsicOpInterface,
- MemoryEffects<[MemWrite]>,
- AttrSizedOperandSegments
- ]> {
- let summary = "tile load operation";
- let description = [{
- Loads a tile from memory defined by a `base` and `indices`, with the
- shape defined by the 2-dim vector type of the result.
- The tile's rows are populated by reading contiguous elements starting
- at the `base`. For each tile row, the `base` is incremented by `stride`
- number of elements.
-
- The tile is loaded using the following indexing scheme:
-
- ```
- for row in enumerate(tile_rows):
- mem_row = base[i0, i1, ..., iN + row * stride]
- for col in enumerate(tile_cols):
- tile[row, col] = mem_row[col]
- ```
-
- If the `stride` is not provided, then the `base` buffer must be at least
- 2-dimensional, and the `stride` is automatically inferred and corresponds
- to the stride of the buffer's second innermost dimension.
-
- The operation is eventually lowered into the "tileloadd" instruction
- with the corresponding tile configuration.
-
- With the write memory effect, each `amx.tile_load` operation serves as
- a compilation hint to use a separate tile register.
-
- Example:
-
- ```mlir
- // Tile load from a 2-D memref with implicit stride.
- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
-
- // Tile load from a 1-D memref with explicit stride.
- %0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
- ```
- }];
- let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
- Variadic<Index>:$indices,
- Optional<Index>:$stride);
- let results = (outs AnyAMXTile:$res);
- let builders = [
- OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
- ];
- let extraClassDeclaration = [{
- MemRefType getMemRefType() {
- return ::llvm::cast<MemRefType>(getBase().getType());
- }
- TileType getTileType() {
- return ::llvm::cast<TileType>(getRes().getType());
- }
-
- std::string getIntrinsicName() {
- return "llvm.x86.tileloadd64.internal";
- }
- SmallVector<Value> getIntrinsicOperands(
- ::mlir::ArrayRef<Value> operands,
- const ::mlir::LLVMTypeConverter &typeConverter,
- ::mlir::RewriterBase &rewriter);
- }];
- let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
- "`:` type($base) `into` qualified(type($res))";
- let hasVerifier = 1;
-}
-
-def TileStoreOp : AMX_Op<"tile_store", [
- AMXIntrinsicOpInterface,
- AttrSizedOperandSegments
- ]> {
- let summary = "tile store operation";
- let description = [{
- Stores a tile to memory defined by a `base` and `indices`, with the
- shape defined by the 2-dim vector type of the value.
- The tile's rows are written contiguously to the buffer starting at
- the `base`. For each tile row, the `base` is incremented by `stride`
- number of elements.
-
- The tile is stored using the following indexing scheme:
-
- ```
- for row in enumerate(tile_rows):
- mem_row = base[i0, i1, ..., iN + row * stride]
- for col in enumerate(tile_cols):
- mem_row[col] = tile[row, col]
- ```
-
- If the `stride` is not provided, then the `base` buffer must be at least
- 2-dimensional, and the `stride` is automatically inferred and corresponds
- to the stride of the buffer's second innermost dimension.
-
- The operation is eventually lowered into the "tilestored" instruction
- with the corresponding tile configuration.
-
- Example:
-
- ```mlir
- // Tile store to a 2-D memref with implicit stride.
- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
-
- // Tile store to a 1-D memref with explicit stride.
- amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
- ```
- }];
- let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
- Variadic<Index>:$indices,
- AnyAMXTile:$val,
- Optional<Index>:$stride);
- let builders = [
- OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
- ];
- let extraClassDeclaration = [{
- MemRefType getMemRefType() {
- return ::llvm::cast<MemRefType>(getBase().getType());
- }
- TileType getTileType() {
- return ::llvm::cast<TileType>(getVal().getType());
- }
-
- std::string getIntrinsicName() {
- return "llvm.x86.tilestored64.internal";
- }
- SmallVector<Value> getIntrinsicOperands(
- ::mlir::ArrayRef<Value> operands,
- const ::mlir::LLVMTypeConverter &typeConverter,
- ::mlir::RewriterBase &rewriter);
- }];
- let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
- "attr-dict `:` type($base) `,` qualified(type($val))";
- let hasVerifier = 1;
-}
-
-//
-// Tile arithmetic operations.
-//
-
-def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
- AMXIntrinsicOpInterface,
- AllTypesMatch<["acc", "res"]>
- ]> {
- let summary = "tile multiplication operation (floating-point)";
- let description = [{
- Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
- into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
- pairs of "bf16").
-
- The operation is eventually lowered into the "tdpbf16ps" instruction with
- the corresponding tile configuration.
-
- Example:
-
- ```mlir
- %0 = amx.tile_mulf %a, %b, %c
- : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
- ```
- }];
- let arguments = (ins AMXTileF16OrBF16:$lhs,
- AMXTileF16OrBF16:$rhs,
- AMXTileF32:$acc);
- let results = (outs AMXTileF32:$res);
- let extraClassDeclaration = [{
- TileType getLhsTileType() {
- return ::llvm::cast<TileType>(getLhs().getType());
- }
- TileType getRhsTileType() {
- return ::llvm::cast<TileType>(getRhs().getType());
- }
- TileType getTileType() {
- return ::llvm::cast<TileType>(getRes().getType());
- }
-
- std::string getIntrinsicName() {
- std::string intr = "llvm.x86.tdp";
- auto elementType =
- getLhsTileType().getElementType();
- intr += elementType.isF16() ? "fp16" : "bf16";
- intr += "ps.internal";
- return intr;
- }
- SmallVector<Value> getIntrinsicOperands(
- ::mlir::ArrayRef<Value> operands,
- const ::mlir::LLVMTypeConverter &typeConverter,
- ::mlir::RewriterBase &rewriter);
- }];
- let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
- "qualified(type($lhs)) `,` qualified(type($rhs))"
- " `,` qualified(type($acc)) ";
- let hasVerifier = 1;
-}
-
-def TileMulIOp : AMX_Op<"tile_muli", [Pure,
- AMXIntrinsicOpInterface,
- AllTypesMatch<["acc", "res"]>
- ]> {
- let summary = "tile multiplication operation (integer)";
- let description = [{
- Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
- into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
- combinations (4 bytes packed into dwords in the columns of both the
- source operand tiles; the zero or sign extension is specified with
- the attributes and default to sign extended).
-
- The operation is eventually lowered into one of the "tdpbssd",
- "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
- tile configuration.
-
- Example:
-
- ```mlir
- %0 = amx.tile_muli %a zext, %b zext, %c
- : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
- ```
- }];
- let arguments = (ins AMXTileI8:$lhs,
- AMXTileI8:$rhs,
- AMXTileI32:$acc,
- UnitAttr:$isZextLhs,
- UnitAttr:$isZextRhs
- );
- let results = (outs AMXTileI32:$res);
- let extraClassDeclaration = [{
- TileType getLhsTileType() {
- return ::llvm::cast<TileType>(getLhs().getType());
- }
- TileType getRhsT...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/183717
More information about the Mlir-commits
mailing list