[Mlir-commits] [mlir] [mlir][ArmSME] Move ArmSME -> intrinsics lowerings to `convert-arm-sme-to-llvm` pass (PR #72890)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Nov 21 02:42:32 PST 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/72890

>From c0430a20b8fd9a74d0e8ab834cc6c42eb48d5758 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 20 Nov 2023 14:06:57 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Move ArmSME -> intrinsics lowerings to
 convert-arm-sme-to-llvm pass (NFC)

This gives more flexibility with when these lowerings are performed,
without also lowering unrelated vector ops.
---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h    |  26 +++++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  17 ++-
 .../Dialect/ArmSME/Transforms/Transforms.h    |   9 --
 .../ArmSMEToLLVM/ArmSMEToLLVM.cpp}            | 108 +++++++++++-------
 .../Conversion/ArmSMEToLLVM/CMakeLists.txt    |  16 +++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  10 --
 .../Dialect/ArmSME/Transforms/CMakeLists.txt  |   1 -
 .../Dialect/ArmSME/arm-sme-to-llvm-casts.mlir |   2 +-
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir |   2 +-
 mlir/test/Dialect/ArmSME/enable-arm-za.mlir   |   6 +-
 mlir/test/Dialect/ArmSME/tile-zero-masks.mlir |   2 +-
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    |  70 +++++-------
 .../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir    |   2 +-
 .../Linalg/CPU/ArmSME/matmul-transpose-a.mlir |   2 +-
 .../CPU/ArmSME/load-store-128-bit-tile.mlir   |   2 +-
 .../Vector/CPU/ArmSME/test-load-vertical.mlir |   2 +-
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |   2 +-
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |   2 +-
 .../CPU/ArmSME/test-transfer-read-2d.mlir     |   2 +-
 .../CPU/ArmSME/test-transfer-write-2d.mlir    |   2 +-
 .../Vector/CPU/ArmSME/test-transpose.mlir     |   2 +-
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  |   2 +-
 .../Vector/CPU/ArmSME/vector-load-store.mlir  |   2 +-
 .../Dialect/Vector/CPU/ArmSME/vector-ops.mlir |   2 +-
 26 files changed, 173 insertions(+), 122 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
 rename mlir/lib/{Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp => Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp} (86%)
 create mode 100644 mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt

diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
new file mode 100644
index 000000000000000..ce778581b2cee37
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -0,0 +1,26 @@
+//===- ArmSMEToLLVM.h - Convert ArmSME to LLVM dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
+#define MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Create a pass to convert a subset of ArmSME ops to SCF.
+std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 3078d909a8946dd..a25fd17ea923fb5 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -15,6 +15,7 @@
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
 #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 626f5f3d19d307e..a0cc05319bb7299 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1241,6 +1241,19 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ArmSMEToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+  let summary = "Lower the operations from the ArmSME dialect into the LLVM "
+                "dialect";
+  let constructor = "mlir::createConvertArmSMEToLLVMPass()";
+    let dependentDialects = [
+      "arm_sme::ArmSMEDialect",
+      "LLVM::LLVMDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // VectorToLLVM
 //===----------------------------------------------------------------------===//
@@ -1280,10 +1293,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of ArmSVE dialect while lowering the vector "
        "dialect.">,
-    Option<"armSME", "enable-arm-sme",
-           "bool", /*default=*/"false",
-           "Enables the use of ArmSME dialect while lowering the vector "
-       "dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index fae04513859938b..8ea3e1e57b7caa5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -20,15 +20,6 @@ void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter,
                                             RewritePatternSet &patterns);
 } // namespace arm_sme
 
-/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
-/// intrinsics.
-void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
-                                                 RewritePatternSet &patterns);
-
-/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
-/// intrinsics.
-void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);
-
 } // namespace mlir
 
 #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
similarity index 86%
rename from mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
rename to mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 6ccb652ecbbc29e..66eee98cd23e4be 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -1,24 +1,36 @@
-//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
+//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// This file implements lowering of ArmSME operations to LLVM intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
 
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
 
 using namespace mlir;
