[Mlir-commits] [mlir] [mlir][amx] Restore conversion interface for AMX (PR #143871)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 12 04:13:41 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Restores mistakenly removed AMX interface which ensures that the custom tile type is converted to its LLVM equivalent within other operations such as control flow.
Fix after #<!-- -->140559
---
Full diff: https://github.com/llvm/llvm-project/pull/143871.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMX/Transforms.h (+3)
- (modified) mlir/include/mlir/InitAllExtensions.h (+2)
- (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+19)
- (modified) mlir/test/Target/LLVMIR/amx.mlir (+20)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 4a751d99ceeee..7391ec2ff6b14 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
/// intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
+/// Register LLVM conversion interface for AMX dialect.
+void registerConvertAMXToLLVMInterface(DialectRegistry ®istry);
+
} // namespace mlir
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 7dcbabe8aafa3..f356b91b1b6c0 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertOpenMPToLLVMInterface(registry);
registerConvertSCFToEmitCInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
+ registerConvertAMXToLLVMInterface(registry);
gpu::registerConvertGpuToLLVMInterface(registry);
NVVM::registerConvertGpuToNVVMInterface(registry);
vector::registerConvertVectorToLLVMInterface(registry);
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 7471dc797e0fc..37aebc9fab3eb 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addIllegalDialect<AMXDialect>();
}
+
+namespace {
+/// Implement the interface to convert AMX to LLVM.
+struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+
+ void populateConvertToLLVMConversionPatterns(
+ ConversionTarget &target, LLVMTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
+ dialect->addInterfaces<AMXToLLVMDialectInterface>();
+ });
+}
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index 094475040436d..abdf2fe3bd534 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
return
}
+
+// CHECK-LABEL: define void @amx_tile_type_through_cf
+func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
+ %idx: index, %cond: i1) {
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1: // pred: ^bb0
+ // CHECK: call x86_amx @llvm.x86.tileloadd64.internal
+ %0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
+ cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
+^bb2: // pred: ^bb0
+ // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+ %1 = amx.tile_zero : !amx.tile<16x64xi8>
+ cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
+^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2
+ cf.br ^bb4
+^bb4: // pred: ^bb3
+ // CHECK: call void @llvm.x86.tilestored64.internal
+ amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/143871
More information about the Mlir-commits
mailing list