[Mlir-commits] [mlir] [mlir][amx] Restore conversion interface for AMX (PR #143871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 12 04:13:41 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Restores mistakenly removed AMX interface which ensures that the custom tile type is converted to its LLVM equivalent within other operations such as control flow.

Fix after #<!-- -->140559

---
Full diff: https://github.com/llvm/llvm-project/pull/143871.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/AMX/Transforms.h (+3) 
- (modified) mlir/include/mlir/InitAllExtensions.h (+2) 
- (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+19) 
- (modified) mlir/test/Target/LLVMIR/amx.mlir (+20) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 4a751d99ceeee..7391ec2ff6b14 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
 /// intrinsics.
 void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 
+/// Register LLVM conversion interface for AMX dialect.
+void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 7dcbabe8aafa3..f356b91b1b6c0 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
 #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertOpenMPToLLVMInterface(registry);
   registerConvertSCFToEmitCInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
+  registerConvertAMXToLLVMInterface(registry);
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 7471dc797e0fc..37aebc9fab3eb 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addIllegalDialect<AMXDialect>();
 }
+
+namespace {
+/// Implement the interface to convert AMX to LLVM.
+struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
+    dialect->addInterfaces<AMXToLLVMDialectInterface>();
+  });
+}
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index 094475040436d..abdf2fe3bd534 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
   amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
   return
 }
+
+// CHECK-LABEL: define void @amx_tile_type_through_cf
+func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
+    %idx: index, %cond: i1) {
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:  // pred: ^bb0
+  // CHECK: call x86_amx @llvm.x86.tileloadd64.internal
+  %0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
+^bb2:  // pred: ^bb0
+  // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+  %1 = amx.tile_zero : !amx.tile<16x64xi8>
+  cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
+^bb3(%2: !amx.tile<16x64xi8>):  // 2 preds: ^bb1, ^bb2
+  cf.br ^bb4
+^bb4:  // pred: ^bb3
+  // CHECK: call void @llvm.x86.tilestored64.internal
+  amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/143871


More information about the Mlir-commits mailing list