[Mlir-commits] [mlir] [mlir][ArmSME] Split the Op definition (nfc) (PR #67985)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Oct 2 06:44:52 PDT 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/67985

Move the definitions of LLVM intrinsic Ops for ArmSME into a dedicated
file. To facilitate this, the dialect definition together with various
shared definitions are moved to ArmSMEBase.td.

This change will allow us to refactor the ArmSME dialect documentation.
In particular, we will be able to categorise the Ops into "regular"  and
"intrinsic" ops. Also, it will be easier to add some custom
documentation as opposed to relying on auto-generated docs that simply
list the available Ops.

The documentation will be updated in a forthcoming patch. Only
non-functional changes.


>From 127403577d2366b38d5b1b536ec78b85e102b11b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 2 Oct 2023 09:56:55 +0000
Subject: [PATCH] [mlir][ArmSME] Split the Op definition (nfc)

Move the definitions of LLVM intrinsic Ops for ArmSME into a dedicated
file. To facilitate this, the dialect definition together with various
shared definitions are moved to ArmSMEBase.td.

This change will allow us to refactor the ArmSME dialect documentation.
In particular, we will be able to categorise the Ops into "regular"  and
"intrinsic" ops. Also, it will be easier to add some custom
documentation as opposed to relying on auto-generated docs that simply
list the available Ops.

The documentation will be updated in a forthcoming patch. Only
non-functional changes.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h  |   3 +
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 147 +-----------------
 .../mlir/Dialect/ArmSME/IR/ArmSMEBase.td      |  52 +++++++
 .../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td   | 137 ++++++++++++++++
 .../mlir/Dialect/ArmSME/IR/CMakeLists.txt     |   7 +
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |   6 +
 .../Transforms/LegalizeForLLVMExport.cpp      |   8 +-
 .../ArmSME/ArmSMEToLLVMIRTranslation.cpp      |   1 +
 8 files changed, 211 insertions(+), 150 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td
 create mode 100644 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index f947fc8fe1631b8..dcb18be4f05ead4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -31,4 +31,7 @@
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc"
 
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.h.inc"
+
 #endif // MLIR_DIALECT_ARMSME_IR_ARMSME_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 66a432ea1b171e0..4d89f62e28e0ac4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,33 +14,13 @@
 #ifndef ARMSME_OPS
 #define ARMSME_OPS
 
+include "ArmSMEBase.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 
-//===----------------------------------------------------------------------===//
-// ArmSME dialect definition
-//===----------------------------------------------------------------------===//
-
-def ArmSME_Dialect : Dialect {
-  let name = "arm_sme";
-  let cppNamespace = "::mlir::arm_sme";
-  let summary = "Basic dialect to target Arm SME architectures";
-  let description = [{
-    This dialect contains the definitions necessary to target Arm SME
-    scalable matrix operations.
-
-    Sources:
-    https://developer.arm.com/documentation/ddi0616
-    https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
-  }];
-  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
-                           "memref::MemRefDialect"];
-  let useDefaultAttributePrinterParser = 1;
-}
-
 //===----------------------------------------------------------------------===//
 // ArmSME type definitions
 //===----------------------------------------------------------------------===//
@@ -65,12 +45,6 @@ def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                          nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
 
