[Mlir-commits] [mlir] 01e40a8 - [mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead) (#73639)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 4 09:02:52 PST 2023


Author: Benjamin Maxwell
Date: 2023-12-04T17:02:48Z
New Revision: 01e40a8a3d40d7595d2cd95363c27d84b31e5cd2

URL: https://github.com/llvm/llvm-project/commit/01e40a8a3d40d7595d2cd95363c27d84b31e5cd2
DIFF: https://github.com/llvm/llvm-project/commit/01e40a8a3d40d7595d2cd95363c27d84b31e5cd2.diff

LOG: [mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead) (#73639)

This patch removes the ArmSMETypeConverter, and instead updates
`populateArmSMEToLLVMConversionPatterns()` to add an ArmSME vector type
conversion to the existing LLVMTypeConverter. This makes it easier to
add these patterns to an existing `-to-llvm` lowering pass.

Added: 
    mlir/unittests/Dialect/ArmSME/CMakeLists.txt
    mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp

Modified: 
    mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
    mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
    mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
    mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
    mlir/unittests/Dialect/CMakeLists.txt

Removed: 
    mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index fe851d17867df..eab871ab49998 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -20,8 +20,6 @@ class RewritePatternSet;
 #define GEN_PASS_DECL_CONVERTARMSMETOLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-using arm_sme::ArmSMETypeConverter;
-
 /// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
 std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
 
@@ -30,7 +28,7 @@ void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
 
 /// Populate the given list with patterns that convert from the ArmSME dialect
 /// to LLVM intrinsics.
-void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             RewritePatternSet &patterns);
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 21a97e9cbc794..aef2959265a7c 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
 /// Pass that allocates tile IDs to ArmSME operations.
 std::unique_ptr<Pass> createTileAllocationPass();
 
-//===----------------------------------------------------------------------===//
-// Type ArmSMETypeConverter pass.
-//===----------------------------------------------------------------------===//
-class ArmSMETypeConverter : public LLVMTypeConverter {
-public:
-  ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
-};
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index a28b8ef7f7fce..f9d6f04a811f3 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -538,9 +538,7 @@ struct ConvertArmSMEToLLVMPass
   void runOnOperation() override {
     LLVMConversionTarget target(getContext());
     RewritePatternSet patterns(&getContext());
-    ArmSMETypeConverter converter(&getContext(),
-                                  LowerToLLVMOptions(&getContext()));
-
+    LLVMTypeConverter converter(&getContext());
     configureArmSMEToLLVMConversionLegality(target);
     populateArmSMEToLLVMConversionPatterns(converter, patterns);
 
@@ -573,8 +571,16 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addLegalOp<UnrealizedConversionCastOp>();
 }
 
-void mlir::populateArmSMEToLLVMConversionPatterns(
-    ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                                  RewritePatternSet &patterns) {
+  converter.addConversion([&](VectorType type) -> std::optional<Type> {
+    // There's no LLVM type for SME tiles, but after lowering to intrinsics all
+    // SME vector types should be eliminated.
+    if (arm_sme::isValidSMETileVectorType(type))
+      return type;
+    return std::nullopt;
+  });
+
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
                OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
deleted file mode 100644
index 1cefc220ecf10..0000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-
-using namespace mlir;
-arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
-    MLIRContext *ctx, const LowerToLLVMOptions &options)
-    : LLVMTypeConverter(ctx, options) {
-  // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
-  // vectors (common in the context of ArmSME), e.g.
-  //    `vector<[16]x[16]xi8>`,
-  // entering the LLVM Type converter. LLVM does not support arrays of scalable
-  // vectors, but in the case of SME such types are effectively eliminated when
-  // emitting ArmSME LLVM IR intrinsics.
-  addConversion([&](VectorType type) { return type; });
-}

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 7b6b2e77dcebf..96eb584420438 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
-  ArmSMETypeConverter.cpp
   EnableArmStreaming.cpp
   TileAllocation.cpp
 

diff  --git a/mlir/unittests/Dialect/ArmSME/CMakeLists.txt b/mlir/unittests/Dialect/ArmSME/CMakeLists.txt
new file mode 100644
index 0000000000000..affd435ef7bfc
--- /dev/null
+++ b/mlir/unittests/Dialect/ArmSME/CMakeLists.txt
@@ -0,0 +1,5 @@
+add_mlir_unittest(MLIRArmSMETests
+  TileTypeConversionTest.cpp)
+target_link_libraries(MLIRArmSMETests
+  PRIVATE
+  MLIRArmSMEToLLVM)

diff  --git a/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
new file mode 100644
index 0000000000000..305f879489813
--- /dev/null
+++ b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
@@ -0,0 +1,51 @@
+//===- TileTypeConversionTest.cpp - Tests ArmSME tile type conversion -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+class ArmSMETest : public ::testing::Test {
+protected:
+  ArmSMETest() { context.getOrLoadDialect<mlir::arm_sme::ArmSMEDialect>(); }
+
+  mlir::MLIRContext context;
+};
+
+TEST_F(ArmSMETest, TestTileTypeConversion) {
+  LLVMTypeConverter llvmConverter(&context);
+  LLVMTypeConverter llvmConverterWithArmSMEConversion(&context);
+
+  RewritePatternSet patterns(&context);
+  populateArmSMEToLLVMConversionPatterns(llvmConverterWithArmSMEConversion,
+                                         patterns);
+
+  Type i32 = IntegerType::get(&context, 32);
+  auto smeTileType = VectorType::get({4, 4}, i32, {true, true});
+
+  // An unmodified LLVMTypeConverer should fail to convert an ArmSME tile type.
+  {
+    SmallVector<Type> convertedType;
+    ASSERT_TRUE(failed(llvmConverter.convertType(smeTileType, convertedType)));
+  }
+
+  // An updated LLVMTypeConverer should return the ArmSME tile vector type
+  // unchanged.
+  {
+    SmallVector<Type> convertedType;
+    ASSERT_TRUE(succeeded(llvmConverterWithArmSMEConversion.convertType(
+        smeTileType, convertedType)));
+    ASSERT_EQ(ArrayRef<Type>(convertedType), ArrayRef<Type>{smeTileType});
+  }
+}

diff  --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index fbb73e8f499a3..2dec4ba3c001e 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
   MLIRIR
   MLIRDialect)
 
+add_subdirectory(ArmSME)
 add_subdirectory(Index)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)


        


More information about the Mlir-commits mailing list