[Mlir-commits] [mlir] [mlir][ArmSME] Use ArmSMETypeConverter for all VectorToLLVM patterns (PR #65261)

Cullen Rhodes llvmlistbot at llvm.org
Mon Sep 4 06:11:49 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/65261:

LLVMTypeConverter::convertVectorType asserts on n-D scalable vectors to prevent generating illegal LLVM IR, since LLVM doesn't support arrays of scalable vectors. The ArmSMETypeConverter disables this conversion, but is only used for ArmSME dialect conversions that rewrite higher-level custom ArmSME ops to intrinsics.

This is problematic if we want to lower Vector ops directly to ArmSME intrinsics, as the assert fires for ops that have dialect conversion patterns (defined in ConvertVectorToLLVMPass, e.g. populateVectorToLLVMConversionPatterns) that use the LLVMTypeConverter.

There are three options to get around this:

  1. Avoid the generic VectorToLLVM dialect conversion patterns (and thus the assert) altogether by first lowering Vector ops to custom ArmSME ops.
  2. Disable the generic VectorToLLVM dialect conversion patterns if ArmSME is enabled.
  3. Disable n-D scalable vector type conversion for all dialect conversion patterns if SME is enabled.

Option 1 is already done for several Vector ops such as vector.load and vector.store as part of ConvertVectorToArmSME, but where possible we'd like to avoid bloating the ArmSME dialect by having to mirror all the Vector ops.

Option 2 is undesirable as the generic conversions should only be disabled for the 2-d scalable vector types the ArmSME patterns apply to. We'd still like Vector ops with other types to get lowered via the default path when ArmSME is enabled.

This patch therefore implements option 3 to use the ArmSMETypeConverter for all VectorToLLVM conversion patterns when ArmSME is enabled.

Depends on #65254

>From f5d1818634b782645692d71f2b18dbd2a8d27b04 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 1 Sep 2023 09:43:39 +0000
Subject: [PATCH] [mlir][ArmSME] Use ArmSMETypeConverter for all VectorToLLVM
 patterns

LLVMTypeConverter::convertVectorType asserts on n-D scalable vectors to
prevent generating illegal LLVM IR, since LLVM doesn't support arrays of
scalable vectors. The ArmSMETypeConverter disables this conversion, but
is only used for ArmSME dialect conversions that rewrite higher-level
custom ArmSME ops to intrinsics.

This is problematic if we want to lower Vector ops directly to ArmSME
intrinsics, as the assert fires for ops that have dialect conversion
patterns (defined in ConvertVectorToLLVMPass, e.g.
populateVectorToLLVMConversionPatterns) that use the LLVMTypeConverter.

There are three options to get around this:

  1. Avoid the generic VectorToLLVM dialect conversion patterns (and
  thus the assert) altogether by first lowering Vector ops to custom
  ArmSME ops.
  2. Disable the generic VectorToLLVM dialect conversion patterns if
  ArmSME is enabled.
  3. Disable n-D scalable vector type conversion for all dialect
  conversion patterns if SME is enabled.

Option 1 is already done for several Vector ops such as vector.load and
vector.store as part of ConvertVectorToArmSME, but where possible we'd
like to avoid bloating the ArmSME dialect by having to mirror all the
Vector ops.

Option 2 is undesirable as the generic conversions should only be
disabled for the 2-d scalable vector types the ArmSME patterns apply to.
We'd still like Vector ops with other types to get lowered via the
default path when ArmSME is enabled.

This patch therefore implements option 3 to use the ArmSMETypeConverter
for all VectorToLLVM conversion patterns when ArmSME is enabled.
---
 .../Conversion/LLVMCommon/TypeConverter.h     |  7 +++---
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   | 13 +++++++++++
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  | 23 +++++++++++--------
 .../ArmSME/Transforms/ArmSMETypeConverter.cpp | 15 ++++++------
 .../VectorToLLVM/vector-to-llvm.mlir          |  1 +
 5 files changed, 40 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index ed174699314e8d..43db7987e650a7 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -238,14 +238,15 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a memref type to a bare pointer to the memref element type.
   Type convertMemRefToBarePtr(BaseMemRefType type) const;
 
-  /// Convert a 1D vector type into an LLVM vector type.
-  Type convertVectorType(VectorType type) const;
-
   /// Options for customizing the llvm lowering.
   LowerToLLVMOptions options;
 
   /// Data layout analysis mapping scopes to layouts active in them.
   const DataLayoutAnalysis *dataLayoutAnalysis;
+
+protected:
+  /// Convert a 1D vector type into an LLVM vector type.
+  Type convertVectorType(VectorType type) const;
 };
 
 /// Callback to convert function argument types. It converts a MemRef function
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index ab5c179f2dd779..ad3c010816fa3e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -43,6 +43,19 @@ std::unique_ptr<Pass> createTileAllocationPass();
 class ArmSMETypeConverter : public LLVMTypeConverter {
 public:
   ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
+
+protected:
+  /// Convert an n-D vector type to an LLVM vector type.
+  ///
+  /// Disables type conversion of legal 2-D scalable vector types such as
+  /// `vector<[16]x[16]xi8>` for ArmSME, since LLVM does not support arrays of
+  /// scalable vectors and the LLVM type converter asserts on such types to
+  /// prevent generation of illegal LLVM IR. When lowering to ArmSME these types
+  /// should be eliminated before lowering to LLVM.
+  ///
+  /// Types unrelated to ArmSME are converted by
+  /// `LLVMTypeConverter::convertVectorType`.
+  Type convertVectorType(VectorType type) const;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 04570a750822ae..c534ef6e408b8d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -83,21 +83,26 @@ void LowerVectorToLLVMPass::runOnOperation() {
   // Convert to the LLVM IR dialect.
   LowerToLLVMOptions options(&getContext());
   options.useOpaquePointers = useOpaquePointers;
-  LLVMTypeConverter converter(&getContext(), options);
+
+  LLVMTypeConverter *converter;
+  if (armSME)
+    converter = new arm_sme::ArmSMETypeConverter(&getContext(), options);
+  else
+    converter = new LLVMTypeConverter(&getContext(), options);
+
   RewritePatternSet patterns(&getContext());
   populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
   populateVectorTransferLoweringPatterns(patterns);
-  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
+  populateVectorToLLVMMatrixConversionPatterns(*converter, patterns);
   populateVectorToLLVMConversionPatterns(
-      converter, patterns, reassociateFPReductions, force32BitVectorIndices);
-  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
+      *converter, patterns, reassociateFPReductions, force32BitVectorIndices);
+  populateVectorToLLVMMatrixConversionPatterns(*converter, patterns);
 
   // Architecture specific augmentations.
   LLVMConversionTarget target(getContext());
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalDialect<memref::MemRefDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
-  arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
 
   if (armNeon) {
     // TODO: we may or may not want to include in-dialect lowering to
@@ -107,19 +112,19 @@ void LowerVectorToLLVMPass::runOnOperation() {
   }
   if (armSVE) {
     configureArmSVELegalizeForExportTarget(target);
-    populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
+    populateArmSVELegalizeForLLVMExportPatterns(*converter, patterns);
   }
   if (armSME) {
     configureArmSMELegalizeForExportTarget(target);
-    populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
+    populateArmSMELegalizeForLLVMExportPatterns(*converter, patterns);
   }
   if (amx) {
     configureAMXLegalizeForExportTarget(target);
-    populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
+    populateAMXLegalizeForLLVMExportPatterns(*converter, patterns);
   }
   if (x86Vector) {
     configureX86VectorLegalizeForExportTarget(target);
-    populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
+    populateX86VectorLegalizeForLLVMExportPatterns(*converter, patterns);
   }
 
   if (failed(
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
index 1cefc220ecf103..65da2a7a75d29c 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
@@ -7,16 +7,17 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
 
 using namespace mlir;
 arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
     MLIRContext *ctx, const LowerToLLVMOptions &options)
     : LLVMTypeConverter(ctx, options) {
-  // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
-  // vectors (common in the context of ArmSME), e.g.
-  //    `vector<[16]x[16]xi8>`,
-  // entering the LLVM Type converter. LLVM does not support arrays of scalable
-  // vectors, but in the case of SME such types are effectively eliminated when
-  // emitting ArmSME LLVM IR intrinsics.
-  addConversion([&](VectorType type) { return type; });
+  addConversion([&](VectorType type) { return convertVectorType(type); });
+}
+
+Type arm_sme::ArmSMETypeConverter::convertVectorType(VectorType type) const {
+  if (arm_sme::isValidSMETileVectorType(type))
+    return type;
+  return LLVMTypeConverter::convertVectorType(type);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 514594240d22a1..3f897fbf01b7bc 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1 enable-arm-sme' -split-input-file | FileCheck %s
 
 
 func.func @bitcast_f32_to_i32_vector_0d(%input: vector<f32>) -> vector<i32> {



More information about the Mlir-commits mailing list