[Mlir-commits] [mlir] [mlir][ArmSME] Name arguments of SME intrinsics (NFC) (PR #69608)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Oct 19 08:55:49 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/69608

>From 54aef262f4fea973d61f8168d7f18c4ac62ad4b5 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 15:26:18 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Name arguments of SME intrinsics (NFC)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This makes the docs a little nicer to read, as these otherwise show up
as "«unnamed»".

The extra include is needed as naming means getters are generated, and
the getters use the LLVM types.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h  |  1 +
 .../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td   | 48 +++++++++----------
 2 files changed, 25 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index b27ceca215dad42..fe1f9062a37ef51 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_ARMSME_IR_ARMSME_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index feeac3b8a0355f9..cc61047af690f3b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -38,16 +38,16 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
 
 // Zero
 def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
-                            Arguments<(ins Arg<I32, "Tile mask">)>;
+                            Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
 
 // MOP's
 class ArmSME_IntrMopOverloadedOp<string mnemonic>
     : ArmSME_IntrOp<mnemonic, [4]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">,
-                 Arg<MOPPredicate, "LHS predicate">,
-                 Arg<MOPPredicate, "RHS predicate">,
-                 Arg<MOPVector, "LHS vector operand">,
-                 Arg<MOPVector, "RHS vector operand">)>;
+      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
+                 Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
+                 Arg<MOPVector, "LHS vector operand">:$lhs_operand,
+                 Arg<MOPVector, "RHS vector operand">:$rhs_operand)>;
 
 def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
 def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
@@ -65,10 +65,10 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
 // Loads
 class ArmSME_IntrLoadOp<string mnemonic>
     : ArmSME_IntrOp<mnemonic>,
-      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
-                 Arg<LLVM_AnyPointer, "Load address">,
-                 Arg<I32, "Virtual tile ID">,
-                 Arg<I32, "Tile slice">)>;
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">:$predicate,
+                 Arg<LLVM_AnyPointer, "Load address">:$load_address,
+                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
 def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
@@ -84,10 +84,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
 // Stores
 class ArmSME_IntrStoreOp<string mnemonic>
     : ArmSME_IntrOp<mnemonic>,
-      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
-                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
-                 Arg<I32, "Virtual tile ID">,
-                 Arg<I32, "Tile slice">)>;
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">:$predicate,
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
+                 Arg<I32, "Virtual tile ID">:$tild_id,
+                 Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
 def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
@@ -102,28 +102,28 @@ def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
 
 def LLVM_aarch64_sme_str
     : ArmSME_IntrOp<"str">,
-      Arguments<(ins Arg<I32, "Index">,
-                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
+      Arguments<(ins Arg<I32, "Index">:$index,
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address)>;
 
 // Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
     : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
-                    [AllShapesMatch<["pg", "vector"]>]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">,
-                     Arg<I32, "Tile slice">,
-                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                    [AllShapesMatch<["predicate", "vector"]>]>,
+      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+                     Arg<I32, "Tile slice">:$tile_slice_index,
+                     Arg<SVEPredicate, "Vector predicate">:$predicate,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
 // Tile slice to vector
 class LLVM_aarch64_sme_read<string direction>
     : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
-                    [AllShapesMatch<["vector", "pg", "res"]>,
+                    [AllShapesMatch<["vector", "predicate", "res"]>,
                      AllElementTypesMatch<["vector", "res"]>],
                     /*numResults=*/1, /*overloadedResults=*/[0]>,
       Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
-                     Arg<SVEPredicate, "Vector predicate">:$pg,
-                     Arg<I32, "Virtual tile ID">,
-                     Arg<I32, "Tile slice">)>;
+                     Arg<SVEPredicate, "Vector predicate">:$predicate,
+                     Arg<I32, "Virtual tile ID">:$tile_id,
+                     Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
 def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;

>From 4a07e623290ca4145efa303557d9c3c67e7fe872 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 15:54:53 +0000
Subject: [PATCH 2/2] Fixup update names in tests

---
 mlir/test/Target/LLVMIR/arm-sme-invalid.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index ae99ac5e02d62f0..b3202b26f8e1e3d 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -5,7 +5,7 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
                                                 %nxv4i1 : vector<[4]xi1>,
                                                 %nxv16i8 : vector<[16]xi8>) {
   %tile = llvm.mlir.constant(0 : index) : i32
-  // expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}}
+  // expected-error @+1 {{failed to verify that all of {predicate, vector} have same shape}}
   "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
       (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
   llvm.return
@@ -17,7 +17,7 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
   %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
 ) -> vector<[3]xf32> {
   %tile = llvm.mlir.constant(0 : index) : i32
-  // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
+  // expected-error @+1 {{failed to verify that all of {vector, predicate, res} have same shape}}
   %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
       (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
   llvm.return %res : vector<[3]xf32>



More information about the Mlir-commits mailing list