[Mlir-commits] [mlir] [mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead) (PR #73639)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Nov 29 08:07:00 PST 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/73639

>From 73afcbbb4e6bfe8f08791b9e18af295b722ba025 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 27 Nov 2023 19:08:02 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Remove ArmSMETypeConverter (and configure
 LLVM one instead)

This patch removes the ArmSMETypeConverter, and instead updates
`configureArmSMEToLLVMConversionLegality()` to add an ArmSME vector type
conversion to the existing LLVMTypeConverter. This makes it easier to
add these patterns to an existing `-to-llvm` lowering pass.
---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h    |  7 +++---
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   |  8 -------
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 20 +++++++++++------
 .../ArmSME/Transforms/ArmSMETypeConverter.cpp | 22 -------------------
 .../Dialect/ArmSME/Transforms/CMakeLists.txt  |  1 -
 5 files changed, 16 insertions(+), 42 deletions(-)
 delete mode 100644 mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp

diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index fe851d17867dff5..b2130742e0f71cf 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -20,17 +20,16 @@ class RewritePatternSet;
 #define GEN_PASS_DECL_CONVERTARMSMETOLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-using arm_sme::ArmSMETypeConverter;
-
 /// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
 std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
 
 /// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
-void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
+                                             LLVMTypeConverter &typeConverter);
 
 /// Populate the given list with patterns that convert from the ArmSME dialect
 /// to LLVM intrinsics.
-void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             RewritePatternSet &patterns);
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 6f7617f5411c57f..4c9907be4086b1f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
 /// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
 std::unique_ptr<Pass> createTileAllocationPass();
 
-//===----------------------------------------------------------------------===//
-// Type ArmSMETypeConverter pass.
-//===----------------------------------------------------------------------===//
-class ArmSMETypeConverter : public LLVMTypeConverter {
-public:
-  ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
-};
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..12f9df7ed4de2f4 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -540,10 +540,8 @@ struct ConvertArmSMEToLLVMPass
   void runOnOperation() override {
     LLVMConversionTarget target(getContext());
     RewritePatternSet patterns(&getContext());
-    ArmSMETypeConverter converter(&getContext(),
-                                  LowerToLLVMOptions(&getContext()));
-
-    configureArmSMEToLLVMConversionLegality(target);
+    LLVMTypeConverter converter(&getContext());
+    configureArmSMEToLLVMConversionLegality(target, converter);
     populateArmSMEToLLVMConversionPatterns(converter, patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
@@ -554,7 +552,8 @@ struct ConvertArmSMEToLLVMPass
 
 } // namespace
 
-void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
+void mlir::configureArmSMEToLLVMConversionLegality(
+    ConversionTarget &target, LLVMTypeConverter &typeConverter) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
       arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
@@ -574,10 +573,17 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
       arm_sme::aarch64_sme_mopa>();
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
+  typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
+    // There's no LLVM type for SME tiles, but after lowering to intrinsics all
+    // SME vector types should be eliminated.
+    if (arm_sme::isValidSMETileVectorType(type))
+      return type;
+    return std::nullopt;
+  });
 }
 
-void mlir::populateArmSMEToLLVMConversionPatterns(
-    ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                                  RewritePatternSet &patterns) {
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
                OuterProductOpConversion, ZeroOpConversion>(converter);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
deleted file mode 100644
index 1cefc220ecf1035..000000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-
-using namespace mlir;
-arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
-    MLIRContext *ctx, const LowerToLLVMOptions &options)
-    : LLVMTypeConverter(ctx, options) {
-  // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
-  // vectors (common in the context of ArmSME), e.g.
-  //    `vector<[16]x[16]xi8>`,
-  // entering the LLVM Type converter. LLVM does not support arrays of scalable
-  // vectors, but in the case of SME such types are effectively eliminated when
-  // emitting ArmSME LLVM IR intrinsics.
-  addConversion([&](VectorType type) { return type; });
-}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index e2407d9f48f7061..f0854d0678e106d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
-  ArmSMETypeConverter.cpp
   EnableArmStreaming.cpp
   TileAllocation.cpp
 

>From 3de9bba284916897cd80d1c5747b46213a18b6ec Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 29 Nov 2023 15:52:09 +0000
Subject: [PATCH 2/2] Fixup

---
 .../mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h    |  3 +--
 mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 14 +++++++-------
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index b2130742e0f71cf..eab871ab499983f 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -24,8 +24,7 @@ class RewritePatternSet;
 std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
 
 /// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
-void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
-                                             LLVMTypeConverter &typeConverter);
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
 
 /// Populate the given list with patterns that convert from the ArmSME dialect
 /// to LLVM intrinsics.
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 12f9df7ed4de2f4..3dd9584ab139b30 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -541,7 +541,7 @@ struct ConvertArmSMEToLLVMPass
     LLVMConversionTarget target(getContext());
     RewritePatternSet patterns(&getContext());
     LLVMTypeConverter converter(&getContext());
-    configureArmSMEToLLVMConversionLegality(target, converter);
+    configureArmSMEToLLVMConversionLegality(target);
     populateArmSMEToLLVMConversionPatterns(converter, patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
@@ -552,8 +552,7 @@ struct ConvertArmSMEToLLVMPass
 
 } // namespace
 
-void mlir::configureArmSMEToLLVMConversionLegality(
-    ConversionTarget &target, LLVMTypeConverter &typeConverter) {
+void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
       arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
@@ -573,17 +572,18 @@ void mlir::configureArmSMEToLLVMConversionLegality(
       arm_sme::aarch64_sme_mopa>();
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
-  typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
+}
+
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                                  RewritePatternSet &patterns) {
+  converter.addConversion([&](VectorType type) -> std::optional<Type> {
     // There's no LLVM type for SME tiles, but after lowering to intrinsics all
     // SME vector types should be eliminated.
     if (arm_sme::isValidSMETileVectorType(type))
       return type;
     return std::nullopt;
   });
-}
 
-void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                                  RewritePatternSet &patterns) {
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
                OuterProductOpConversion, ZeroOpConversion>(converter);



More information about the Mlir-commits mailing list