[Mlir-commits] [mlir] [mlir][ArmSME] Add custom vector.print lowering for SME tiles (PR #66691)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Sep 20 07:23:28 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/66691

>From 896ab7de80ddaf7ec36bb39a0f1d3f45950a42dd Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 12:50:24 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Add tile slice to vector intrinsics

Add support for following vector to tile (MOVA) intrinsics to ArmSME
dialect:

  llvm.aarch64.sme.read.vert
  llvm.aarch64.sme.read.horiz

This also slightly updates ArmSME_IntrOp to support return values.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td |  23 ++-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |   1 +
 mlir/test/Target/LLVMIR/arm-sme-invalid.mlir  |  24 ++++
 mlir/test/Target/LLVMIR/arm-sme.mlir          | 134 ++++++++++++++++++
 4 files changed, 178 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..00e1fefc0521a78 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
 def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
 
 class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
-                    list<Trait> traits = []>
+                    list<Trait> traits = [], int numResults = 0,
+                    list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
           /*Dialect dialect=*/ArmSME_Dialect,
           /*string opName=*/"intr." # mnemonic,
           /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
-          /*list<int> overloadedResults=*/[],
+          /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/0>;
+          /*int numResults=*/numResults>;
 
 // Zero
 def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
       Arguments<(ins Arg<I32, "Index">,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
 
-// Vector to tile
+// Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
     : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
                     [AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,23 @@ class LLVM_aarch64_sme_write<string direction>
                      Arg<SVEPredicate, "Vector predicate">:$pg,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+                    [AllShapesMatch<["vector", "pg", "res"]>,
+                     AllElementTypesMatch<["vector", "res"]>],
+                    /*numResults*/1, /*overloadedResults*/[0]>,
+      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">)>;
+
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
 def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
 
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
 def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
 def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
 
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..7cbc382b0050a6e 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 
 using namespace mlir;
 using namespace mlir::arm_sme;
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index e119e1f1a404416..ae99ac5e02d62f0 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -10,3 +10,27 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
       (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
+) -> vector<[3]xf32> {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
+  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
+      (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+  llvm.return %res : vector<[3]xf32>
+}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
+) -> vector<[3]xi32> {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
+  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
+      (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.return %res : vector<[4]xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 9bb6b0c6574fcdb..c318e6d2d37f7fd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -334,3 +334,137 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
       (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
+                                              %nxv16i1 : vector<[16]xi1>,
+                                              %nxv8i1 : vector<[8]xi1>,
+                                              %nxv4i1 : vector<[4]xi1>,
+                                              %nxv2i1 : vector<[2]xi1>,
+                                              %nxv1i1 : vector<[1]xi1>,
+                                              %nxv16i8 : vector<[16]xi8>,
+                                              %nxv8i16 : vector<[8]xi16>,
+                                              %nxv4i32 : vector<[4]xi32>,
+                                              %nxv2i64 : vector<[2]xi64>,
+                                              %nxv1i128 : vector<[1]xi128>,
+                                              %nxv8f16 : vector<[8]xf16>,
+                                              %nxv8bf16 : vector<[8]xbf16>,
+                                              %nxv4f32 : vector<[4]xf32>,
+                                              %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
+  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+  // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
+  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
+  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+  // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
+  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+  // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
+  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+  // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
+  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+  // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
+  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
+  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+  // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
+  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
+                                              %nxv16i1 : vector<[16]xi1>,
+                                              %nxv8i1 : vector<[8]xi1>,
+                                              %nxv4i1 : vector<[4]xi1>,
+                                              %nxv2i1 : vector<[2]xi1>,
+                                              %nxv1i1 : vector<[1]xi1>,
+                                              %nxv16i8 : vector<[16]xi8>,
+                                              %nxv8i16 : vector<[8]xi16>,
+                                              %nxv4i32 : vector<[4]xi32>,
+                                              %nxv2i64 : vector<[2]xi64>,
+                                              %nxv1i128 : vector<[1]xi128>,
+                                              %nxv8f16 : vector<[8]xf16>,
+                                              %nxv8bf16 : vector<[8]xbf16>,
+                                              %nxv4f32 : vector<[4]xf32>,
+                                              %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
+  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+  // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
+  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
+  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+  // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
+  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+  // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
+  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+  // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
+  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+  // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
+  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
+  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+  // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
+  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+  llvm.return
+}

>From 996abcf91e6c639f4cd484aab4b9d6b01c1b6b7e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 13:51:07 +0000
Subject: [PATCH 2/3] [mlir][ArmSME] Add `enable_arm_streaming_ignore`
 attribute

This attribute makes the `enable_arm_streaming` pass ignore a function
(i.e. not add the enable streaming/za attributes). The main use case for
this is to prevent helper functions within tests being made streaming
functions.
---
 mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp | 6 +++++-
 mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir        | 8 ++++++++
 2 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 97c38b546349510..1d3a090e861013b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -52,6 +52,8 @@ using namespace mlir::arm_sme;
 static constexpr char kArmStreamingAttr[] = "arm_streaming";
 static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
 static constexpr char kArmZAAttr[] = "arm_za";
+static constexpr char kEnableArmStreamingIgnoreAttr[] =
+    "enable_arm_streaming_ignore";
 
 namespace {
 struct EnableArmStreamingPass
@@ -61,7 +63,9 @@ struct EnableArmStreamingPass
     this->enableZA = enableZA;
   }
   void runOnOperation() override {
-    std::string attr;
+    if (getOperation()->getAttr(kEnableArmStreamingIgnoreAttr))
+      return;
+    StringRef attr;
     switch (mode) {
     case ArmStreaming::Default:
       attr = kArmStreamingAttr;
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index f5cc83192f9f6e1..e7bbe8c0047687d 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -9,3 +9,11 @@
 // CHECK-ENABLE-ZA-LABEL: @arm_streaming
 // CHECK-ENABLE-ZA-SAME: attributes {arm_streaming, arm_za}
 func.func @arm_streaming() { return }
+
+// CHECK-LABEL: @not_arm_streaming
+// CHECK-SAME: attributes {enable_arm_streaming_ignore}
+// CHECK-LOCALLY-LABEL: @not_arm_streaming
+// CHECK-LOCALLY-SAME: attributes {enable_arm_streaming_ignore}
+// CHECK-ENABLE-ZA-LABEL: @not_arm_streaming
+// CHECK-ENABLE-ZA-SAME: attributes {enable_arm_streaming_ignore}
+func.func @not_arm_streaming() attributes {enable_arm_streaming_ignore} { return }

>From 04c77e6786fd9f5b223b29a8e182409c7f0b0244 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 14:11:32 +0000
Subject: [PATCH 3/3] [mlir][ArmSME] Add custom vector.print lowering for SME
 tiles

This adds a custom lowering for SME that loops over each row of the
tile, extracting it via an SME MOVA, then printing with a normal 1D
vector.print.

This makes writing SME integration tests easier and less verbose.
---
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  6 ++
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 93 ++++++++++++++++++-
 .../Transforms/LegalizeForLLVMExport.cpp      | 17 ----
 mlir/lib/Dialect/ArmSME/Utils/Utils.cpp       | 14 +++
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           | 22 +++++
 .../CPU/ArmSME/test-outerproduct-f32.mlir     | 46 ++-------
 .../CPU/ArmSME/test-outerproduct-f64.mlir     | 26 +-----
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  | 28 +-----
 8 files changed, 150 insertions(+), 102 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 9e8ad48b3c2db94..0941592497beaae 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -34,6 +34,12 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
+/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
+/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
+/// integer, to an i32 that can be passed as the `tile` parameter to the SME
+/// intrinsics. Or returns `tile` if already i32.
+Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);
+
 } // namespace arm_sme
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..dce478129c869ce 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -190,11 +190,94 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
   }
 };
 
+/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
+/// extracting them via a MOVA, then printing with a 1D `vector.print`.
+///
+///  BEFORE:
+///  ```mlir
+///  vector.print %tile : vector<[4]x[4]xf32>
+///  ```
+///  AFTER:
+///  ```mlir
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %c4 = arith.constant 4 : index
+///  %ptrue = arith.constant dense<true> : vector<[4]xi1>
+///  %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
+///  %vscale = vector.vscale
+///  %svl_s = arith.muli %c4, %vscale : index
+///  %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+///  scf.for %i = %c0 to %svl_s step %c1 {
+///    %slice_idx = arith.index_cast %i : index to i32
+///    %tile_slice = "arm_sme.intr.read.horiz"
+///        (%cst, %ptrue, %tile_id, %slice_idx)
+///      : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+///    vector.print %tile_slice : vector<[4]xf32>
+///  }
+///  ```
+struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
+  using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::PrintOp printOp,
+                                PatternRewriter &rewriter) const override {
+    if (!printOp.getSource())
+      return failure();
+
+    VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
+    if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
+      return failure();
+
+    auto loc = printOp.getLoc();
+
+    // Create an 'all true' predicate for each tile row.
+    auto predicateType =
+        VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
+    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+
+    // Cast tile to i32 tile ID.
+    auto tileId =
+        rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
+    auto tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+
+    // Zero destination/fallback for tile slice extraction.
+    auto rowType = VectorType::get(vectorType.getDimSize(1),
+                                   vectorType.getElementType(), true);
+    auto zeroVector = rewriter.create<arith::ConstantOp>(
+        loc, rowType, rewriter.getZeroAttr(rowType));
+
+    // Create a loop over the rows of the tile.
+    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+    auto minTileRows =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    {
+      // Loop body.
+      rewriter.setInsertionPointToStart(forOp.getBody());
+      // Extract the current row from the tile.
+      auto rowIndex = forOp.getInductionVar();
+      auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
+          loc, rewriter.getI32Type(), rowIndex);
+      auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
+          loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
+      // Print the row with a 1D vector.print.
+      rewriter.create<vector::PrintOp>(loc, tileSlice,
+                                       printOp.getPunctuation());
+    }
+
+    rewriter.eraseOp(printOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<TileLoadOpConversion, TileStoreOpConversion>(
-      patterns.getContext());
+  patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+               TileVectorPrintOpConversion>(patterns.getContext());
 }
 
 namespace {
@@ -208,6 +291,12 @@ struct ConvertArmSMEToSCFPass
     target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
                            arith::ArithDialect, scf::SCFDialect>();
     target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
+    target.addDynamicallyLegalOp<vector::PrintOp>([](vector::PrintOp op) {
+      if (!op.getSource())
+        return true;
+      VectorType vectorType = dyn_cast<VectorType>(op.getPrintType());
+      return !vectorType || !arm_sme::isValidSMETileVectorType(vectorType);
+    });
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6c8843fbb4546e6..e4d1292358eb6d6 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -49,23 +49,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
   }
 };
 
-/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
-/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
-/// integer, to an i32 that can be passed as the `tile` parameter to the SME
-/// intrinsics. Or returns `tile` if already i32.
-Value castTileIDToI32(Value tile, Location loc,
-                      ConversionPatternRewriter &rewriter) {
-  assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
-             tile.getDefiningOp())) &&
-         "expected ArmSME GetTileID or CastVectorToTile op!");
-  unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
-  if (tileElementWidth < 32)
-    return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
-  if (tileElementWidth > 32)
-    return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
-  return tile;
-}
-
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
 ///  BEFORE:
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index b8a47951cc7bbba..f17077ff8565d59 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 
 using namespace mlir;