-using namespace mlir::arm_sme;
 
 namespace {
 
@@ -40,11 +52,11 @@ namespace {
 ///  The 'arm_sme.cast_tile_to_vector' (which models the return) and the
 ///  'arith.shli' (which generates the mask) will be folded away after tile
 ///  allocation and canonization.
-struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
-  using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
+struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
+  using ConvertOpToLLVMPattern<arm_sme::ZeroOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(ZeroOp zero, OpAdaptor adaptor,
+  matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = zero.getLoc();
 
@@ -121,7 +133,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
 };
 
 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
-struct LoadTileSliceToArmSMELowering
+struct LoadTileSliceConversion
     : public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
   using ConvertOpToLLVMPattern<
       arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
@@ -220,7 +232,7 @@ struct LoadTileSliceToArmSMELowering
 };
 
 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
-struct StoreTileSliceToArmSMELowering
+struct StoreTileSliceConversion
     : public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
   using ConvertOpToLLVMPattern<
       arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
@@ -313,7 +325,7 @@ struct StoreTileSliceToArmSMELowering
 };
 
 /// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
-struct MoveVectorToTileSliceToArmSMELowering
+struct MoveVectorToTileSliceConversion
     : public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
   using ConvertOpToLLVMPattern<
       arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern;
@@ -373,7 +385,7 @@ struct MoveVectorToTileSliceToArmSMELowering
 };
 
 /// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
-struct MoveTileSliceToVectorArmSMELowering
+struct MoveTileSliceToVectorConversion
     : public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
   using ConvertOpToLLVMPattern<
       arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
@@ -456,7 +468,8 @@ struct OuterProductOpConversion
       //   * half-precision   - +sme2p1,+b16b16
       //
       // It should be possible to control lowering based on target features.
-      // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
+      // [1]
+      // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
       if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
         return false;
 
@@ -475,7 +488,7 @@ struct OuterProductOpConversion
     };
 
     // TODO: Support CombiningKind::Sub for outer products.
-    if (outerProductOp.getKind() != CombiningKind::Add)
+    if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
       return outerProductOp.emitError("unsupported kind");
 
     auto resultVectorType = outerProductOp.getResultType();
@@ -522,32 +535,49 @@ struct OuterProductOpConversion
 
 } // namespace
 
-void mlir::configureArmSMELegalizeForExportTarget(
-    LLVMConversionTarget &target) {
-  target.addLegalOp<
-      scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
-      arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
-      arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
-      arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
-      arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
-      arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
-      arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
-      arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
-      arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
-      arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
-      arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
-      arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
-      arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
-      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
-  target.addLegalOp<GetTileID>();
-  target.addIllegalOp<vector::OuterProductOp>();
-}
+namespace {
+
+struct ConvertArmSMEToLLVMPass
+    : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    arm_sme::ArmSMETypeConverter converter(&getContext(),
+                                           LowerToLLVMOptions(&getContext()));
+
+    patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
+                 MoveVectorToTileSliceConversion, StoreTileSliceConversion,
+                 OuterProductOpConversion, ZeroOpConversion>(converter);
+
+    LLVMConversionTarget target(getContext());
+    target.addLegalDialect<arith::ArithDialect>();
+    target.addLegalOp<UnrealizedConversionCastOp>();
+    target.addLegalOp<
+        scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
+        arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
+        arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+        arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+        arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+        arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+        arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+        arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+        arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+        arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+        arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+        arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+        arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+        arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+        arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
+    target.addLegalOp<arm_sme::GetTileID>();
+    target.addIllegalOp<vector::OuterProductOp>();
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
 
-void mlir::populateArmSMELegalizeForLLVMExportPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<
-      LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
-      MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
-      OuterProductOpConversion, ZeroOpConversion>(converter);
+std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
+  return std::make_unique<ConvertArmSMEToLLVMPass>();
 }
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000000..9914f39e17a1a91
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRArmSMEToLLVM
+  ArmSMEToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSMEToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArmSMETransforms
+  MLIRArmSMEDialect
+  MLIRArmSMEUtils
+  MLIRTransforms
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect)
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 822ce5aca255510..c3a2481975040c9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)
 add_subdirectory(ArmSMEToSCF)
+add_subdirectory(ArmSMEToLLVM)
 add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexToLibm)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4c6d0672d4108ef..ff8e78a668e0f10 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,9 +14,6 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -52,8 +49,6 @@ struct LowerVectorToLLVMPass
       registry.insert<arm_neon::ArmNeonDialect>();
     if (armSVE)
       registry.insert<arm_sve::ArmSVEDialect>();
-    if (armSME)
-      registry.insert<arm_sme::ArmSMEDialect>();
     if (amx)
       registry.insert<amx::AMXDialect>();
     if (x86Vector)
