[Mlir-commits] [mlir] 7a2fdc6 - [mlir][ArmSME] Dialect and Intrinsic Op Definition

Prabhdeep Singh Soni llvmlistbot at llvm.org
Wed Jun 14 14:12:23 PDT 2023


Author: Frank (Fang) Gao
Date: 2023-06-14T17:11:49-04:00
New Revision: 7a2fdc685f609730af29e5e969843e9eb71a184c

URL: https://github.com/llvm/llvm-project/commit/7a2fdc685f609730af29e5e969843e9eb71a184c
DIFF: https://github.com/llvm/llvm-project/commit/7a2fdc685f609730af29e5e969843e9eb71a184c.diff

LOG: [mlir][ArmSME] Dialect and Intrinsic Op Definition

This patch creates the ArmSME dialect, and provides the intrinsic op
definition necessary for lowering to LLVM IR.

This will cover most instructions interacting with the ZA tile register,
not covering SME2 instructions.

Source: https://developer.arm.com/documentation/ddi0616/latest

Reviewed By: awarzynski, c-rhodes

Differential Revision: https://reviews.llvm.org/D152878

Added: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
    mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
    mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h
    mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
    mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt
    mlir/test/Target/LLVMIR/arm-sme.mlir

Modified: 
    mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/include/mlir/Target/LLVMIR/Dialect/All.h
    mlir/lib/Dialect/ArmSME/CMakeLists.txt
    mlir/lib/Target/LLVMIR/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
    mlir/test/mlir-opt/commandline.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt
index e31af32661164..9f57627c321fb 100644
--- a/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt
@@ -1 +1,2 @@
+add_subdirectory(IR)
 add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
