[Mlir-commits] [mlir] [mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead) (PR #73639)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon Dec 4 02:50:48 PST 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/73639
>From c457e2a2bcd6583760ad6f7afae10b1c3a117cd1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 27 Nov 2023 19:08:02 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Remove ArmSMETypeConverter (and configure
LLVM one instead)
This patch removes the ArmSMETypeConverter, and instead updates
`configureArmSMEToLLVMConversionLegality()` to add an ArmSME vector type
conversion to the existing LLVMTypeConverter. This makes it easier to
add these patterns to an existing `-to-llvm` lowering pass.
---
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h | 7 +++---
.../mlir/Dialect/ArmSME/Transforms/Passes.h | 8 -------
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 20 +++++++++++------
.../ArmSME/Transforms/ArmSMETypeConverter.cpp | 22 -------------------
.../Dialect/ArmSME/Transforms/CMakeLists.txt | 1 -
5 files changed, 16 insertions(+), 42 deletions(-)
delete mode 100644 mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index fe851d17867df..b2130742e0f71 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -20,17 +20,16 @@ class RewritePatternSet;
#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"
-using arm_sme::ArmSMETypeConverter;
-
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
-void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
+ LLVMTypeConverter &typeConverter);
/// Populate the given list with patterns that convert from the ArmSME dialect
/// to LLVM intrinsics.
-void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 21a97e9cbc794..aef2959265a7c 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();
-//===----------------------------------------------------------------------===//
-// Type ArmSMETypeConverter pass.
-//===----------------------------------------------------------------------===//
-class ArmSMETypeConverter : public LLVMTypeConverter {
-public:
- ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
-};
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index a28b8ef7f7fce..dbbcd9e643566 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -538,10 +538,8 @@ struct ConvertArmSMEToLLVMPass
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
- ArmSMETypeConverter converter(&getContext(),
- LowerToLLVMOptions(&getContext()));
-
- configureArmSMEToLLVMConversionLegality(target);
+ LLVMTypeConverter converter(&getContext());
+ configureArmSMEToLLVMConversionLegality(target, converter);
populateArmSMEToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
@@ -552,7 +550,8 @@ struct ConvertArmSMEToLLVMPass
} // namespace
-void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
+void mlir::configureArmSMEToLLVMConversionLegality(
+ ConversionTarget &target, LLVMTypeConverter &typeConverter) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
@@ -571,10 +570,17 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
+ typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
+ // There's no LLVM type for SME tiles, but after lowering to intrinsics all
+ // SME vector types should be eliminated.
+ if (arm_sme::isValidSMETileVectorType(type))
+ return type;
+ return std::nullopt;
+ });
}
-void mlir::populateArmSMEToLLVMConversionPatterns(
- ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
deleted file mode 100644
index 1cefc220ecf10..0000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-
-using namespace mlir;
-arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
- MLIRContext *ctx, const LowerToLLVMOptions &options)
- : LLVMTypeConverter(ctx, options) {
- // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
- // vectors (common in the context of ArmSME), e.g.
- // `vector<[16]x[16]xi8>`,
- // entering the LLVM Type converter. LLVM does not support arrays of scalable
- // vectors, but in the case of SME such types are effectively eliminated when
- // emitting ArmSME LLVM IR intrinsics.
- addConversion([&](VectorType type) { return type; });
-}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 7b6b2e77dcebf..96eb584420438 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_dialect_library(MLIRArmSMETransforms
- ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
TileAllocation.cpp
>From 8b74abfa3c77e9dc927790a7d9a0e052b0833d6e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 29 Nov 2023 15:52:09 +0000
Subject: [PATCH 2/3] Fixup
---
.../mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h | 3 +--
mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 14 +++++++-------
2 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index b2130742e0f71..eab871ab49998 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -24,8 +24,7 @@ class RewritePatternSet;
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
-void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
- LLVMTypeConverter &typeConverter);
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
/// Populate the given list with patterns that convert from the ArmSME dialect
/// to LLVM intrinsics.
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index dbbcd9e643566..f9d6f04a811f3 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -539,7 +539,7 @@ struct ConvertArmSMEToLLVMPass
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
- configureArmSMEToLLVMConversionLegality(target, converter);
+ configureArmSMEToLLVMConversionLegality(target);
populateArmSMEToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
@@ -550,8 +550,7 @@ struct ConvertArmSMEToLLVMPass
} // namespace
-void mlir::configureArmSMEToLLVMConversionLegality(
- ConversionTarget &target, LLVMTypeConverter &typeConverter) {
+void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
@@ -570,17 +569,18 @@ void mlir::configureArmSMEToLLVMConversionLegality(
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
- typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
+}
+
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ converter.addConversion([&](VectorType type) -> std::optional<Type> {
// There's no LLVM type for SME tiles, but after lowering to intrinsics all
// SME vector types should be eliminated.
if (arm_sme::isValidSMETileVectorType(type))
return type;
return std::nullopt;
});
-}
-void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns) {
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
>From 5e83afd0171c212a4d5c15db8e263650ac25e0f0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 4 Dec 2023 10:49:21 +0000
Subject: [PATCH 3/3] Add ArmSME type conversion unit test
---
mlir/unittests/Dialect/ArmSME/CMakeLists.txt | 5 ++
.../Dialect/ArmSME/TileTypeConversionTest.cpp | 51 +++++++++++++++++++
mlir/unittests/Dialect/CMakeLists.txt | 1 +
3 files changed, 57 insertions(+)
create mode 100644 mlir/unittests/Dialect/ArmSME/CMakeLists.txt
create mode 100644 mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
diff --git a/mlir/unittests/Dialect/ArmSME/CMakeLists.txt b/mlir/unittests/Dialect/ArmSME/CMakeLists.txt
new file mode 100644
index 0000000000000..affd435ef7bfc
--- /dev/null
+++ b/mlir/unittests/Dialect/ArmSME/CMakeLists.txt
@@ -0,0 +1,5 @@
+add_mlir_unittest(MLIRArmSMETests
+ TileTypeConversionTest.cpp)
+target_link_libraries(MLIRArmSMETests
+ PRIVATE
+ MLIRArmSMEToLLVM)
diff --git a/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
new file mode 100644
index 0000000000000..4574766e1cd7f
--- /dev/null
+++ b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
@@ -0,0 +1,51 @@
+//===- TileTypeConversionTest.cpp - Tests ArmSME tile type conversion -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+class ArmSMETest : public ::testing::Test {
+protected:
+ ArmSMETest() { context.getOrLoadDialect<mlir::arm_sme::ArmSMEDialect>(); }
+
+ mlir::MLIRContext context;
+};
+
+TEST_F(ArmSMETest, TestTileTypeConversion) {
+ LLVMTypeConverter llvmConverer(&context);
+ LLVMTypeConverter llvmConvererWithArmSMEConversion(&context);
+
+ RewritePatternSet patterns(&context);
+ populateArmSMEToLLVMConversionPatterns(llvmConvererWithArmSMEConversion,
+ patterns);
+
+ Type i32 = IntegerType::get(&context, 32);
+ auto smeTileType = VectorType::get({4, 4}, i32, {true, true});
+
+ // An unmodified LLVMTypeConverer should fail to convert an ArmSME tile type.
+ {
+ SmallVector<Type> convertedType;
+ ASSERT_TRUE(failed(llvmConverer.convertType(smeTileType, convertedType)));
+ }
+
+ // An updated LLVMTypeConverer should return the ArmSME tile vector type
+ // unchanged.
+ {
+ SmallVector<Type> convertedType;
+ ASSERT_TRUE(succeeded(llvmConvererWithArmSMEConversion.convertType(
+ smeTileType, convertedType)));
+ ASSERT_EQ(ArrayRef<Type>(convertedType), ArrayRef<Type>{smeTileType});
+ }
+}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index fbb73e8f499a3..2dec4ba3c001e 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
MLIRIR
MLIRDialect)
+add_subdirectory(ArmSME)
add_subdirectory(Index)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
More information about the Mlir-commits
mailing list