[Mlir-commits] [mlir] 65305ae - [mlir][ArmSME] Insert intrinsics to enable/disable ZA
Cullen Rhodes
llvmlistbot at llvm.org
Fri Jun 16 02:45:49 PDT 2023
Author: Cullen Rhodes
Date: 2023-06-16T09:40:48Z
New Revision: 65305aeab99ad8ea09dd85e28a41c657152a08fb
URL: https://github.com/llvm/llvm-project/commit/65305aeab99ad8ea09dd85e28a41c657152a08fb
DIFF: https://github.com/llvm/llvm-project/commit/65305aeab99ad8ea09dd85e28a41c657152a08fb.diff
LOG: [mlir][ArmSME] Insert intrinsics to enable/disable ZA
This patch adds two LLVM intrinsics to the ArmSME dialect:
* llvm.aarch64.sme.za.enable
* llvm.aarch64.sme.za.disable
for enabling the ZA storage array [1], as well as patterns for inserting
them during legalization to LLVM at the start and end of functions if
the function has the 'arm_za' attribute (D152695).
In the future ZA should probably be automatically enabled/disabled when
lowering from vector to SME, but this should be sufficient for now at
least until we have patterns lowering to SME instructions that use ZA.
N.B. The backend function attribute 'aarch64_pstate_za_new' can be used
manage ZA state (as was originally tried in D152694), but it emits calls
to the following SME support routines [2] for the lazy-save mechanism
[3]:
* __arm_tpidr2_restore
* __arm_tpidr2_save
These will soon be added to compiler-rt but there's currently no public
implementation, and using this attribute would introduce an MLIR
dependency on compiler-rt. Furthermore, this mechanism is for routines
with ZA enabled calling other routines with it also enabled. We can
choose not to enable ZA in the compiler when this is case.
Depends on D152695
[1] https://developer.arm.com/documentation/ddi0616/aa
[2] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#sme-support-routines
[3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#the-za-lazy-saving-scheme
Reviewed By: awarzynski, dcaballe
Differential Revision: https://reviews.llvm.org/D153050
Added:
mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSME/enable-arm-za.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
mlir/test/Target/LLVMIR/arm-sme.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 9e39137d970a8..7a6c6c7fa8f3b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1092,6 +1092,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm", "ModuleOp"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+ Option<"armSME", "enable-arm-sme",
+ "bool", /*default=*/"false",
+ "Enables the use of ArmSME dialect while lowering the vector "
+ "dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 45a0ad77129c6..d0072b6149cd9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -119,4 +119,7 @@ def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
+def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
+def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
+
#endif // ARMSME_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
new file mode 100644
index 0000000000000..c8d9e6704ac57
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -0,0 +1,29 @@
+//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
+#define MLIR_DIALECT_ARMSME_TRANSFORMS_H
+
+namespace mlir {
+
+class LLVMConversionTarget;
+class LLVMTypeConverter;
+class RewritePatternSet;
+
+/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
+/// intrinsics.
+void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+
+/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
+/// intrinsics.
+void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index b7fadeab09211..e4a5528c29892 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -15,6 +15,8 @@ add_mlir_conversion_library(MLIRVectorToLLVM
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRArmNeonDialect
+ MLIRArmSMEDialect
+ MLIRArmSMETransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 3f1b107f6f8e0..acc4244ce9bb8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,8 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -49,6 +51,8 @@ struct LowerVectorToLLVMPass
registry.insert<arm_neon::ArmNeonDialect>();
if (armSVE)
registry.insert<arm_sve::ArmSVEDialect>();
+ if (armSME)
+ registry.insert<arm_sme::ArmSMEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86Vector)
@@ -102,6 +106,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
+ if (armSME) {
+ configureArmSMELegalizeForExportTarget(target);
+ populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
+ }
if (amx) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 2b616b5effcc5..efcb17fd2076a 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
+ LegalizeForLLVMExport.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -8,6 +9,8 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRArmSMETransformsIncGen
LINK_LIBS PUBLIC
+ MLIRArmSMEDialect
MLIRFuncDialect
+ MLIRLLVMCommonConversion
MLIRPass
)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
new file mode 100644
index 0000000000000..3fe9e78a85a3d
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -0,0 +1,78 @@
+//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
+/// ops to enable the ZA storage array.
+struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(func::FuncOp op,
+ PatternRewriter &rewriter) const final {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToStart(&op.front());
+ rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+
+/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
+/// disable the ZA storage array.
+struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(func::ReturnOp op,
+ PatternRewriter &rewriter) const final {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+ rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateArmSMELegalizeForLLVMExportPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
+}
+
+void mlir::configureArmSMELegalizeForExportTarget(
+ LLVMConversionTarget &target) {
+ target.addLegalOp<arm_sme::aarch64_sme_za_enable,
+ arm_sme::aarch64_sme_za_disable>();
+
+ // Mark 'func.func' ops as legal if either:
+ // 1. no 'arm_za' function attribute is present.
+ // 2. the 'arm_za' function attribute is present and the first op in the
+ // function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
+ auto firstOp = funcOp.getBody().front().begin();
+ return !funcOp->hasAttr("arm_za") ||
+ isa<arm_sme::aarch64_sme_za_enable>(firstOp);
+ });
+
+ // Mark 'func.return' ops as legal if either:
+ // 1. no 'arm_za' function attribute is present.
+ // 2. the 'arm_za' function attribute is present and there's a preceding
+ // 'arm_sme::aarch64_sme_za_disable' intrinsic.
+ target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
+ bool hasDisableZA = false;
+ auto funcOp = returnOp->getParentOp();
+ funcOp->walk<WalkOrder::PreOrder>(
+ [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
+ return !funcOp->hasAttr("arm_za") || hasDisableZA;
+ });
+}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
new file mode 100644
index 0000000000000..ae0bbdc6d1894
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
+
+// CHECK-LABEL: @arm_za
+func.func @arm_za() {
+ // ENABLE-ZA: arm_sme.intr.za.enable
+ // ENABLE-ZA-NEXT: arm_sme.intr.za.disable
+ // ENABLE-ZA-NEXT: return
+ // DISABLE-ZA-NOT: arm_sme.intr.za.enable
+ // DISABLE-ZA-NOT: arm_sme.intr.za.disable
+ // NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable
+ // NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 096d6194071cf..453a887c6d7fd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -223,3 +223,14 @@ llvm.func @arm_sme_store(%nxv1i1 : vector<[1]xi1>,
(vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_toggle_za
+llvm.func @arm_sme_toggle_za() {
+ // CHECK: call void @llvm.aarch64.sme.za.enable()
+ "arm_sme.intr.za.enable"() : () -> ()
+ // CHECK: call void @llvm.aarch64.sme.za.disable()
+ "arm_sme.intr.za.disable"() : () -> ()
+ llvm.return
+}
More information about the Mlir-commits
mailing list