[Mlir-commits] [mlir] 3fa5ee6 - [mlir][ArmSME] Introduce custom TypeConverter for ArmSME

Andrzej Warzynski llvmlistbot at llvm.org
Tue Jul 18 02:35:55 PDT 2023


Author: Andrzej Warzynski
Date: 2023-07-18T09:35:32Z
New Revision: 3fa5ee67babad11a88943ede42a4123299acf31a

URL: https://github.com/llvm/llvm-project/commit/3fa5ee67babad11a88943ede42a4123299acf31a
DIFF: https://github.com/llvm/llvm-project/commit/3fa5ee67babad11a88943ede42a4123299acf31a.diff

LOG: [mlir][ArmSME] Introduce custom TypeConverter for ArmSME

At the moment, SME-to-LLVM lowerings rely entirely on
`LLVMTypeConverter`. This patch introduces a dedicated `TypeConverter`
that inherits from `LLVMTypeConverter` (it will also be used when
lowering ArmSME Ops to LLVM).

The new type converter merely disables lowerings for `VectorType` to
prevent 2-d scalable vectors (common in the context of ArmSME), e.g.

   `vector<[16]x[16]xi8>`,

entering the LLVM Type converter. LLVM does not support arrays of
scalable vectors and hence the need for specialisation. In the case of
SME such types are effectively eliminated when emitting LLVM IR
intrinsics for SME.

Differential Revision: https://reviews.llvm.org/D155365

Added: 
    mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp

Modified: 
    mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
    mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 133968b60665b0..ab5c179f2dd779 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H
 #define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H
 
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -16,6 +17,9 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arm_sme {
+//===----------------------------------------------------------------------===//
+// The EnableArmStreaming pass.
+//===----------------------------------------------------------------------===//
 // Options for Armv9 Streaming SVE mode. By default, streaming-mode is part of
 // the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
 // In a locally streaming function PSTATE.SM is kept internal and the callee
@@ -33,6 +37,14 @@ createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
 /// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
 std::unique_ptr<Pass> createTileAllocationPass();
 
+//===----------------------------------------------------------------------===//
+// Type ArmSMETypeConverter pass.
+//===----------------------------------------------------------------------===//
+class ArmSMETypeConverter : public LLVMTypeConverter {
+public:
+  ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
+};
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index e4a5528c298924..bb92d6506054f2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRVectorToLLVM
   MLIRArmNeonDialect
   MLIRArmSMEDialect
   MLIRArmSMETransforms
+  MLIRVectorToArmSME
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index acc4244ce9bb87..04570a750822ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms.h"
@@ -96,6 +97,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalDialect<memref::MemRefDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
+  arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
+
   if (armNeon) {
     // TODO: we may or may not want to include in-dialect lowering to
     // LLVM-compatible operations here. So far, all operations in the dialect
@@ -108,7 +111,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
   }
   if (armSME) {
     configureArmSMELegalizeForExportTarget(target);
-    populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
+    populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
   }
   if (amx) {
     configureAMXLegalizeForExportTarget(target);

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
new file mode 100644
index 00000000000000..1cefc220ecf103
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
@@ -0,0 +1,22 @@
+//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+
+using namespace mlir;
+arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
+    MLIRContext *ctx, const LowerToLLVMOptions &options)
+    : LLVMTypeConverter(ctx, options) {
+  // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
+  // vectors (common in the context of ArmSME), e.g.
+  //    `vector<[16]x[16]xi8>`,
+  // entering the LLVM Type converter. LLVM does not support arrays of scalable
+  // vectors, but in the case of SME such types are effectively eliminated when
+  // emitting ArmSME LLVM IR intrinsics.
+  addConversion([&](VectorType type) { return type; });
+}

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 247da2a3a4aa11..991beae0bec9cf 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
+  ArmSMETypeConverter.cpp
   EnableArmStreaming.cpp
   LegalizeForLLVMExport.cpp
   TileAllocation.cpp


        


More information about the Mlir-commits mailing list