@@ -96,7 +91,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalDialect<memref::MemRefDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
-  arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
 
   if (armNeon) {
     // TODO: we may or may not want to include in-dialect lowering to
@@ -108,10 +102,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
     configureArmSVELegalizeForExportTarget(target);
     populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
   }
-  if (armSME) {
-    configureArmSMELegalizeForExportTarget(target);
-    populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
-  }
   if (amx) {
     configureAMXLegalizeForExportTarget(target);
     populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 8f485db4e8438b1..e2407d9f48f7061 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   ArmSMETypeConverter.cpp
   EnableArmStreaming.cpp
-  LegalizeForLLVMExport.cpp
   TileAllocation.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
index 2c26c62ad42481e..65996e81c42d909 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s
 
 // This test verifies the temporary casts that are emitted when lowering to
 // intrinsics to preserve data flow are correct. Canonicalization will remove
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 8fdcf69958244f3..fa62332bc3f5b17 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
 
 // Test conversion of ArmSME ops to LLVM intrinsics.
 
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index 0f31278eefd1550..ba650b031e6110b 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING
 
 // CHECK-LABEL: @declaration
 func.func private @declaration()
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 26cd91bd3e8956a..2378f4234aef1ef 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" \
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm \
 // RUN:           -allocate-arm-sme-tiles -canonicalize      \
 // RUN:           -allow-unregistered-dialect                \
 // RUN: | FileCheck %s
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 721ff8f2c3589d4..c288f786f89a947 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
@@ -17,9 +17,8 @@
 // CHECK-DAG:  %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
 // CHECK-DAG:  %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32
 // CHECK-DAG:  "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> ()
-// CHECK-DAG:  %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
+// CHECK-DAG:  %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
 // CHECK:        %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -58,9 +57,8 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
 // CHECK-DAG:  %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK-DAG:  %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
+// CHECK-DAG:  %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
 // CHECK-NEXT:   %[[TILE_SLICE_PLUS_OFF0:.*]] = arith.addi %[[TILE_SLICE]], %[[C123]] : index
 // CHECK-NEXT:   %[[TILE_SLICE_PLUS_OFF0_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_PLUS_OFF0]] : index to i64
@@ -92,9 +90,8 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
-// CHECK-DAG:  %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
+// CHECK-DAG:  %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
 // CHECK-NEXT:   %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE]], %[[SVL_B]] : index
 // CHECK-NEXT:   %[[TILE_SLICE_IDX_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_IDX]] : index to i64
@@ -255,9 +252,8 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
 // CHECK-DAG:  %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
-// CHECK-DAG:  %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
+// CHECK-DAG:  %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
 // CHECK:        %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
 // CHECK-NEXT:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
@@ -466,14 +462,8 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 // CHECK-LABEL: @vector_outerproduct_masked_f32
 // CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
 func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
-  // CHECK: %[[DIM0_I32:.*]] = arith.index_cast %[[DIM0]] : index to i32
-  // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertelement %[[DIM0_I32]], {{.*}} : vector<[4]xi32>
-  // CHECK: %[[SPLAT_DIM0:.*]] = llvm.shufflevector %[[INSERT_DIM0]], {{.*}} : vector<[4]xi32>
-  // CHECK: %[[LHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM0]] : vector<[4]xi32>
-  // CHECK: %[[DIM1_I32:.*]] = arith.index_cast %[[DIM1]] : index to i32
-  // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertelement %[[DIM1_I32]], {{.*}} : vector<[4]xi32>
-  // CHECK: %[[SPLAT_DIM1:.*]] = llvm.shufflevector %[[INSERT_DIM1]], {{.*}} : vector<[4]xi32>
-  // CHECK: %[[RHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM1]] : vector<[4]xi32>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
   // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
   // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
   %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
@@ -486,8 +476,8 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
 // CHECK-LABEL: @vector_outerproduct_masked_f16
 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
 func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
-  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
-  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
   // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
@@ -499,8 +489,8 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
 // CHECK-LABEL: @vector_outerproduct_masked_bf16
 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
 func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
-  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
-  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
   // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
@@ -512,8 +502,8 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
 // CHECK-LABEL: @vector_outerproduct_masked_f64
 // CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
 func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
-  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
-  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
+  // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
   // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
   %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
@@ -522,9 +512,9 @@ func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<
 
 // -----
 
-// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
 func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // expected-error at +1 {{AXPY operations not supported}}
+  // expected-error at +2 {{AXPY operations not supported}}
+  // expected-error at +1 {{failed to legalize operation 'vector.outerproduct'}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
   return %0 : vector<[2]xf64>
 }