new file mode 100644
index 0000000000000..a69d32610357c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -0,0 +1,27 @@
+//===- ArmSMEDialect.h - MLIR Dialect for Arm SME ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSME in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_IR_ARMSME_H
+#define MLIR_DIALECT_ARMSME_IR_ARMSME_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc"
+
+#endif // MLIR_DIALECT_ARMSME_IR_ARMSME_H

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
new file mode 100644
index 0000000000000..45a0ad77129c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -0,0 +1,122 @@
+//===-- ArmSME.td - ArmSME dialect operation definitions ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ArmSME dialect and contains intrinsic ops to lower to
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSME_OPS
+#define ARMSME_OPS
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSME dialect definition
+//===----------------------------------------------------------------------===//
+
+def ArmSME_Dialect : Dialect {
+  let name = "arm_sme";
+  let cppNamespace = "::mlir::arm_sme";
+  let summary = "Basic dialect to target Arm SME architectures";
+  let description = [{
+    This dialect contains the definitions necessary to target Arm SME
+    scalable matrix operations.
+
+    Sources:
+    https://developer.arm.com/documentation/ddi0616
+    https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// ArmSME Intrinsic op definitions
+//===----------------------------------------------------------------------===//
+
+def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>;
+def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
+                                              [I8, I16, BF16, F16, F32, F64]>;
+def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
+
+class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+                    list<Trait> traits = []>
+    : LLVM_IntrOpBase<
+          /*Dialect dialect=*/ArmSME_Dialect,
+          /*string opName=*/"intr." # mnemonic,
+          /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
+          /*list<int> overloadedResults=*/[],
+          /*list<int> overloadedOperands=*/overloadedOperands,
+          /*list<Trait> traits=*/traits,
+          /*int numResults=*/0>;
+
+// Zero
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
+                            Arguments<(ins Arg<I32, "Tile mask">)>;
+
+// MOP's
+class ArmSME_IntrMopOverloadedOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic, [4]>,
+      Arguments<(ins Arg<I32, "Virtual tile ID">,
+                 Arg<MOPPredicate, "LHS predicate">,
+                 Arg<MOPPredicate, "RHS predicate">,
+                 Arg<MOPVector, "LHS vector operand">,
+                 Arg<MOPVector, "RHS vector operand">)>;
+
+def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
+def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
+def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
+def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
+def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
+def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
+def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
+def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
+def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
+def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
+def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
+def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+
+// Loads
+class ArmSME_IntrLoadOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic>,
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+                 Arg<LLVM_AnyPointer, "Load address", [MemRead]>,
+                 Arg<I32, "Virtual tile ID">,
+                 Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
+def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
+def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
+def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
+def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
+def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
+def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
+def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
+def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
+def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
+
+// Stores
+class ArmSME_IntrStoreOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic>,
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
+                 Arg<I32, "Virtual tile ID">,
+                 Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
+def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
+def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
+def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
+def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
+def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
+def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
+def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
+def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
+def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
+
+#endif // ARMSME_OPS

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..d20ee65e62e7d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_dialect(ArmSME arm_sme ArmSME)
+add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
+
+set(LLVM_TARGET_DEFINITIONS ArmSME.td)
+mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRArmSMEConversionsIncGen)

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0baaa7b5d5315..db15dff136cd1 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -117,6 +118,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   pdl_interp::PDLInterpDialect,
                   quant::QuantizationDialect,
                   spirv::SPIRVDialect,
+                  arm_sme::ArmSMEDialect,
                   arm_sve::ArmSVEDialect,
                   vector::VectorDialect,
                   NVVM::NVVMDialect,

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index cd7f76ff669a4..65c1c515d4443 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -35,6 +36,7 @@ class DialectRegistry;
 static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
   registerArmNeonDialectTranslation(registry);
   registerAMXDialectTranslation(registry);
+  registerArmSMEDialectTranslation(registry);
   registerArmSVEDialectTranslation(registry);
   registerBuiltinDialectTranslation(registry);
   registerGPUDialectTranslation(registry);

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h
new file mode 100644
index 0000000000000..205d9b6326032
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h
@@ -0,0 +1,31 @@
+//=======- ArmSMEToLLVMIRTranslation.h - ArmSME to LLVM IR --*- C++ -*-=======//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for ArmSME dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Register the ArmSME dialect and the translation from it to the LLVM IR in
+/// the given registry;
+void registerArmSMEDialectTranslation(DialectRegistry &registry);
+
+/// Register the ArmSME dialect and the translation from it in the registry
+/// associated with the given context.
+void registerArmSMEDialectTranslation(MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H

diff  --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
index e31af32661164..9f57627c321fb 100644
--- a/mlir/lib/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
@@ -1 +1,2 @@
+add_subdirectory(IR)
 add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
new file mode 100644
index 0000000000000..7f5aa61aa327e
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -0,0 +1,36 @@
+//===- ArmSMEDialect.cpp - MLIR ArmSME dialect implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the ArmSME dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+//===----------------------------------------------------------------------===//
+// Tablegen Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
+
+void ArmSMEDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+      >();
+}

diff  --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..afe69de713306
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRArmSMEDialect
+  ArmSME.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME
+
+  DEPENDS
+  MLIRArmSMEIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRSideEffectInterfaces
+)

diff  --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index f2d95949a9740..868ccbbb10620 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -46,6 +46,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
 
   LINK_LIBS PUBLIC
   MLIRArmNeonToLLVMIRTranslation
+  MLIRArmSMEToLLVMIRTranslation
   MLIRArmSVEToLLVMIRTranslation
   MLIRAMXToLLVMIRTranslation
   MLIRBuiltinToLLVMIRTranslation

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
new file mode 100644
index 0000000000000..1b57b9979af28
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
@@ -0,0 +1,56 @@
+//======- ArmSMEToLLVMIRTranslation.cpp - Translate ArmSME to LLVM IR -=======//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the ArmSME dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the ArmSME dialect to LLVM IR.
+class ArmSMEDialectLLVMIRTranslationInterface
+    : public LLVMTranslationDialectInterface {
+public:
+  using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+  /// Translates the given operation to LLVM IR using the provided IR builder
+  /// and saving the state in `moduleTranslation`.
+  LogicalResult
+  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const final {
+    Operation &opInst = *op;
+#include "mlir/Dialect/ArmSME/IR/ArmSMEConversions.inc"
+
+    return failure();
+  }
+};
+} // namespace
+
+void mlir::registerArmSMEDialectTranslation(DialectRegistry &registry) {
+  registry.insert<arm_sme::ArmSMEDialect>();
+  registry.addExtension(+[](MLIRContext *ctx, arm_sme::ArmSMEDialect *dialect) {
+    dialect->addInterfaces<ArmSMEDialectLLVMIRTranslationInterface>();
+  });
+}
+
+void mlir::registerArmSMEDialectTranslation(MLIRContext &context) {
+  DialectRegistry registry;
+  registerArmSMEDialectTranslation(registry);
+  context.appendDialectRegistry(registry);
+}

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt
new file mode 100644
index 0000000000000..d34cebf487271
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_translation_library(MLIRArmSMEToLLVMIRTranslation
+  ArmSMEToLLVMIRTranslation.cpp
+
+  DEPENDS
+  MLIRArmSMEConversionsIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRArmSMEDialect
+  MLIRLLVMDialect
+  MLIRSupport
+  MLIRTargetLLVMIRExport
+  )

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index f27810feed824..fb0e5cd0649f6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(ArmNeon)
+add_subdirectory(ArmSME)
 add_subdirectory(ArmSVE)
 add_subdirectory(AMX)
 add_subdirectory(Builtin)