@@ -42,3 +43,16 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
 
   return true;
 }
+
+Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
+                                     RewriterBase &rewriter) {
+  assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
+             tile.getDefiningOp())) &&
+         "expected ArmSME GetTileID or CastVectorToTile op!");
+  unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
+  if (tileElementWidth < 32)
+    return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
+  if (tileElementWidth > 32)
+    return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
+  return tile;
+}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 9ab1d79794d7659..6d2fda268476373 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -36,3 +36,25 @@ func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi3
   arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
+
+// -----
+
+func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
+{
+  vector.print %tile : vector<[4]x[4]xf32>
+  return
+}
+// CHECK-LABEL:   func.func @arm_sme_tile_print(
+// CHECK-SAME:                                  %[[TILE:.*]]: vector<[4]x[4]xf32>) {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG:       %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK-DAG:       %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
+// CHECK-DAG:       %[[ZERO_VECTOR:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-DAG:       %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT:      scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:        %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_cast %[[TILE_SLICE_INDEX]] : index to i32
+// CHECK-NEXT:        %[[TILE_SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VECTOR]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+// CHECK-NEXT:        vector.print %[[TILE_SLICE]] : vector<[4]xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 00f1f6fd3fa8e19..4265ca0f599281c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -16,7 +16,7 @@
 
 llvm.func @printCString(!llvm.ptr<i8>)
 
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -25,7 +25,7 @@ func.func @printTileBegin() {
   return
 }
 
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -41,20 +41,8 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
   %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
   %tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
 
