[Mlir-commits] [mlir] [mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead) (PR #73639)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 28 04:13:02 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sme
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This patch removes the ArmSMETypeConverter, and instead updates `configureArmSMEToLLVMConversionLegality()` to add an ArmSME vector type conversion to the existing LLVMTypeConverter. This makes it easier to add these patterns to an existing `-to-llvm` lowering pass.
---
Full diff: https://github.com/llvm/llvm-project/pull/73639.diff
5 Files Affected:
- (modified) mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h (+3-4)
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (-8)
- (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+13-7)
- (removed) mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp (-22)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (-1)
``````````diff
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index fe851d17867dff5..b2130742e0f71cf 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -20,17 +20,16 @@ class RewritePatternSet;
#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"
-using arm_sme::ArmSMETypeConverter;
-
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
-void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
+ LLVMTypeConverter &typeConverter);
/// Populate the given list with patterns that convert from the ArmSME dialect
/// to LLVM intrinsics.
-void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 6f7617f5411c57f..4c9907be4086b1f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
std::unique_ptr<Pass> createTileAllocationPass();
-//===----------------------------------------------------------------------===//
-// Type ArmSMETypeConverter pass.
-//===----------------------------------------------------------------------===//
-class ArmSMETypeConverter : public LLVMTypeConverter {
-public:
- ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
-};
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..12f9df7ed4de2f4 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -540,10 +540,8 @@ struct ConvertArmSMEToLLVMPass
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
- ArmSMETypeConverter converter(&getContext(),
- LowerToLLVMOptions(&getContext()));
-
- configureArmSMEToLLVMConversionLegality(target);
+ LLVMTypeConverter converter(&getContext());
+ configureArmSMEToLLVMConversionLegality(target, converter);
populateArmSMEToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
@@ -554,7 +552,8 @@ struct ConvertArmSMEToLLVMPass
} // namespace
-void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
+void mlir::configureArmSMEToLLVMConversionLegality(
+ ConversionTarget &target, LLVMTypeConverter &typeConverter) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
@@ -574,10 +573,17 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_mopa>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
+ typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
+ // There's no LLVM type for SME tiles, but after lowering to intrinsics all
+ // SME vector types should be eliminated.
+ if (arm_sme::isValidSMETileVectorType(type))
+ return type;
+ return std::nullopt;
+ });
}
-void mlir::populateArmSMEToLLVMConversionPatterns(
- ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion>(converter);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
deleted file mode 100644
index 1cefc220ecf1035..000000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-
-using namespace mlir;
-arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
- MLIRContext *ctx, const LowerToLLVMOptions &options)
- : LLVMTypeConverter(ctx, options) {
- // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
- // vectors (common in the context of ArmSME), e.g.
- // `vector<[16]x[16]xi8>`,
- // entering the LLVM Type converter. LLVM does not support arrays of scalable
- // vectors, but in the case of SME such types are effectively eliminated when
- // emitting ArmSME LLVM IR intrinsics.
- addConversion([&](VectorType type) { return type; });
-}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index e2407d9f48f7061..f0854d0678e106d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_dialect_library(MLIRArmSMETransforms
- ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
TileAllocation.cpp
``````````
</details>
https://github.com/llvm/llvm-project/pull/73639
More information about the Mlir-commits
mailing list