[Mlir-commits] [mlir] [mlir][ArmSME] Move ArmSME -> intrinsics lowerings to `convert-arm-sme-to-llvm` pass (PR #72890)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Nov 21 02:42:32 PST 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/72890
>From c0430a20b8fd9a74d0e8ab834cc6c42eb48d5758 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 20 Nov 2023 14:06:57 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Move ArmSME -> intrinsics lowerings to
convert-arm-sme-to-llvm pass (NFC)
This gives more flexibility with when these lowerings are performed,
without also lowering unrelated vector ops.
---
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h | 26 +++++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 17 ++-
.../Dialect/ArmSME/Transforms/Transforms.h | 9 --
.../ArmSMEToLLVM/ArmSMEToLLVM.cpp} | 108 +++++++++++-------
.../Conversion/ArmSMEToLLVM/CMakeLists.txt | 16 +++
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 10 --
.../Dialect/ArmSME/Transforms/CMakeLists.txt | 1 -
.../Dialect/ArmSME/arm-sme-to-llvm-casts.mlir | 2 +-
mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 2 +-
mlir/test/Dialect/ArmSME/enable-arm-za.mlir | 6 +-
mlir/test/Dialect/ArmSME/tile-zero-masks.mlir | 2 +-
.../Dialect/ArmSME/vector-ops-to-llvm.mlir | 70 +++++-------
.../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir | 2 +-
.../Linalg/CPU/ArmSME/matmul-transpose-a.mlir | 2 +-
.../CPU/ArmSME/load-store-128-bit-tile.mlir | 2 +-
.../Vector/CPU/ArmSME/test-load-vertical.mlir | 2 +-
.../CPU/ArmSME/test-outerproduct-f32.mlir | 2 +-
.../CPU/ArmSME/test-outerproduct-f64.mlir | 2 +-
.../CPU/ArmSME/test-transfer-read-2d.mlir | 2 +-
.../CPU/ArmSME/test-transfer-write-2d.mlir | 2 +-
.../Vector/CPU/ArmSME/test-transpose.mlir | 2 +-
.../Dialect/Vector/CPU/ArmSME/tile_fill.mlir | 2 +-
.../Vector/CPU/ArmSME/vector-load-store.mlir | 2 +-
.../Dialect/Vector/CPU/ArmSME/vector-ops.mlir | 2 +-
26 files changed, 173 insertions(+), 122 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
rename mlir/lib/{Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp => Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp} (86%)
create mode 100644 mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
new file mode 100644
index 000000000000000..ce778581b2cee37
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -0,0 +1,26 @@
+//===- ArmSMEToLLVM.h - Convert ArmSME to LLVM dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
+#define MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Create a pass to convert a subset of ArmSME ops to SCF.
+std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 3078d909a8946dd..a25fd17ea923fb5 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -15,6 +15,7 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 626f5f3d19d307e..a0cc05319bb7299 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1241,6 +1241,19 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArmSMEToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+ let summary = "Lower the operations from the ArmSME dialect into the LLVM "
+ "dialect";
+ let constructor = "mlir::createConvertArmSMEToLLVMPass()";
+ let dependentDialects = [
+ "arm_sme::ArmSMEDialect",
+ "LLVM::LLVMDialect"];
+}
+
//===----------------------------------------------------------------------===//
// VectorToLLVM
//===----------------------------------------------------------------------===//
@@ -1280,10 +1293,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
- Option<"armSME", "enable-arm-sme",
- "bool", /*default=*/"false",
- "Enables the use of ArmSME dialect while lowering the vector "
- "dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index fae04513859938b..8ea3e1e57b7caa5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -20,15 +20,6 @@ void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace arm_sme
-/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
-/// intrinsics.
-void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
-
-/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
-/// intrinsics.
-void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);
-
} // namespace mlir
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
similarity index 86%
rename from mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
rename to mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 6ccb652ecbbc29e..66eee98cd23e4be 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -1,24 +1,36 @@
-//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
+//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of ArmSME operations to LLVM intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
using namespace mlir;
-using namespace mlir::arm_sme;
namespace {
@@ -40,11 +52,11 @@ namespace {
/// The 'arm_sme.cast_tile_to_vector' (which models the return) and the
/// 'arith.shli' (which generates the mask) will be folded away after tile
/// allocation and canonization.
-struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
- using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
+struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
+ using ConvertOpToLLVMPattern<arm_sme::ZeroOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(ZeroOp zero, OpAdaptor adaptor,
+ matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
@@ -121,7 +133,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
};
/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
-struct LoadTileSliceToArmSMELowering
+struct LoadTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
@@ -220,7 +232,7 @@ struct LoadTileSliceToArmSMELowering
};
/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
-struct StoreTileSliceToArmSMELowering
+struct StoreTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
@@ -313,7 +325,7 @@ struct StoreTileSliceToArmSMELowering
};
/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
-struct MoveVectorToTileSliceToArmSMELowering
+struct MoveVectorToTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern;
@@ -373,7 +385,7 @@ struct MoveVectorToTileSliceToArmSMELowering
};
/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
-struct MoveTileSliceToVectorArmSMELowering
+struct MoveTileSliceToVectorConversion
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
using ConvertOpToLLVMPattern<
arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
@@ -456,7 +468,8 @@ struct OuterProductOpConversion
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
- // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
+ // [1]
+ // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;
@@ -475,7 +488,7 @@ struct OuterProductOpConversion
};
// TODO: Support CombiningKind::Sub for outer products.
- if (outerProductOp.getKind() != CombiningKind::Add)
+ if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
return outerProductOp.emitError("unsupported kind");
auto resultVectorType = outerProductOp.getResultType();
@@ -522,32 +535,49 @@ struct OuterProductOpConversion
} // namespace
-void mlir::configureArmSMELegalizeForExportTarget(
- LLVMConversionTarget &target) {
- target.addLegalOp<
- scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
- arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
- arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
- arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
- arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
- arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
- arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
- arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
- arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
- arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
- arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
- arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
- arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
- target.addLegalOp<GetTileID>();
- target.addIllegalOp<vector::OuterProductOp>();
-}
+namespace {
+
+struct ConvertArmSMEToLLVMPass
+ : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ arm_sme::ArmSMETypeConverter converter(&getContext(),
+ LowerToLLVMOptions(&getContext()));
+
+ patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
+ MoveVectorToTileSliceConversion, StoreTileSliceConversion,
+ OuterProductOpConversion, ZeroOpConversion>(converter);
+
+ LLVMConversionTarget target(getContext());
+ target.addLegalDialect<arith::ArithDialect>();
+ target.addLegalOp<UnrealizedConversionCastOp>();
+ target.addLegalOp<
+ scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
+ arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
+ arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+ arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+ arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+ arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+ arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+ arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+ arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+ arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+ arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+ arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+ arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+ arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+ arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
+ target.addLegalOp<arm_sme::GetTileID>();
+ target.addIllegalOp<vector::OuterProductOp>();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
-void mlir::populateArmSMELegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<
- LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
- MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
- OuterProductOpConversion, ZeroOpConversion>(converter);
+std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
+ return std::make_unique<ConvertArmSMEToLLVMPass>();
}
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000000..9914f39e17a1a91
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRArmSMEToLLVM
+ ArmSMEToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSMEToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMETransforms
+ MLIRArmSMEDialect
+ MLIRArmSMEUtils
+ MLIRTransforms
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect)
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 822ce5aca255510..c3a2481975040c9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
add_subdirectory(ArmSMEToSCF)
+add_subdirectory(ArmSMEToLLVM)
add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLibm)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4c6d0672d4108ef..ff8e78a668e0f10 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,9 +14,6 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -52,8 +49,6 @@ struct LowerVectorToLLVMPass
registry.insert<arm_neon::ArmNeonDialect>();
if (armSVE)
registry.insert<arm_sve::ArmSVEDialect>();
- if (armSME)
- registry.insert<arm_sme::ArmSMEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86Vector)
@@ -96,7 +91,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
- arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
if (armNeon) {
// TODO: we may or may not want to include in-dialect lowering to
@@ -108,10 +102,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
- if (armSME) {
- configureArmSMELegalizeForExportTarget(target);
- populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
- }
if (amx) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 8f485db4e8438b1..e2407d9f48f7061 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
add_mlir_dialect_library(MLIRArmSMETransforms
ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
- LegalizeForLLVMExport.cpp
TileAllocation.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
index 2c26c62ad42481e..65996e81c42d909 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s
// This test verifies the temporary casts that are emitted when lowering to
// intrinsics to preserve data flow are correct. Canonicalization will remove
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 8fdcf69958244f3..fa62332bc3f5b17 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
// Test conversion of ArmSME ops to LLVM intrinsics.
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index 0f31278eefd1550..ba650b031e6110b 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING
// CHECK-LABEL: @declaration
func.func private @declaration()
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 26cd91bd3e8956a..2378f4234aef1ef 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" \
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm \
// RUN: -allocate-arm-sme-tiles -canonicalize \
// RUN: -allow-unregistered-dialect \
// RUN: | FileCheck %s
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 721ff8f2c3589d4..c288f786f89a947 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
//===----------------------------------------------------------------------===//
// vector.transfer_write
@@ -17,9 +17,8 @@
// CHECK-DAG: %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
// CHECK-DAG: %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32
// CHECK-DAG: "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> ()
-// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -58,9 +57,8 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
// CHECK-NEXT: %[[TILE_SLICE_PLUS_OFF0:.*]] = arith.addi %[[TILE_SLICE]], %[[C123]] : index
// CHECK-NEXT: %[[TILE_SLICE_PLUS_OFF0_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_PLUS_OFF0]] : index to i64
@@ -92,9 +90,8 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
-// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE]], %[[SVL_B]] : index
// CHECK-NEXT: %[[TILE_SLICE_IDX_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_IDX]] : index to i64
@@ -255,9 +252,8 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
-// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
@@ -466,14 +462,8 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
// CHECK-LABEL: @vector_outerproduct_masked_f32
// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
- // CHECK: %[[DIM0_I32:.*]] = arith.index_cast %[[DIM0]] : index to i32
- // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertelement %[[DIM0_I32]], {{.*}} : vector<[4]xi32>
- // CHECK: %[[SPLAT_DIM0:.*]] = llvm.shufflevector %[[INSERT_DIM0]], {{.*}} : vector<[4]xi32>
- // CHECK: %[[LHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM0]] : vector<[4]xi32>
- // CHECK: %[[DIM1_I32:.*]] = arith.index_cast %[[DIM1]] : index to i32
- // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertelement %[[DIM1_I32]], {{.*}} : vector<[4]xi32>
- // CHECK: %[[SPLAT_DIM1:.*]] = llvm.shufflevector %[[INSERT_DIM1]], {{.*}} : vector<[4]xi32>
- // CHECK: %[[RHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM1]] : vector<[4]xi32>
+ // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
+ // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
// CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
%mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
@@ -486,8 +476,8 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
// CHECK-LABEL: @vector_outerproduct_masked_f16
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
- // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
- // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+ // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
+ // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
@@ -499,8 +489,8 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
// CHECK-LABEL: @vector_outerproduct_masked_bf16
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
- // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
- // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+ // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
+ // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
@@ -512,8 +502,8 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
// CHECK-LABEL: @vector_outerproduct_masked_f64
// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
- // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
- // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+ // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
+ // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
%mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
@@ -522,9 +512,9 @@ func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<
// -----
-// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
- // expected-error at +1 {{AXPY operations not supported}}
+ // expected-error at +2 {{AXPY operations not supported}}
+ // expected-error at +1 {{failed to legalize operation 'vector.outerproduct'}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
return %0 : vector<[2]xf64>
}
@@ -655,11 +645,10 @@ func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2
func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
// CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
- // CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64
// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
// CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
// CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
- // CHECK-NEXT: %[[NEW_SLICE:.*]] = llvm.insertelement %[[EL]], %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32>
+ // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32>
// CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[ROW]] : index to i32
// CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
%new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32>
@@ -846,11 +835,10 @@ func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) ->
func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col: index) -> i32 {
// CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
- // CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64
// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
// CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
// CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
- // CHECK-NEXT: %[[EL:.*]] = llvm.extractelement %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32>
+ // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32>
%el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32>
return %el : i32
}
@@ -860,7 +848,7 @@ func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col:
// CHECK-LABEL: @vector_extract_element_i8
func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 {
// CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[16]xi8>
+ // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8>
%el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
return %el : i8
}
@@ -870,7 +858,7 @@ func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %
// CHECK-LABEL: @vector_extract_element_i16
func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 {
// CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xi16>
+ // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16>
%el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
return %el : i16
}
@@ -880,7 +868,7 @@ func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %
// CHECK-LABEL: @vector_extract_element_i64
func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %col: index) -> i64 {
// CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xi64>
+ // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64>
%el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64>
return %el : i64
}
@@ -890,7 +878,7 @@ func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %
// CHECK-LABEL: @vector_extract_element_i128
func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 {
// CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[1]xi128>
+ // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128>
%el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
return %el : i128
}
@@ -900,7 +888,7 @@ func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index,
// CHECK-LABEL: @vector_extract_element_f16
func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %col: index) -> f16 {
// CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xf16>
+ // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16>
%el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16>
return %el : f16
}
@@ -910,7 +898,7 @@ func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %
// CHECK-LABEL: @vector_extract_element_bf16
func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index, %col: index) -> bf16 {
// CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xbf16>
+ // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16>
%el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16>
return %el : bf16
}
@@ -920,7 +908,7 @@ func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index,
// CHECK-LABEL: @vector_extract_element_f32
func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 {
// CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[4]xf32>
+ // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32>
%el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
return %el : f32
}
@@ -930,7 +918,7 @@ func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %
// CHECK-LABEL: @vector_extract_element_f64
func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 {
// CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xf64>
+ // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64>
%el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
return %el : f64
}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index efe4da7d3c50c6f..18b95cf2fdf843c 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -5,7 +5,7 @@
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \
// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=entry -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index ab74f0100474263..f189fd97d66cde7 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -4,7 +4,7 @@
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
-// RUN: -convert-vector-to-llvm=enable-arm-sme \
+// RUN: -convert-arm-sme-to-llvm \
// RUN: -convert-vector-to-llvm=enable-arm-sve \
// RUN: -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
index 32e7e6b79ce09b9..59b4a7e6a52f9b0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 44cf23f41b63254..0c186cc373a3b32 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index f1ecf768ebe83db..442a70cacd66508 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 5c907bb1675e462..74b51dcc9b4df3a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
// DEFINE: -march=aarch64 -mattr=+sve,+sme-f64f64 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index ccc08289570afc5..82f38b4dbfa9d1f 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index f35f83dcec0daa2..3b218aefcd415ff 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index 39b5ef2ade4b0c0..e2cbe735fa4ff06 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index baf2046722b9e0c..6e33a421bf799a2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \
// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 8878dca8bdcb6b1..961bb274d1e3352 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index a890aaa6f309d15..25ef1799e63adb1 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,7 +1,7 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-vector-to-llvm="enable-arm-sme" \
+// DEFINE: -convert-arm-sme-to-llvm \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
>From 58f680848e68b8cef6f1e0489106ac80ac64ffc4 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 21 Nov 2023 10:37:56 +0000
Subject: [PATCH 2/2] Fixups
---
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h | 14 ++++-
mlir/include/mlir/Conversion/Passes.td | 7 ++-
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 63 ++++++++++---------
.../Dialect/ArmSME/vector-ops-to-llvm.mlir | 7 +--
4 files changed, 54 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index ce778581b2cee37..fe851d17867dff5 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -11,6 +11,8 @@
#include <memory>
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+
namespace mlir {
class Pass;
class RewritePatternSet;
@@ -18,9 +20,19 @@ class RewritePatternSet;
#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"
-/// Create a pass to convert a subset of ArmSME ops to SCF.
+using arm_sme::ArmSMETypeConverter;
+
+/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
+
+/// Populate the given list with patterns that convert from the ArmSME dialect
+/// to LLVM intrinsics.
+void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+ RewritePatternSet &patterns);
+
} // namespace mlir
#endif // MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a0cc05319bb7299..06756ff3df0bb3b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1249,9 +1249,10 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
let summary = "Lower the operations from the ArmSME dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertArmSMEToLLVMPass()";
- let dependentDialects = [
- "arm_sme::ArmSMEDialect",
- "LLVM::LLVMDialect"];
+ let dependentDialects = [
+ "arm_sme::ArmSMEDialect",
+ "LLVM::LLVMDialect"
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 66eee98cd23e4be..e409dc57fb020e2 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -16,8 +16,6 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -540,35 +538,13 @@ namespace {
struct ConvertArmSMEToLLVMPass
: public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
void runOnOperation() override {
+ LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
- arm_sme::ArmSMETypeConverter converter(&getContext(),
- LowerToLLVMOptions(&getContext()));
-
- patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
- MoveVectorToTileSliceConversion, StoreTileSliceConversion,
- OuterProductOpConversion, ZeroOpConversion>(converter);
+ ArmSMETypeConverter converter(&getContext(),
+ LowerToLLVMOptions(&getContext()));
- LLVMConversionTarget target(getContext());
- target.addLegalDialect<arith::ArithDialect>();
- target.addLegalOp<UnrealizedConversionCastOp>();
- target.addLegalOp<
- scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
- arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
- arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
- arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
- arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
- arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
- arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
- arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
- arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
- arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
- arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
- arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
- arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
- target.addLegalOp<arm_sme::GetTileID>();
- target.addIllegalOp<vector::OuterProductOp>();
+ configureArmSMEToLLVMConversionLegality(target);
+ populateArmSMEToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -578,6 +554,35 @@ struct ConvertArmSMEToLLVMPass
} // namespace
+void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
+ target.addIllegalDialect<arm_sme::ArmSMEDialect>();
+ target.addLegalOp<
+ arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
+ arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
+ arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
+ arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
+ arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
+ arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
+ arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
+ arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
+ arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
+ arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
+ arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
+ arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
+ arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
+ arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
+ arm_sme::aarch64_sme_mopa>();
+ target.addLegalDialect<arith::ArithDialect>();
+ target.addLegalOp<UnrealizedConversionCastOp>();
+}
+
+void mlir::populateArmSMEToLLVMConversionPatterns(
+ ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
+ MoveVectorToTileSliceConversion, StoreTileSliceConversion,
+ OuterProductOpConversion, ZeroOpConversion>(converter);
+}
+
std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
return std::make_unique<ConvertArmSMEToLLVMPass>();
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index c288f786f89a947..77ac071ef67de9a 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -513,8 +513,7 @@ func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<
// -----
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
- // expected-error at +2 {{AXPY operations not supported}}
- // expected-error at +1 {{failed to legalize operation 'vector.outerproduct'}}
+ // expected-error at +1 {{AXPY operations not supported}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
return %0 : vector<[2]xf64>
}
@@ -522,6 +521,7 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f
// -----
func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
+ // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
// expected-error at +1 {{unsupported type}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
@@ -530,7 +530,6 @@ func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : v
// -----
func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
- // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error at +1 {{unsupported kind}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
@@ -539,7 +538,7 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v
// -----
func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
- // expected-error at +1 {{failed to legalize operation 'vector.outerproduct'}}
+ // CHECK: vector.outerproduct
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
More information about the Mlir-commits
mailing list