@@ -655,11 +645,10 @@ func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2
 func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
   // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64
   // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
   // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
   // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
-  // CHECK-NEXT: %[[NEW_SLICE:.*]] = llvm.insertelement %[[EL]], %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32>
+  // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32>
   // CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[ROW]] : index to i32
   // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
   %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32>
@@ -846,11 +835,10 @@ func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) ->
 func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col: index) -> i32 {
   // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64
   // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
   // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
   // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
-  // CHECK-NEXT: %[[EL:.*]] = llvm.extractelement %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32>
+  // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32>
   %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32>
   return %el : i32
 }
@@ -860,7 +848,7 @@ func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col:
 // CHECK-LABEL: @vector_extract_element_i8
 func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 {
   // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-  // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[16]xi8>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8>
   %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
   return %el : i8
 }
@@ -870,7 +858,7 @@ func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %
 // CHECK-LABEL: @vector_extract_element_i16
 func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 {
   // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-  // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xi16>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16>
   %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
   return %el : i16
 }
@@ -880,7 +868,7 @@ func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %
 // CHECK-LABEL: @vector_extract_element_i64
 func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %col: index) -> i64 {
   // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-  // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xi64>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64>
   %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64>
   return %el : i64
 }
@@ -890,7 +878,7 @@ func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %
 // CHECK-LABEL: @vector_extract_element_i128
 func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 {
   // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-  // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[1]xi128>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128>
   %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
   return %el : i128
 }
@@ -900,7 +888,7 @@ func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index,
 // CHECK-LABEL: @vector_extract_element_f16
 func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %col: index) -> f16 {
   // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-  // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xf16>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16>
   %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16>
   return %el : f16
 }
@@ -910,7 +898,7 @@ func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %
 // CHECK-LABEL: @vector_extract_element_bf16
 func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index, %col: index) -> bf16 {
   // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-  // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xbf16>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16>
   %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16>
   return %el : bf16
 }
@@ -920,7 +908,7 @@ func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index,
 // CHECK-LABEL: @vector_extract_element_f32
 func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 {
   // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-  // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[4]xf32>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32>
   %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
   return %el : f32
 }
@@ -930,7 +918,7 @@ func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %
 // CHECK-LABEL: @vector_extract_element_f64
 func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 {
   // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-  // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xf64>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64>
   %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
   return %el : f64
 }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index efe4da7d3c50c6f..18b95cf2fdf843c 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -5,7 +5,7 @@
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=entry -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index ab74f0100474263..f189fd97d66cde7 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -4,7 +4,7 @@
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
-// RUN:   -convert-vector-to-llvm=enable-arm-sme \
+// RUN:   -convert-arm-sme-to-llvm \
 // RUN:   -convert-vector-to-llvm=enable-arm-sve \
 // RUN:   -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
index 32e7e6b79ce09b9..59b4a7e6a52f9b0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -2,7 +2,7 @@
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 44cf23f41b63254..0c186cc373a3b32 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -2,7 +2,7 @@
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index f1ecf768ebe83db..442a70cacd66508 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -2,7 +2,7 @@
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 5c907bb1675e462..74b51dcc9b4df3a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -2,7 +2,7 @@
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index ccc08289570afc5..82f38b4dbfa9d1f 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -2,7 +2,7 @@
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index f35f83dcec0daa2..3b218aefcd415ff 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -2,7 +2,7 @@
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index 39b5ef2ade4b0c0..e2cbe735fa4ff06 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -2,7 +2,7 @@
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index baf2046722b9e0c..6e33a421bf799a2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:  -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 8878dca8bdcb6b1..961bb274d1e3352 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -2,7 +2,7 @@
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index a890aaa6f309d15..25ef1799e63adb1 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,7 +1,7 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" \
+// DEFINE:   -convert-arm-sme-to-llvm \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \

>From 58f680848e68b8cef6f1e0489106ac80ac64ffc4 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 21 Nov 2023 10:37:56 +0000
Subject: [PATCH 2/2] Fixups