-  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
-  %vscale = vector.vscale
-  %min_elts_s = arith.constant 4 : index
-  %svl_s = arith.muli %min_elts_s, %vscale : index
-  %za_s_size = arith.muli %svl_s, %svl_s : index
-
-  // Allocate memory.
-  %mem = memref.alloca(%za_s_size) : memref<?xf32>
-
-  // Store the tile to memory.
-  vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
-
-  // Reload and print. The smallest SVL is 128-bits so the tile will be at
-  // least 4x4xf32.
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xf32.
   //
   // WITHOUT-ACC:      TILE BEGIN
   // WITHOUT-ACC-NEXT: ( 0, 0, 0, 0
@@ -63,10 +51,7 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
   // WITHOUT-ACC-NEXT: ( 0, 3, 6, 9
   // WITHOUT-ACC:      TILE END
   func.call @printTileBegin() : () -> ()
-  scf.for %i = %c0 to %za_s_size step %svl_s {
-    %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
-    vector.print %tileslice : vector<[4]xf32>
-  }
+  vector.print %tile : vector<[4]x[4]xf32>
   func.call @printTileEnd() : () -> ()
 
   return
@@ -81,20 +66,8 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
   %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
   %tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
 
-  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
-  %vscale = vector.vscale
-  %min_elts_s = arith.constant 4 : index
-  %svl_s = arith.muli %min_elts_s, %vscale : index
-  %za_s_size = arith.muli %svl_s, %svl_s : index
-
-  // Allocate memory.
-  %mem = memref.alloca(%za_s_size) : memref<?xf32>
-
-  // Store the tile to memory.
-  vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
-
-  // Reload and print. The smallest SVL is 128-bits so the tile will be at
-  // least 4x4xf32.
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xf32.
   //
   // WITH-ACC:      TILE BEGIN
   // WITH-ACC-NEXT: ( 10, 10, 10, 10
