[Mlir-commits] [mlir] [mlir][amx] Restore conversion interface for AMX (PR #143871)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Jun 12 04:13:13 PDT 2025
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/143871
Restores mistakenly removed AMX interface which ensures that the custom tile type is converted to its LLVM equivalent within other operations such as control flow.
Fix after #140559
>From 6448c1c7cd02452f814f216fbf72089c74b9a882 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 12 Jun 2025 13:05:38 +0200
Subject: [PATCH] [mlir][amx] Restore conversion interface for AMX
Restores mistakenly removed AMX interface which ensures that
the custom tile type is converted to its LLVM equivalent within other
operations such as control flow.
Fix after #140559
---
mlir/include/mlir/Dialect/AMX/Transforms.h | 3 +++
mlir/include/mlir/InitAllExtensions.h | 2 ++
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 19 ++++++++++++++++++
mlir/test/Target/LLVMIR/amx.mlir | 20 +++++++++++++++++++
4 files changed, 44 insertions(+)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 4a751d99ceeee..7391ec2ff6b14 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
/// intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
+/// Register LLVM conversion interface for AMX dialect.
+void registerConvertAMXToLLVMInterface(DialectRegistry ®istry);
+
} // namespace mlir
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 7dcbabe8aafa3..f356b91b1b6c0 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertOpenMPToLLVMInterface(registry);
registerConvertSCFToEmitCInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
+ registerConvertAMXToLLVMInterface(registry);
gpu::registerConvertGpuToLLVMInterface(registry);
NVVM::registerConvertGpuToNVVMInterface(registry);
vector::registerConvertVectorToLLVMInterface(registry);
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 7471dc797e0fc..37aebc9fab3eb 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addIllegalDialect<AMXDialect>();
}
+
+namespace {
+/// Implement the interface to convert AMX to LLVM.
+struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+
+ void populateConvertToLLVMConversionPatterns(
+ ConversionTarget &target, LLVMTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
+ dialect->addInterfaces<AMXToLLVMDialectInterface>();
+ });
+}
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index 094475040436d..abdf2fe3bd534 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
return
}
+
+// CHECK-LABEL: define void @amx_tile_type_through_cf
+func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
+ %idx: index, %cond: i1) {
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1: // pred: ^bb0
+ // CHECK: call x86_amx @llvm.x86.tileloadd64.internal
+ %0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
+ cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
+^bb2: // pred: ^bb0
+ // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+ %1 = amx.tile_zero : !amx.tile<16x64xi8>
+ cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
+^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2
+ cf.br ^bb4
+^bb4: // pred: ^bb3
+ // CHECK: call void @llvm.x86.tilestored64.internal
+ amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
+ return
+}
More information about the Mlir-commits
mailing list