diff  --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
new file mode 100644
index 0000000000000..096d6194071cf
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -0,0 +1,225 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @arm_sme_zero
+llvm.func @arm_sme_zero() {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.zero(i32 0)
+  "arm_sme.intr.zero"(%c0) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_fmopa
+llvm.func @arm_sme_fmopa(%nxv2f64 : vector<[2]xf64>,
+                         %nxv4f32 : vector<[4]xf32>,
+                         %nxv8f16 : vector<[8]xf16>,
+                         %nxv8bf16: vector<[8]xbf16>,
+                         %nxv2i1  : vector<[2]xi1>,
+                         %nxv4i1  : vector<[4]xi1>,
+                         %nxv8i1  : vector<[8]xi1>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.mopa.nxv2f64
+  "arm_sme.intr.mopa"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
+    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mopa.nxv4f32
+  "arm_sme.intr.mopa"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
+    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8f16
+  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8bf16
+  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_imopa
+llvm.func @arm_sme_imopa(%nxv8i16 : vector<[8]xi16>,
+                         %nxv16i8 : vector<[16]xi8>,
+                         %nxv8i1  : vector<[8]xi1>,
+                         %nxv16i1 : vector<[16]xi1>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv8i16
+  "arm_sme.intr.smopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv8i16
+  "arm_sme.intr.umopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv8i16
+  "arm_sme.intr.sumopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv8i16
+  "arm_sme.intr.usmopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv16i8
+  "arm_sme.intr.smopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv16i8
+  "arm_sme.intr.umopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv16i8
+  "arm_sme.intr.sumopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv16i8
+  "arm_sme.intr.usmopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_fmops
+llvm.func @arm_sme_fmops(%nxv2f64 : vector<[2]xf64>,
+                         %nxv4f32 : vector<[4]xf32>,
+                         %nxv8f16 : vector<[8]xf16>,
+                         %nxv8bf16: vector<[8]xbf16>,
+                         %nxv2i1  : vector<[2]xi1>,
+                         %nxv4i1  : vector<[4]xi1>,
+                         %nxv8i1  : vector<[8]xi1>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.mops.nxv2f64
+  "arm_sme.intr.mops"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
+    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mops.nxv4f32
+  "arm_sme.intr.mops"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
+    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8f16
+  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8bf16
+  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_imops
+llvm.func @arm_sme_imops(%nxv8i16 : vector<[8]xi16>,
+                         %nxv16i8 : vector<[16]xi8>,
+                         %nxv8i1  : vector<[8]xi1>,
+                         %nxv16i1 : vector<[16]xi1>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv8i16
+  "arm_sme.intr.smops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv8i16
+  "arm_sme.intr.umops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv8i16
+  "arm_sme.intr.sumops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv8i16
+  "arm_sme.intr.usmops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv16i8
+  "arm_sme.intr.smops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv16i8
+  "arm_sme.intr.umops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv16i8
+  "arm_sme.intr.sumops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv16i8
+  "arm_sme.intr.usmops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load
+llvm.func @arm_sme_load(%nxv1i1  : vector<[1]xi1>,
+                        %nxv2i1  : vector<[2]xi1>,
+                        %nxv4i1  : vector<[4]xi1>,
+                        %nxv8i1  : vector<[8]xi1>,
+                        %nxv16i1 : vector<[16]xi1>,
+                        %p8      : !llvm.ptr<i8>,
+                        %p16     : !llvm.ptr<i16>,
+                        %p32     : !llvm.ptr<i32>,
+                        %p64     : !llvm.ptr<i64>,
+                        %p128    : !llvm.ptr<i128>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.ld1q.horiz
+  "arm_sme.intr.ld1q.horiz"(%nxv1i1, %p128, %c0, %c0) :
+              (vector<[1]xi1>, !llvm.ptr<i128>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1d.horiz
+  "arm_sme.intr.ld1d.horiz"(%nxv2i1, %p64, %c0, %c0) :
+              (vector<[2]xi1>, !llvm.ptr<i64>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1w.horiz
+  "arm_sme.intr.ld1w.horiz"(%nxv4i1, %p32, %c0, %c0) :
+              (vector<[4]xi1>, !llvm.ptr<i32>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1h.horiz
+  "arm_sme.intr.ld1h.horiz"(%nxv8i1, %p16, %c0, %c0) :
+              (vector<[8]xi1>, !llvm.ptr<i16>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1b.horiz
+  "arm_sme.intr.ld1b.horiz"(%nxv16i1, %p8, %c0, %c0) :
+              (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1q.vert
+  "arm_sme.intr.ld1q.vert"(%nxv1i1, %p128, %c0, %c0) :
+              (vector<[1]xi1>, !llvm.ptr<i128>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1d.vert
+  "arm_sme.intr.ld1d.vert"(%nxv2i1, %p64, %c0, %c0) :
+              (vector<[2]xi1>, !llvm.ptr<i64>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1w.vert
+  "arm_sme.intr.ld1w.vert"(%nxv4i1, %p32, %c0, %c0) :
+              (vector<[4]xi1>, !llvm.ptr<i32>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1h.vert
+  "arm_sme.intr.ld1h.vert"(%nxv8i1, %p16, %c0, %c0) :
+              (vector<[8]xi1>, !llvm.ptr<i16>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1b.vert
+  "arm_sme.intr.ld1b.vert"(%nxv16i1, %p8, %c0, %c0) :
+              (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store
+llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
+                         %nxv2i1  : vector<[2]xi1>,
+                         %nxv4i1  : vector<[4]xi1>,
+                         %nxv8i1  : vector<[8]xi1>,
+                         %nxv16i1 : vector<[16]xi1>,
+                         %p8      : !llvm.ptr<i8>,
+                         %p16     : !llvm.ptr<i16>,
+                         %p32     : !llvm.ptr<i32>,
+                         %p64     : !llvm.ptr<i64>,
+                         %p128    : !llvm.ptr<i128>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.st1q.horiz
+  "arm_sme.intr.st1q.horiz"(%nxv1i1, %p128, %c0, %c0) :
+              (vector<[1]xi1>, !llvm.ptr<i128>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1d.horiz
+  "arm_sme.intr.st1d.horiz"(%nxv2i1, %p64, %c0, %c0) :
+              (vector<[2]xi1>, !llvm.ptr<i64>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1w.horiz
+  "arm_sme.intr.st1w.horiz"(%nxv4i1, %p32, %c0, %c0) :
+              (vector<[4]xi1>, !llvm.ptr<i32>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1h.horiz
+  "arm_sme.intr.st1h.horiz"(%nxv8i1, %p16, %c0, %c0) :
+              (vector<[8]xi1>, !llvm.ptr<i16>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1b.horiz
+  "arm_sme.intr.st1b.horiz"(%nxv16i1, %p8, %c0, %c0) :
+              (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1q.vert
+  "arm_sme.intr.st1q.vert"(%nxv1i1, %p128, %c0, %c0) :
+              (vector<[1]xi1>, !llvm.ptr<i128>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1d.vert
+  "arm_sme.intr.st1d.vert"(%nxv2i1, %p64, %c0, %c0) :
+              (vector<[2]xi1>, !llvm.ptr<i64>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1w.vert
+  "arm_sme.intr.st1w.vert"(%nxv4i1, %p32, %c0, %c0) :
+              (vector<[4]xi1>, !llvm.ptr<i32>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1h.vert
+  "arm_sme.intr.st1h.vert"(%nxv8i1, %p16, %c0, %c0) :
+              (vector<[8]xi1>, !llvm.ptr<i16>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1b.vert
+  "arm_sme.intr.st1b.vert"(%nxv16i1, %p8, %c0, %c0) :
+              (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+  llvm.return
+}

diff  --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 1b2ab3f47b81a..7400f46dd6f0c 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -6,6 +6,7 @@
 // CHECK-SAME: amx
 // CHECK-SAME: arith
 // CHECK-SAME: arm_neon
+// CHECK-SAME: arm_sme
 // CHECK-SAME: arm_sve
 // CHECK-SAME: async
 // CHECK-SAME: bufferization


        


More information about the Mlir-commits mailing list