---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h    | 14 ++++-
 mlir/include/mlir/Conversion/Passes.td        |  7 ++-
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 63 ++++++++++---------
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    |  7 +--
 4 files changed, 54 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index ce778581b2cee37..fe851d17867dff5 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -11,6 +11,8 @@
 
 #include <memory>
 
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+
 namespace mlir {
 class Pass;
 class RewritePatternSet;
@@ -18,9 +20,19 @@ class RewritePatternSet;
 #define GEN_PASS_DECL_CONVERTARMSMETOLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-/// Create a pass to convert a subset of ArmSME ops to SCF.
+using arm_sme::ArmSMETypeConverter;
+
+/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
 std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
 
+/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
+
+/// Populate the given list with patterns that convert from the ArmSME dialect
+/// to LLVM intrinsics.
+void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+                                            RewritePatternSet &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a0cc05319bb7299..06756ff3df0bb3b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1249,9 +1249,10 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
   let summary = "Lower the operations from the ArmSME dialect into the LLVM "
                 "dialect";
   let constructor = "mlir::createConvertArmSMEToLLVMPass()";
-    let dependentDialects = [
-      "arm_sme::ArmSMEDialect",
-      "LLVM::LLVMDialect"];
+  let dependentDialects = [
+    "arm_sme::ArmSMEDialect",
+    "LLVM::LLVMDialect"
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 66eee98cd23e4be..e409dc57fb020e2 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -16,8 +16,6 @@
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -540,35 +538,13 @@ namespace {
 struct ConvertArmSMEToLLVMPass
     : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
   void runOnOperation() override {
+    LLVMConversionTarget target(getContext());
     RewritePatternSet patterns(&getContext());
-    arm_sme::ArmSMETypeConverter converter(&getContext(),
-                                           LowerToLLVMOptions(&getContext()));
-
-    patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
-                 MoveVectorToTileSliceConversion, StoreTileSliceConversion,
-                 OuterProductOpConversion, ZeroOpConversion>(converter);
+    ArmSMETypeConverter converter(&getContext(),
+                                  LowerToLLVMOptions(&getContext()));
 
-    LLVMConversionTarget target(getContext());
-    target.addLegalDialect<arith::ArithDialect>();
-    target.addLegalOp<UnrealizedConversionCastOp>();
-    target.addLegalOp<
-        scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
-        arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
-        arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
-        arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
-        arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
-        arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
-        arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
-        arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
-        arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
-        arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
-        arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
-        arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
-        arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
-        arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-        arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
-    target.addLegalOp<arm_sme::GetTileID>();
-    target.addIllegalOp<vector::OuterProductOp>();
+    configureArmSMEToLLVMConversionLegality(target);
+    populateArmSMEToLLVMConversionPatterns(converter, patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -578,6 +554,35 @@ struct ConvertArmSMEToLLVMPass
 
 } // namespace
 
+void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
+  target.addIllegalDialect<arm_sme::ArmSMEDialect>();
+  target.addLegalOp<
+      arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
+      arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
+      arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
+      arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
+      arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
+      arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
+      arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
+      arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
+      arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
+      arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
+      arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
+      arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
+      arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
+      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
+      arm_sme::aarch64_sme_mopa>();
+  target.addLegalDialect<arith::ArithDialect>();
+  target.addLegalOp<UnrealizedConversionCastOp>();
+}
+
+void mlir::populateArmSMEToLLVMConversionPatterns(
+    ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
+               MoveVectorToTileSliceConversion, StoreTileSliceConversion,
+               OuterProductOpConversion, ZeroOpConversion>(converter);
+}
+
 std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
   return std::make_unique<ConvertArmSMEToLLVMPass>();
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index c288f786f89a947..77ac071ef67de9a 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -513,8 +513,7 @@ func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<
 // -----
 
 func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // expected-error at +2 {{AXPY operations not supported}}
-  // expected-error at +1 {{failed to legalize operation 'vector.outerproduct'}}
+  // expected-error at +1 {{AXPY operations not supported}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
   return %0 : vector<[2]xf64>
 }
@@ -522,6 +521,7 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f
 // -----
 
 func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
   "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
@@ -530,7 +530,6 @@ func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : v
 // -----
 
 func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
-  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
   // expected-error at +1 {{unsupported kind}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
@@ -539,7 +538,7 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v
 // -----
 
 func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
-  // expected-error at +1 {{failed to legalize operation 'vector.outerproduct'}}
+  // CHECK: vector.outerproduct
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }



More information about the Mlir-commits mailing list