-def SVEVector : ScalableVectorOfRankAndLengthAndType<
-  [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
-
-def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
-  [1], [16, 8, 4, 2, 1], [I1]>;
-
 // A type constraint that verifies the bitwidth of the scalar integer returned
 // from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
 def TileElementWidthMatchesTileID : TypesMatchWith<
@@ -538,123 +512,4 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// ArmSME Intrinsic op definitions
-//===----------------------------------------------------------------------===//
-
-def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>;
-def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
-                                              [I8, I16, BF16, F16, F32, F64]>;
-def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
-
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
-                    list<Trait> traits = [], int numResults = 0,
-                    list<int> overloadedResults = []>
-    : LLVM_IntrOpBase<
-          /*Dialect dialect=*/ArmSME_Dialect,
-          /*string opName=*/"intr." # mnemonic,
-          /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
-          /*list<int> overloadedResults=*/overloadedResults,
-          /*list<int> overloadedOperands=*/overloadedOperands,
-          /*list<Trait> traits=*/traits,
-          /*int numResults=*/numResults>;
-
-// Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
-                            Arguments<(ins Arg<I32, "Tile mask">)>;
-
-// MOP's
-class ArmSME_IntrMopOverloadedOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic, [4]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">,
-                 Arg<MOPPredicate, "LHS predicate">,
-                 Arg<MOPPredicate, "RHS predicate">,
-                 Arg<MOPVector, "LHS vector operand">,
-                 Arg<MOPVector, "RHS vector operand">)>;
-
-def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
-def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
-def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
-def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
-def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
-def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
-def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
-def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
-def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
-def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
-def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
-def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
-
-// Loads
-class ArmSME_IntrLoadOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
-      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
-                 Arg<LLVM_AnyPointer, "Load address">,
-                 Arg<I32, "Virtual tile ID">,
-                 Arg<I32, "Tile slice">)>;
-
-def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
-def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
-def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
-def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
-def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
-def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
-def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
-def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
-def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
-def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
-
-// Stores
-class ArmSME_IntrStoreOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
-      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
-                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
-                 Arg<I32, "Virtual tile ID">,
-                 Arg<I32, "Tile slice">)>;
-
-def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
-def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
-def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
-def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
-def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
-def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
-def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
-def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
-def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
-def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
-
-def LLVM_aarch64_sme_str
-    : ArmSME_IntrOp<"str">,
-      Arguments<(ins Arg<I32, "Index">,
-                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
-
-// Vector to tile slice
-class LLVM_aarch64_sme_write<string direction>
-    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
-                    [AllShapesMatch<["pg", "vector"]>]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">,
-                     Arg<I32, "Tile slice">,
-                     Arg<SVEPredicate, "Vector predicate">:$pg,
-                     Arg<SVEVector, "Vector operand">:$vector)>;
-
-// Tile slice to vector
-class LLVM_aarch64_sme_read<string direction>
-    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
-                    [AllShapesMatch<["vector", "pg", "res"]>,
-                     AllElementTypesMatch<["vector", "res"]>],
-                    /*numResults=*/1, /*overloadedResults=*/[0]>,
-      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
-                     Arg<SVEPredicate, "Vector predicate">:$pg,
-                     Arg<I32, "Virtual tile ID">,
-                     Arg<I32, "Tile slice">)>;
-
-def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
-def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
-
-def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
-def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
-
-def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
-def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
-
 #endif // ARMSME_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td