@@ -103,10 +76,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
   // WITH-ACC-NEXT: ( 10, 13, 16, 19
   // WITH-ACC:      TILE END
   func.call @printTileBegin() : () -> ()
-  scf.for %i = %c0 to %za_s_size step %svl_s {
-    %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
-    vector.print %tileslice : vector<[4]xf32>
-  }
+  vector.print %tile : vector<[4]x[4]xf32>
   func.call @printTileEnd() : () -> ()
 
   return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 2c2a06fa8db26e1..cb2c6b98a4eef3a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -13,7 +13,7 @@
 
 llvm.func @printCString(!llvm.ptr<i8>)
 
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
   return
 }
 
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -32,7 +32,6 @@ func.func @printTileEnd() {
 }
 
 func.func @test_outerproduct_with_accumulator_2x2xf64() {
-  %c0 = arith.constant 0 : index
   %f1 = arith.constant 1.0 : f64
   %f2 = arith.constant 2.0 : f64
   %f10 = arith.constant 10.0 : f64
@@ -44,30 +43,15 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
 
   %tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>
 
-  // Calculate the size of a 64-bit tile, e.g. ZA{n}.d.
-  %vscale = vector.vscale
-  %min_elts_d = arith.constant 2 : index
-  %svl_d = arith.muli %min_elts_d, %vscale : index
-  %za_d_size = arith.muli %svl_d, %svl_d : index
-
-  // Allocate memory.
-  %mem = memref.alloca(%za_d_size) : memref<?xf64>
-
-  // Store the tile to memory.
-  vector.store %tile, %mem[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
-
-  // Reload and print. The smallest SVL is 128-bits so the tile will be at
-  // least 2x2xf64.
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 2x2xf64.
   //
   // CHECK:      TILE BEGIN
   // CHECK-NEXT: ( 12, 12
   // CHECK-NEXT: ( 12, 12
   // CHECK:      TILE END
   func.call @printTileBegin() : () -> ()
-  scf.for %i = %c0 to %za_d_size step %svl_d {
-    %tileslice = vector.load %mem[%i] : memref<?xf64>, vector<[2]xf64>
-    vector.print %tileslice : vector<[2]xf64>
-  }
+  vector.print %tile : vector<[2]x[2]xf64>
   func.call @printTileEnd() : () -> ()
 
   return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index a407b13b541839f..fe6ded71c1613fa 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -13,7 +13,7 @@
 
 llvm.func @printCString(!llvm.ptr<i8>)
 
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
   return
 }
 
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -32,29 +32,12 @@ func.func @printTileEnd() {
 }
 
 func.func @entry() -> i32 {
-  %c0 = arith.constant 0 : index
-  %c1_index = arith.constant 1 : index
-
-  %min_elts_s = arith.constant 4 : index
-  %vscale = vector.vscale
-
-  // "svl" refers to the Streaming Vector Length and "svl_s" the number of
-  // 32-bit elements in a vector of SVL bits.
-  %svl_s = arith.muli %min_elts_s, %vscale : index
-
-  // Allocate memory.
-  %tilesize = arith.muli %svl_s, %svl_s : index
-  %mem = memref.alloca(%tilesize) : memref<?xi32>
-
   // Fill a tile with '123'. This will get lowered to a 1-d vector splat of
   // '123' and a loop that writes this vector to each tile slice in the ZA
   // tile.
   %tile = arith.constant dense<123> : vector<[4]x[4]xi32>
 
-  // Store tile to memory so it can be dumped.
-  vector.store %tile, %mem[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-
-  // Dump "mem". The smallest SVL is 128-bits so the tile will be at least
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
   // 4x4xi32.
   //
   // CHECK:      TILE BEGIN
@@ -64,10 +47,7 @@ func.func @entry() -> i32 {
   // CHECK-NEXT: ( 123, 123, 123, 123
   // CHECK:      TILE END
   func.call @printTileBegin() : () -> ()
-  scf.for %i = %c0 to %tilesize step %svl_s {
-    %tileslice = vector.load %mem[%i] : memref<?xi32>, vector<[4]xi32>
-    vector.print %tileslice : vector<[4]xi32>
-  }
+  vector.print %tile : vector<[4]x[4]xi32>
   func.call @printTileEnd() : () -> ()
 
   %c0_i32 = arith.constant 0 : i32



More information about the Mlir-commits mailing list