new file mode 100644
index 000000000000000..36f3ae44ff72a2e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td
@@ -0,0 +1,52 @@
+//===-- ArmSMESMpBase.td - ArmSME dialect definitions -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definition of the ArmSME dialect as well as some
+// shared definitions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSME_OP_BASE
+#define ARMSME_OP_BASE
+
+include "mlir/IR/DialectBase.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSME Dialect
+//===----------------------------------------------------------------------===//
+
+def ArmSME_Dialect : Dialect {
+  let name = "arm_sme";
+  let cppNamespace = "::mlir::arm_sme";
+  let summary = "Basic dialect to target Arm SME architectures";
+  let description = [{
+    This dialect contains the definitions necessary to target Arm SME
+    scalable matrix operations.
+
+    Sources:
+    https://developer.arm.com/documentation/ddi0616
+    https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
+  }];
+  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
+                           "memref::MemRefDialect"];
+  let useDefaultAttributePrinterParser = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// ArmSME type definitions
+//===----------------------------------------------------------------------===//
+
+def SVEVector : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I1]>;
+
+
+#endif // ARMSME_OP_BASE
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
new file mode 100644
index 000000000000000..e3a051a171400eb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -0,0 +1,137 @@
+//===-- ArmSMEIntrinsicsOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitons of the intrinsic Ops for the ArmSME dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSME_INTRINSICS_OPS
+#define ARMSME_INTRINSICS_OPS
+
+include "ArmSMEBase.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSME Intrinsic op definitions
+//===----------------------------------------------------------------------===//
+
+def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>;
+def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
+                                              [I8, I16, BF16, F16, F32, F64]>;
+def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
+
+class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+                    list<Trait> traits = [], int numResults = 0,
+                    list<int> overloadedResults = []>
+    : LLVM_IntrOpBase<
+          /*Dialect dialect=*/ArmSME_Dialect,
+          /*string opName=*/"intr." # mnemonic,
+          /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
+          /*list<int> overloadedResults=*/overloadedResults,
+          /*list<int> overloadedOperands=*/overloadedOperands,
+          /*list<Trait> traits=*/traits,
+          /*int numResults=*/numResults>;
+
+// Zero
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
+                            Arguments<(ins Arg<I32, "Tile mask">)>;
+
+// MOP's
+class ArmSME_IntrMopOverloadedOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic, [4]>,
+      Arguments<(ins Arg<I32, "Virtual tile ID">,
+                 Arg<MOPPredicate, "LHS predicate">,
+                 Arg<MOPPredicate, "RHS predicate">,
+                 Arg<MOPVector, "LHS vector operand">,
+                 Arg<MOPVector, "RHS vector operand">)>;
+
+def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
+def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
+def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
+def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
+def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
+def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
+def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
+def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
+def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
+def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
+def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
+def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+
+// Loads
+class ArmSME_IntrLoadOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic>,
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+                 Arg<LLVM_AnyPointer, "Load address">,
+                 Arg<I32, "Virtual tile ID">,
+                 Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
+def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
+def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
+def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
+def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
+def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
+def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
+def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
+def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
+def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
+
+// Stores
+class ArmSME_IntrStoreOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic>,
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
+                 Arg<I32, "Virtual tile ID">,
+                 Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
+def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
+def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
+def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
+def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
+def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
+def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
+def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
+def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
+def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
+
+def LLVM_aarch64_sme_str
+    : ArmSME_IntrOp<"str">,
+      Arguments<(ins Arg<I32, "Index">,
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
+
+// Vector to tile slice
+class LLVM_aarch64_sme_write<string direction>
+    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+                    [AllShapesMatch<["pg", "vector"]>]>,
+      Arguments<(ins Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<SVEVector, "Vector operand">:$vector)>;
+
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+                    [AllShapesMatch<["vector", "pg", "res"]>,
+                     AllElementTypesMatch<["vector", "res"]>],
+                    /*numResults=*/1, /*overloadedResults=*/[0]>,
+      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
+def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
+
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
+def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
+def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
+
+#endif // ARMSME_INTRINSICS_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index 617809e482b2caa..62b61aa90246a01 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -10,3 +10,10 @@ mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
 mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
 mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
 add_public_tablegen_target(MLIRArmSMEAttrDefsIncGen)
+
+# Generate declarations and definitions of ArmSME intrinsics
+set(LLVM_TARGET_DEFINITIONS ArmSMEIntrinsicOps.td)
+mlir_tablegen(ArmSMEIntrinsicOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmSMEIntrinsicOps.cpp.inc -gen-op-defs)
+mlir_tablegen(ArmSMEIntrinsicConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(ArmSMEIntrinsicOpsIncGen)
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 101cb750f4a6f30..e7fce6a7fe6f1f5 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -31,6 +31,9 @@ using namespace mlir::arm_sme;
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
 
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
 
@@ -46,6 +49,9 @@ void ArmSMEDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+      ,
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
       >();
 }
 
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index e75e958e18a2cfd..51cbd5423ecf49b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -187,10 +187,10 @@ struct LoadTileSliceToArmSMELowering
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
-    arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
+    TileSliceLayout layout = loadTileSliceOp.getLayout();
 
     // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
-    if (layout == arm_sme::TileSliceLayout::Horizontal) {
+    if (layout == TileSliceLayout::Horizontal) {
       switch (tileElementWidth) {
       default:
         llvm_unreachable("unexpected element type!");
@@ -292,9 +292,9 @@ struct StoreTileSliceToArmSMELowering
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
     Value tileI32 = castTileIDToI32(tile, loc, rewriter);
-    arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
+    TileSliceLayout layout = storeTileSliceOp.getLayout();
 
-    if (layout == arm_sme::TileSliceLayout::Horizontal) {
+    if (layout == TileSliceLayout::Horizontal) {
       switch (tileElementWidth) {
       default:
         llvm_unreachable("unexpected element type!");
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
index 1b57b9979af28ad..589f24c7312a264 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
@@ -36,6 +36,7 @@ class ArmSMEDialectLLVMIRTranslationInterface
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
 #include "mlir/Dialect/ArmSME/IR/ArmSMEConversions.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicConversions.inc"
 
     return failure();
   }



More information about the Mlir-commits mailing list