[Mlir-commits] [mlir] [mlir][SVE] Add an e2e test for vectorization of linalg.matmul (PR #69592)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Oct 23 00:28:17 PDT 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/69592

>From 0c2fd7b2692b5e2d932b2f635933e5de59720b05 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 10 Oct 2023 11:38:12 +0000
Subject: [PATCH 1/5] [mlir][VectorOps] Support string literals in
 `vector.print`

Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:

```
llvm.mlir.global internal constant @hello_world("Hello, World!\0")

func.func @entry() {
  %0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>>
  %1 = llvm.mlir.constant(0 : index) : i64
  %2 = llvm.getelementptr %0[%1, %1]
    : (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8>
  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
  return
}
``

So this patch adds a simple extension to `vector.print` to simplify
this:
```
func.func @entry() {
   // Print a vector of characters ;)
   vector.print str "Hello, World!"
   return
}
```

Most of the logic for this is now shared with `cf.assert` which already
does something similar.
---
 .../Conversion/LLVMCommon/PrintCallHelper.h   | 36 ++++++++++
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 37 +++++++++--
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   | 49 +-------------
 mlir/lib/Conversion/LLVMCommon/CMakeLists.txt |  1 +
 .../Conversion/LLVMCommon/PrintCallHelper.cpp | 66 +++++++++++++++++++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  6 +-
 .../VectorToLLVM/vector-to-llvm.mlir          | 14 ++++
 .../Dialect/Vector/CPU/test-hello-world.mlir  | 10 +++
 8 files changed, 168 insertions(+), 51 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
 create mode 100644 mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
new file mode 100644
index 000000000000000..7e26858589f2756
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -0,0 +1,36 @@
+
+//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+class Location;
+class ModuleOp;
+class OpBuilder;
+class Operation;
+class Type;
+class ValueRange;
+class LLVMTypeConverter;
+
+namespace LLVM {
+
+/// Generate IR that prints the given string to stdout.
+void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+                        StringRef symbolName, StringRef string,
+                        const LLVMTypeConverter &typeConverter);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..2b60055ca9db94b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/BuiltinAttributes.td"
 
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
@@ -2476,12 +2477,18 @@ def Vector_TransposeOp :
 }
 
 def Vector_PrintOp :
-  Vector_Op<"print", []>,
+  Vector_Op<"print", [
+    PredOpTrait<
+      "`source` or `punctuation` are not set printing strings",
+      CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
+    >,
+  ]>,
   Arguments<(ins Optional<Type<Or<[
     AnyVectorOfAnyRank.predicate,
     AnyInteger.predicate, Index.predicate, AnyFloat.predicate
   ]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
-                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
+                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
+                OptionalAttr<Builtin_StringAttr>:$stringLiteral)
   > {
   let summary = "print operation (for testing and debugging)";
   let description = [{
@@ -2520,6 +2527,13 @@ def Vector_PrintOp :
     ```mlir
     vector.print punctuation <newline>
     ```
+
+    Additionally, to aid with debugging and testing `vector.print` can also
+    print constant strings:
+
+    ```mlir
+    vector.print str "Hello, World!"
+    ```
   }];
   let extraClassDeclaration = [{
     Type getPrintType() {
@@ -2528,11 +2542,26 @@ def Vector_PrintOp :
   }];
   let builders = [
     OpBuilder<(ins "PrintPunctuation":$punctuation), [{
-      build($_builder, $_state, {}, punctuation);
+      build($_builder, $_state, {}, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source), [{
+      build($_builder, $_state, source, PrintPunctuation::NewLine);
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
+      build($_builder, $_state, source, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::llvm::StringRef":$string), [{
+      build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
     }]>,
   ];
 
-  let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
+  let assemblyFormat = [{
+      ($source^ `:` type($source))?
+        oilist(
+            `str` $stringLiteral
+          | `punctuation` $punctuation)
+        attr-dict
+    }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index a4f146bbe475cc6..6b7647b038f1d94 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
 
 #define PASS_NAME "convert-cf-to-llvm"
 
-static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
-  std::string prefix = "assert_msg_";
-  int counter = 0;
-  while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
-    ++counter;
-  return prefix + std::to_string(counter);
-}
-
-/// Generate IR that prints the given string to stderr.
-static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
-                           StringRef msg,
-                           const LLVMTypeConverter &typeConverter) {
-  auto ip = builder.saveInsertionPoint();
-  builder.setInsertionPointToStart(moduleOp.getBody());
-  MLIRContext *ctx = builder.getContext();
-
-  // Create a zero-terminated byte representation and allocate global symbol.
-  SmallVector<uint8_t> elementVals;
-  elementVals.append(msg.begin(), msg.end());
-  elementVals.push_back(0);
-  auto dataAttrType = RankedTensorType::get(
-      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
-  auto dataAttr =
-      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
-  auto arrayTy =
-      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
-  std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
-  auto globalOp = builder.create<LLVM::GlobalOp>(
-      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
-      dataAttr);
-
-  // Emit call to `printStr` in runtime library.
-  builder.restoreInsertionPoint(ip);
-  auto msgAddr = builder.create<LLVM::AddressOfOp>(
-      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
-  SmallVector<LLVM::GEPArg> indices(1, 0);
-  Value gep = builder.create<LLVM::GEPOp>(
-      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
-      indices);
-  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
-      moduleOp, typeConverter.useOpaquePointers());
-  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
-                               gep);
-}
-
 namespace {
 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
 /// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
+    LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+                             *getTypeConverter());
     if (abortOnFailedAssert) {
       // Insert the `abort` declaration if necessary.
       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 091cd539f0ae014..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   LoweringOptions.cpp
   MemRefBuilder.cpp
   Pattern.cpp
+  PrintCallHelper.cpp
   StructBuilder.cpp
   TypeConverter.cpp
   VectorPattern.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
new file mode 100644
index 000000000000000..487abb435d10ad7
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -0,0 +1,66 @@
+
+//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace llvm;
+
+static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
+                                            StringRef symbolName) {
+  static int counter = 0;
+  std::string uniqueName = std::string(symbolName);
+  while (moduleOp.lookupSymbol(uniqueName)) {
+    uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
+  }
+  return uniqueName;
+}
+
+void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
+                                    ModuleOp moduleOp, StringRef symbolName,
+                                    StringRef string,
+                                    const LLVMTypeConverter &typeConverter) {
+  auto ip = builder.saveInsertionPoint();
+  builder.setInsertionPointToStart(moduleOp.getBody());
+  MLIRContext *ctx = builder.getContext();
+
+  // Create a zero-terminated byte representation and allocate global symbol.
+  SmallVector<uint8_t> elementVals;
+  elementVals.append(string.begin(), string.end());
+  elementVals.push_back(0);
+  auto dataAttrType = RankedTensorType::get(
+      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+  auto dataAttr =
+      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+  auto arrayTy =
+      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+  auto globalOp = builder.create<LLVM::GlobalOp>(
+      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+      ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+
+  // Emit call to `printStr` in runtime library.
+  builder.restoreInsertionPoint(ip);
+  auto msgAddr = builder.create<LLVM::AddressOfOp>(
+      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
+  SmallVector<LLVM::GEPArg> indices(1, 0);
+  Value gep = builder.create<LLVM::GEPOp>(
+      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+      indices);
+  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+      moduleOp, typeConverter.useOpaquePointers());
+  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+                               gep);
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..4af58653c8227ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     }
 
     auto punct = printOp.getPunctuation();
-    if (punct != PrintPunctuation::NoPunctuation) {
+    if (auto stringLiteral = printOp.getStringLiteral()) {
+      LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+                               *stringLiteral, *getTypeConverter());
+    } else if (punct != PrintPunctuation::NoPunctuation) {
       emitCall(rewriter, printOp->getLoc(), [&] {
         switch (punct) {
         case PrintPunctuation::Close:
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9aa4d735681f576..65b3a78e295f0c4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
 
 // -----
 
+// CHECK-LABEL: module {
+// CHECK: llvm.func @puts(!llvm.ptr)
+// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8>
+// CHECK: @vector_print_string
+//       CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
+//       CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+//       CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> ()
+func.func @vector_print_string() {
+  vector.print str "Hello, World!"
+  return
+}
+
+// -----
+
 func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
   return %0 : vector<2xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
new file mode 100644
index 000000000000000..c4076e65151ac72
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+   // CHECK: Hello, World!
+   vector.print str "Hello, World!"
+   return
+}

>From 3b807f16bd49286342906d1c187bedf6856fe486 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 10 Oct 2023 18:53:05 +0000
Subject: [PATCH 2/5] [mlir][ArmSVE] Add `-arm-sve-legalize-vector-storage`
 pass

This patch adds a pass that ensures that loads, stores, and allocations
of SVE vector types will be legal in the LLVM backend. It does this at
the memref level, so this pass must be applied before lowering all the
way to LLVM.

This pass currently fixes two issues.

It is only legal to load/store predicate types equal to (or greater
than) a full predicate register, which in MLIR is `vector<[16]xi1>`.
Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted
to/from a full predicate type (referred to as a `svbool`) before and
after storing and loading respectively. This pass does this by widening
allocations and inserting conversion intrinsics.

For example:

```mlir
%alloca = memref.alloca() : memref<vector<[4]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
%reload = memref.load %alloca[] : memref<vector<[4]xi1>>
```
Becomes:
```mlir
%alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
%reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
%reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
```

The storage for SVE vector types only needs to have an alignment that
matches the element type (for example 4 byte alignment for `f32`s).
However, the LLVM backend currently defaults to aligning to `base size`
x `element size` bytes. For non-legal vector types like
`vector<[8]xf32>` this results in 8 x 4 = 32-byte alignment, but the
backend only supports up to 16-byte alignment for SVE vectors on the
stack. Explicitly setting a smaller alignment prevents this issue.
---
 .../mlir/Dialect/ArmSVE/CMakeLists.txt        |   1 +
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   5 +
 .../mlir/Dialect/ArmSVE/Transforms/Passes.h   |  33 ++
 .../mlir/Dialect/ArmSVE/Transforms/Passes.td  |  67 ++++
 mlir/include/mlir/InitAllPasses.h             |   2 +
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   2 +
 .../Transforms/LegalizeVectorStorage.cpp      | 310 ++++++++++++++++++
 .../ArmSVE/legalize-vector-storage.mlir       | 160 +++++++++
 .../ArmSVE/arrays-of-scalable-vectors.mlir    | 121 +++++++
 9 files changed, 701 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
 create mode 100644 mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
 create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
 create mode 100644 mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index f33061b2d87cffc..9f57627c321fb0c 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..7226642daf86172
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSVE)
+add_public_tablegen_target(MLIRArmSVEPassIncGen)
+
+add_mlir_doc(Passes ArmSVEPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
new file mode 100644
index 000000000000000..317fb9021b3c577
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
@@ -0,0 +1,33 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::arm_sve {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+
+/// Pass to legalize the types of mask stores.
+std::unique_ptr<Pass> createLegalizeVectorStoragePass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+
+} // namespace mlir::arm_sve
+
+#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
new file mode 100644
index 000000000000000..35c49607181da0c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
@@ -0,0 +1,67 @@
+//===-- Passes.td - ArmSVE pass definition file ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+def LegalizeVectorStorage
+    : Pass<"arm-sve-legalize-vector-storage", "mlir::func::FuncOp"> {
+  let summary = "Ensures stores of SVE vector types will be legal";
+  let description = [{
+    This pass ensures that loads, stores, and allocations of SVE vector types
+    will be legal in the LLVM backend. It does this at the memref level, so this
+    pass must be applied before lowering all the way to LLVM.
+
+    This pass currently fixes two issues.
+
+    ## Loading and storing predicate types
+
+    It is only legal to load/store predicate types equal to (or greater than) a
+    full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller
+    predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full
+    predicate type (referred to as a `svbool`) before and after storing and
+    loading respectively. This pass does this by widening allocations and
+    inserting conversion intrinsics.
+
+    For example:
+
+    ```mlir
+    %alloca = memref.alloca() : memref<vector<[4]xi1>>
+    %mask = vector.constant_mask [4] : vector<[4]xi1>
+    memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
+    %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+    ```
+    Becomes:
+    ```mlir
+    %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+    %mask = vector.constant_mask [4] : vector<[4]xi1>
+    %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
+    memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
+    %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
+    %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
+    ```
+
+    ## Relax alignments for SVE vector allocas
+
+    The storage for SVE vector types only needs to have an alignment that
+    matches the element type (for example 4 byte alignment for `f32`s). However,
+    the LLVM backend currently defaults to aligning to `base size` x
+    `element size` bytes. For non-legal vector types like `vector<[8]xf32>` this
+    results in 8 x 4 = 32-byte alignment, but the backend only supports up to
+    16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller
+    alignment prevents this issue.
+  }];
+  let constructor = "mlir::arm_sve::createLegalizeVectorStoragePass()";
+  let dependentDialects = ["func::FuncDialect",
+    "memref::MemRefDialect", "vector::VectorDialect",
+    "arm_sve::ArmSVEDialect"];
+}
+
+#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..7301905954f56d8 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
 #include "mlir/Dialect/Async/Passes.h"
 #include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
   transform::registerTransformPasses();
   vector::registerVectorPasses();
   arm_sme::registerArmSMEPasses();
+  arm_sve::registerArmSVEPasses();
 
   // Dialect pipelines
   bufferization::registerBufferizationPipelines();
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 2f1c43fae240d51..a70c489a51fea9a 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,8 +1,10 @@
 add_mlir_dialect_library(MLIRArmSVETransforms
   LegalizeForLLVMExport.cpp
+  LegalizeVectorStorage.cpp
 
   DEPENDS
   MLIRArmSVEConversionsIncGen
+  MLIRArmSVEPassIncGen
 
   LINK_LIBS PUBLIC
   MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
new file mode 100644
index 000000000000000..610eb38089c4c88
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -0,0 +1,310 @@
+//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::arm_sve {
+#define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sve
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+constexpr StringLiteral kPassLabel("__arm_sve_legalize_vector_storage__");
+
+namespace {
+
+/// A (legal) SVE predicate mask that has a logical size, i.e. the number of
+/// bits match the number of lanes it masks (such as vector<[4]xi1>), but is too
+/// small to be stored to memory.
+bool isLogicalSVEPredicateType(VectorType type) {
+  return type.getRank() > 0 && type.getElementType().isInteger(1) &&
+         type.getScalableDims().back() && type.getShape().back() < 16 &&
+         llvm::isPowerOf2_32(type.getShape().back()) &&
+         !llvm::is_contained(type.getScalableDims().drop_back(), true);
+}
+
+VectorType widenScalableMaskTypeToSvbool(VectorType type) {
+  assert(isLogicalSVEPredicateType(type));
+  return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
+}
+
+template <typename TOp, typename TLegalizerCallback>
+void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
+                              TLegalizerCallback callback) {
+  // Clone the previous op to preserve any properties/attributes.
+  auto newOp = op.clone();
+  rewriter.insert(newOp);
+  rewriter.replaceOp(op, callback(newOp));
+}
+
+template <typename TOp, typename TLegalizerCallback>
+void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
+                                       TLegalizerCallback callback) {
+  replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
+    // Mark our `unrealized_conversion_casts` with a pass label.
+    return rewriter.create<UnrealizedConversionCastOp>(
+        op.getLoc(), TypeRange{op.getResult().getType()},
+        ValueRange{callback(newOp)},
+        NamedAttribute(rewriter.getStringAttr(kPassLabel),
+                       rewriter.getUnitAttr()));
+  });
+}
+
+/// Extracts the legal memref value from the `unrealized_conversion_casts` added
+/// by this pass.
+static FailureOr<Value> getLegalMemRef(Value illegalMemref) {
+  Operation *definingOp = illegalMemref.getDefiningOp();
+  if (!definingOp || !definingOp->hasAttr(kPassLabel))
+    return failure();
+  auto unrealizedConversion =
+      llvm::cast<UnrealizedConversionCastOp>(definingOp);
+  return unrealizedConversion.getOperand(0);
+}
+
+/// The default alignment of an alloca may request overaligned sizes for SVE
+/// types, which will fail during stack frame allocation. This rewrite
+/// explicitly adds a reasonable alignment to allocas of scalable types.
+struct RelaxScalableVectorAllocaAlignment
+    : public OpRewritePattern<memref::AllocaOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
+                                PatternRewriter &rewriter) const override {
+    auto elementType = allocaOp.getType().getElementType();
+    auto vectorType = llvm::dyn_cast<VectorType>(elementType);
+    if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
+      return failure();
+
+    unsigned elementByteSize =
+        vectorType.getElementType().getIntOrFloatBitWidth() / 8;
+
+    unsigned aligment = std::max(1u, elementByteSize);
+    allocaOp.setAlignment(aligment);
+
+    return success();
+  }
+};
+
+/// Replaces allocations of SVE predicates smaller than an svbool with a wider
+/// allocation and a tagged unrealized conversion.
+///
+/// Example
+/// ```
+/// %alloca = memref.alloca() : memref<vector<[4]xi1>>
+/// ```
+/// is rewritten into:
+/// ```
+/// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+/// %alloca = builtin.unrealized_conversion_cast %widened
+///   : memref<vector<[16]xi1>> to memref<vector<[4]xi1>>
+///     {__arm_sve_legalize_vector_storage__}
+/// ```
+template <typename AllocLikeOp>
+struct LegalizeAllocLikeOpConversion : public OpRewritePattern<AllocLikeOp> {
+  using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
+                                PatternRewriter &rewriter) const override {
+    auto vectorType =
+        llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
+
+    if (!vectorType || !isLogicalSVEPredicateType(vectorType))
+      return failure();
+
+    // Replace this alloc-like op of an SVE mask with one of a (storable)
+    // svbool_t mask. A temporary unrealized_conversion_cast is added to the old
+    // type to allow local rewrites.
+    replaceOpWithUnrealizedConversion(
+        rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
+          newAllocLikeOp.getResult().setType(
+              llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
+                  {}, widenScalableMaskTypeToSvbool(vectorType))));
+          return newAllocLikeOp;
+        });
+
+    return success();
+  }
+};
+
+/// Replaces vector.type_casts of unrealized conversions to illegal memref types
+/// with legal type casts, followed by unrealized conversions.
+///
+/// Example:
+/// ```
+/// %alloca = builtin.unrealized_conversion_cast %widened
+///   : memref<vector<[16]xi1>> to memref<vector<[8]xi1>>
+///     {__arm_sve_legalize_vector_storage__}
+/// %cast = vector.type_cast %alloca
+///   : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
+/// ```
+/// is rewritten into:
+/// ```
+/// %widened_cast = vector.type_cast %widened
+///   : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
+/// %cast = builtin.unrealized_conversion_cast %widened_cast
+///   : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>>
+///     {__arm_sve_legalize_vector_storage__}
+/// ```
+struct LegalizeVectorTypeCastConversion
+    : public OpRewritePattern<vector::TypeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto resultType = typeCastOp.getResultMemRefType();
+    auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
+
+    if (!vectorType || !isLogicalSVEPredicateType(vectorType))
+      return failure();
+
+    auto legalMemref = getLegalMemRef(typeCastOp.getMemref());
+    if (failed(legalMemref))
+      return failure();
+
+    // Replace this vector.type_cast with one of a (storable) svbool_t mask.
+    replaceOpWithUnrealizedConversion(
+        rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
+          newTypeCast.setOperand(*legalMemref);
+          newTypeCast.getResult().setType(
+              llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
+                  {}, widenScalableMaskTypeToSvbool(vectorType))));
+          return newTypeCast;
+        });
+
+    return success();
+  }
+};
+
+/// Replaces stores to unrealized conversions to illegal memref types with
+/// `arm_sve.convert_to_svbool`s followed by (legal) wider stores.
+///
+/// Example:
+/// ```
+/// memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
+/// ```
+/// is rewritten into:
+/// ```
+/// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1>
+/// memref.store %svbool, %widened[] : memref<vector<[16]xi1>>
+/// ```
+struct LegalizeMemrefStoreConversion
+    : public OpRewritePattern<memref::StoreOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::StoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = storeOp.getLoc();
+
+    Value valueToStore = storeOp.getValueToStore();
+    auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
+
+    if (!vectorType || !isLogicalSVEPredicateType(vectorType))
+      return failure();
+
+    auto legalMemref = getLegalMemRef(storeOp.getMemref());
+    if (failed(legalMemref))
+      return failure();
+
+    auto legalMaskType = widenScalableMaskTypeToSvbool(
+        llvm::cast<VectorType>(valueToStore.getType()));
+    auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>(
+        loc, legalMaskType, valueToStore);
+    // Replace this store with a conversion to a storable svbool_t mask,
+    // followed by a wider store.
+    replaceOpWithLegalizedOp(rewriter, storeOp,
+                             [&](memref::StoreOp newStoreOp) {
+                               newStoreOp.setOperand(0, convertToSvbool);
+                               newStoreOp.setOperand(1, *legalMemref);
+                               return newStoreOp;
+                             });
+
+    return success();
+  }
+};
+
+/// Replaces loads from unrealized conversions to illegal memref types with
+/// (legal) wider loads, followed by `arm_sve.convert_from_svbool`s.
+///
+/// Example:
+/// ```
+/// %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+/// ```
+/// is rewritten into:
+/// ```
+/// %svbool = memref.load %widened[] : memref<vector<[16]xi1>>
+/// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1>
+/// ```
+struct LegalizeMemrefLoadConversion : public OpRewritePattern<memref::LoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::LoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = loadOp.getLoc();
+
+    Value loadedMask = loadOp.getResult();
+    auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType());
+
+    if (!vectorType || !isLogicalSVEPredicateType(vectorType))
+      return failure();
+
+    auto legalMemref = getLegalMemRef(loadOp.getMemref());
+    if (failed(legalMemref))
+      return failure();
+
+    auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
+    // Replace this load with a legal load of an svbool_t type, followed by a
+    // conversion back to the original type.
+    replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
+      newLoadOp.setMemRef(*legalMemref);
+      newLoadOp.getResult().setType(legalMaskType);
+      return rewriter.create<arm_sve::ConvertFromSvboolOp>(
+          loc, loadedMask.getType(), newLoadOp);
+    });
+
+    return success();
+  }
+};
+
+struct LegalizeVectorStorage
+    : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<RelaxScalableVectorAllocaAlignment,
+                 LegalizeAllocLikeOpConversion<memref::AllocaOp>,
+                 LegalizeAllocLikeOpConversion<memref::AllocOp>,
+                 LegalizeVectorTypeCastConversion,
+                 LegalizeMemrefStoreConversion, LegalizeMemrefLoadConversion>(
+        patterns.getContext());
+    if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                            std::move(patterns)))) {
+      signalPassFailure();
+    }
+    ConversionTarget target(getContext());
+    target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
+        [](UnrealizedConversionCastOp unrealizedConversion) {
+          return !unrealizedConversion->hasAttr(kPassLabel);
+        });
+    // This detects if we failed to completely legalize the IR.
+    if (failed(applyPartialConversion(getOperation(), target, {})))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sve::createLegalizeVectorStoragePass() {
+  return std::make_unique<LegalizeVectorStorage>();
+}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir b/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir
new file mode 100644
index 000000000000000..fda8e2e0fab9618
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir
@@ -0,0 +1,160 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -arm-sve-legalize-vector-storage -split-input-file -verify-diagnostics | FileCheck %s
+
+/// This tests the basic functionality of the -arm-sve-legalize-vector-storage pass.
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1(
+// CHECK-SAME:                                         %[[MASK:.*]]: vector<[1]xi1>)
+func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> {
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<[1]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<[1]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[1]xi1>
+  %reload = memref.load %alloca[] : memref<vector<[1]xi1>>
+  // CHECK-NEXT: return %[[MASK]] : vector<[1]xi1>
+  return %reload : vector<[1]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1(
+// CHECK-SAME:                                         %[[MASK:.*]]: vector<[2]xi1>)
+func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> {
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<[2]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<[2]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[2]xi1>
+  %reload = memref.load %alloca[] : memref<vector<[2]xi1>>
+  // CHECK-NEXT: return %[[MASK]] : vector<[2]xi1>
+  return %reload : vector<[2]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1(
+// CHECK-SAME:                                         %[[MASK:.*]]: vector<[4]xi1>)
+func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> {
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<[4]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[4]xi1>
+  %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+  // CHECK-NEXT: return %[[MASK]] : vector<[4]xi1>
+  return %reload : vector<[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1(
+// CHECK-SAME:                                         %[[MASK:.*]]: vector<[8]xi1>)
+func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> {
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<[8]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
+  %reload = memref.load %alloca[] : memref<vector<[8]xi1>>
+  // CHECK-NEXT: return %[[MASK]] : vector<[8]xi1>
+  return %reload : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_2d_mask_and_reload_slice(
+// CHECK-SAME:                                  %[[MASK:.*]]: vector<3x[8]xi1>)
+func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> {
+  // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<3x[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<3x[8]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<3x[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<3x[8]xi1>>
+  // CHECK-NEXT: %[[UNPACK:.*]] = vector.type_cast %[[ALLOCA]] : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
+  %unpack = vector.type_cast %alloca : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[UNPACK]][%[[C0]]] : memref<3xvector<[16]xi1>>
+  // CHECK-NEXT: %[[SLICE:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
+  %slice = memref.load %unpack[%c0] : memref<3xvector<[8]xi1>>
+  // CHECK-NEXT: return %[[SLICE]] : vector<[8]xi1>
+  return %slice : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @set_sve_alloca_alignment
+func.func @set_sve_alloca_alignment() {
+  // CHECK-COUNT-6: alignment = 1
+  %a1 = memref.alloca() : memref<vector<[32]xi8>>
+  %a2 = memref.alloca() : memref<vector<[16]xi8>>
+  %a3 = memref.alloca() : memref<vector<[8]xi8>>
+  %a4 = memref.alloca() : memref<vector<[4]xi8>>
+  %a5 = memref.alloca() : memref<vector<[2]xi8>>
+  %a6 = memref.alloca() : memref<vector<[1]xi8>>
+
+  // CHECK-COUNT-6: alignment = 2
+  %b1 = memref.alloca() : memref<vector<[32]xi16>>
+  %b2 = memref.alloca() : memref<vector<[16]xi16>>
+  %b3 = memref.alloca() : memref<vector<[8]xi16>>
+  %b4 = memref.alloca() : memref<vector<[4]xi16>>
+  %b5 = memref.alloca() : memref<vector<[2]xi16>>
+  %b6 = memref.alloca() : memref<vector<[1]xi16>>
+
+  // CHECK-COUNT-6: alignment = 4
+  %c1 = memref.alloca() : memref<vector<[32]xi32>>
+  %c2 = memref.alloca() : memref<vector<[16]xi32>>
+  %c3 = memref.alloca() : memref<vector<[8]xi32>>
+  %c4 = memref.alloca() : memref<vector<[4]xi32>>
+  %c5 = memref.alloca() : memref<vector<[2]xi32>>
+  %c6 = memref.alloca() : memref<vector<[1]xi32>>
+
+  // CHECK-COUNT-6: alignment = 8
+  %d1 = memref.alloca() : memref<vector<[32]xi64>>
+  %d2 = memref.alloca() : memref<vector<[16]xi64>>
+  %d3 = memref.alloca() : memref<vector<[8]xi64>>
+  %d4 = memref.alloca() : memref<vector<[4]xi64>>
+  %d5 = memref.alloca() : memref<vector<[2]xi64>>
+  %d6 = memref.alloca() : memref<vector<[1]xi64>>
+
+  // CHECK-COUNT-6: alignment = 4
+  %e1 = memref.alloca() : memref<vector<[32]xf32>>
+  %e2 = memref.alloca() : memref<vector<[16]xf32>>
+  %e3 = memref.alloca() : memref<vector<[8]xf32>>
+  %e4 = memref.alloca() : memref<vector<[4]xf32>>
+  %e5 = memref.alloca() : memref<vector<[2]xf32>>
+  %e6 = memref.alloca() : memref<vector<[1]xf32>>
+
+  // CHECK-COUNT-6: alignment = 8
+  %f1 = memref.alloca() : memref<vector<[32]xf64>>
+  %f2 = memref.alloca() : memref<vector<[16]xf64>>
+  %f3 = memref.alloca() : memref<vector<[8]xf64>>
+  %f4 = memref.alloca() : memref<vector<[4]xf64>>
+  %f5 = memref.alloca() : memref<vector<[2]xf64>>
+  %f6 = memref.alloca() : memref<vector<[1]xf64>>
+
+  "prevent.dce"(
+    %a1, %a2, %a3, %a4, %a5, %a6,
+    %b1, %b2, %b3, %b4, %b5, %b6,
+    %c1, %c2, %c3, %c4, %c5, %c6,
+    %d1, %d2, %d3, %d4, %d5, %d6,
+    %e1, %e2, %e3, %e4, %e5, %e6,
+    %f1, %f2, %f3, %f4, %f5, %f6)
+    : (memref<vector<[32]xi8>>, memref<vector<[16]xi8>>, memref<vector<[8]xi8>>, memref<vector<[4]xi8>>, memref<vector<[2]xi8>>, memref<vector<[1]xi8>>,
+       memref<vector<[32]xi16>>, memref<vector<[16]xi16>>, memref<vector<[8]xi16>>, memref<vector<[4]xi16>>, memref<vector<[2]xi16>>, memref<vector<[1]xi16>>,
+       memref<vector<[32]xi32>>, memref<vector<[16]xi32>>, memref<vector<[8]xi32>>, memref<vector<[4]xi32>>, memref<vector<[2]xi32>>, memref<vector<[1]xi32>>,
+       memref<vector<[32]xi64>>, memref<vector<[16]xi64>>, memref<vector<[8]xi64>>, memref<vector<[4]xi64>>, memref<vector<[2]xi64>>, memref<vector<[1]xi64>>,
+       memref<vector<[32]xf32>>, memref<vector<[16]xf32>>, memref<vector<[8]xf32>>, memref<vector<[4]xf32>>, memref<vector<[2]xf32>>, memref<vector<[1]xf32>>,
+       memref<vector<[32]xf64>>, memref<vector<[16]xf64>>, memref<vector<[8]xf64>>, memref<vector<[4]xf64>>, memref<vector<[2]xf64>>, memref<vector<[1]xf64>>) -> ()
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir
new file mode 100644
index 000000000000000..260cd1eabe7dae1
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir
@@ -0,0 +1,121 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -arm-sve-legalize-vector-storage -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd -e=entry -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+/// This tests basic functionality of arrays of scalable vectors, which in MLIR
+/// are vectors with a single trailing scalable dimension. This test requires
+/// the -arm-sve-legalize-vector-storage pass to ensure the loads/stores done
+/// here are be legal for the LLVM backend.
+
+func.func @read_and_print_2d_vector(%memref: memref<3x?xf32>)  {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim = memref.dim %memref, %c1 : memref<3x?xf32>
+  %mask = vector.create_mask %c2, %dim : vector<3x[8]xi1>
+  %vector = vector.transfer_read %memref[%c0,%c0], %cst, %mask {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[8]xf32>
+
+  /// TODO: Support vector.print for arrays of scalable vectors.
+  %row0 = vector.extract %vector[0] : vector<[8]xf32> from vector<3x[8]xf32>
+  %row1 = vector.extract %vector[1] : vector<[8]xf32> from vector<3x[8]xf32>
+  %row2 = vector.extract %vector[2] : vector<[8]xf32> from vector<3x[8]xf32>
+
+  /// Print each of the vectors. This only checks the first eight elements (which
+  /// works for all vscale >= 1).
+
+  // CHECK-LABEL: TEST 1
+  vector.print str "TEST 1 (print and read 2D arrays of scalable vectors)"
+  // CHECK: ( 8, 8, 8, 8, 8, 8, 8, 8
+  vector.print %row0 : vector<[8]xf32>
+  // CHECK: ( 8, 8, 8, 8, 8, 8, 8, 8
+  vector.print %row1 : vector<[8]xf32>
+  /// This last row is all zero due to our mask.
+  // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0
+  vector.print %row2 : vector<[8]xf32>
+
+  return
+}
+
+func.func @print_1x2xVSCALExf32(%vector: vector<1x2x[4]xf32>) {
+  /// TODO: Support vector.print for arrays of scalable vectors.
+  %slice0 = vector.extract %vector[0, 1] : vector<[4]xf32> from vector<1x2x[4]xf32>
+  %slice1 = vector.extract %vector[0, 1] : vector<[4]xf32> from vector<1x2x[4]xf32>
+  vector.print %slice0 : vector<[4]xf32>
+  vector.print %slice1 : vector<[4]xf32>
+  return
+}
+
+func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 2 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32>
+  %dim_b = memref.dim %b, %c2 : memref<1x2x?xf32>
+  %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1>
+  %mask_b = vector.create_mask %c2, %c3, %dim_b : vector<1x2x[4]xi1>
+
+  vector.print str "TEST 2 (reading and adding two 3D arrays of scalable vectors)"
+
+  /// Print each of the vectors. This only checks the first four elements (which
+  /// works for all vscale >= 1).
+
+  // CHECK-LABEL: Vector A
+  // CHECK-NEXT: ( 5, 5, 5, 5
+  // CHECK-NEXT: ( 5, 5, 5, 5
+  vector.print str "\nVector A"
+  %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+  func.call @print_1x2xVSCALExf32(%vector_a) : (vector<1x2x[4]xf32>) -> ()
+
+  vector.print str "\nVector B"
+  // CHECK-LABEL: Vector B
+  // CHECK-NEXT: ( 4, 4, 4, 4
+  // CHECK-NEXT: ( 4, 4, 4, 4
+  %vector_b = vector.transfer_read %b[%c0, %c0, %c0], %cst, %mask_b {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+  func.call @print_1x2xVSCALExf32(%vector_b) : (vector<1x2x[4]xf32>) -> ()
+
+  %sum = arith.addf %vector_a, %vector_b : vector<1x2x[4]xf32>
+  // CHECK-LABEL: Sum
+  // CHECK-NEXT: ( 9, 9, 9, 9
+  // CHECK-NEXT: ( 9, 9, 9, 9
+  vector.print str "\nSum"
+  func.call @print_1x2xVSCALExf32(%sum) : (vector<1x2x[4]xf32>) -> ()
+
+  return
+}
+
+func.func @entry() {
+  %vscale = vector.vscale
+
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %f32_8 = arith.constant 8.0 : f32
+  %f32_5 = arith.constant 5.0 : f32
+  %f32_4 = arith.constant 4.0 : f32
+
+  vector.print str "\n====================\n"
+
+  %test_1_memref_size = arith.muli %vscale, %c8 : index
+  %test_1_memref = memref.alloca(%test_1_memref_size) : memref<3x?xf32>
+
+  linalg.fill ins(%f32_8 : f32) outs(%test_1_memref :memref<3x?xf32>)
+
+  func.call @read_and_print_2d_vector(%test_1_memref) : (memref<3x?xf32>) -> ()
+
+  vector.print str "\n====================\n"
+
+  %test_2_memref_size = arith.muli %vscale, %c4 : index
+  %test_2_memref_a = memref.alloca(%test_2_memref_size) : memref<1x2x?xf32>
+  %test_2_memref_b = memref.alloca(%test_2_memref_size) : memref<1x2x?xf32>
+
+  linalg.fill ins(%f32_5 : f32) outs(%test_2_memref_a :memref<1x2x?xf32>)
+  linalg.fill ins(%f32_4 : f32) outs(%test_2_memref_b :memref<1x2x?xf32>)
+
+  func.call @add_arrays_of_scalable_vectors(
+    %test_2_memref_a, %test_2_memref_b) : (memref<1x2x?xf32>, memref<1x2x?xf32>) -> ()
+
+  vector.print str "\n====================\n"
+
+  return
+}

>From 8888576b4e27b4cf5b16afd6bf126e90ba42088f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 19 Oct 2023 10:17:54 +0000
Subject: [PATCH 3/5] [mlir][SVE] Add an e2e test for vectorization of
 linalg.matmul

Adds an end-2-end test for scalable vectorization of linalg.matmul. This
is the most basic case where the dimension along which we vectorize fits
perfectly within SVE registers. I will be extending this to more generic
cases in the forthcoming patches.

Depends on #68794.
---
 .../Dialect/Linalg/CPU/ArmSVE/matmul.mlir     | 77 +++++++++++++++++++
 1 file changed, 77 insertions(+)
 create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir

diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
new file mode 100644
index 000000000000000..9b2aaefc5d63136
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule \
+// RUN:   -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// RUN:   -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd -e=entry -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @printTestEnd() {
+  %0 = llvm.mlir.addressof @str_sve_end : !llvm.ptr<array<24 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<24 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @entry() {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %step = arith.constant 1 : index
+  %c0_f32 = arith.constant 0.0 : f32
+
+  %vscale = vector.vscale
+  %vl_fp = arith.muli %c4, %vscale : index
+  %A_alloc = bufferization.alloc_tensor(%c2, %c1) : tensor<?x?xf32>
+  %B_alloc = bufferization.alloc_tensor(%c1, %vl_fp) : tensor<?x?xf32>
+  %C_alloc = bufferization.alloc_tensor(%c2, %vl_fp) : tensor<?x?xf32>
+
+  %pi = arith.constant  3.14 : f32
+  %A = linalg.fill ins(%pi : f32) outs(%A_alloc : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %B = linalg.fill ins(%pi : f32) outs(%B_alloc : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %C_in = linalg.fill ins(%c0_f32 : f32) outs(%C_alloc : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  %C_out = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>) outs(%C_in: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  // There are at least 4 f32 elements in every SVE vector, i.e. 
+  //    * %vscale is >= 1. 
+  // For implementations with wider vectors, you should see more elements being
+  // printed.
+  // CHECK: {{\[}}[9.8596,   9.8596,   9.8596,   9.8596
+  // CHECK-NEXT: [9.8596,   9.8596,   9.8596,   9.8596
+
+  %xf = tensor.cast %C_out : tensor<?x?xf32> to tensor<*xf32>
+  call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+
+  // CHECK: SVE: END OF TEST OUTPUT
+  func.call @printTestEnd() : () -> ()
+
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
+  %func_op = get_parent_op %0 : (!transform.any_op) -> !transform.op<"func.func">
+  // The tile sizes match the output matrix sizes
+  %1, %loops:3 = transform.structured.tile_using_for %0 [2, [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+  %2 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
+  // The vector sizes match the output matrix sizes
+  // TOOD: Use variables to re-use "shared" sizes
+  transform.structured.vectorize %2 vector_sizes [2, [4], 1] : !transform.any_op
+
+  transform.apply_patterns to %func_op {
+    transform.apply_patterns.vector.reduction_to_contract
+    transform.apply_patterns.vector.transfer_permutation_patterns
+    transform.apply_patterns.vector.lower_masked_transfers
+  } : !transform.op<"func.func">
+  transform.apply_patterns to %func_op {
+    transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+    transform.apply_patterns.vector.lower_outerproduct
+  } : !transform.op<"func.func">
+}
+
+llvm.func @printCString(!llvm.ptr<i8>)
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)
+llvm.mlir.global internal constant @str_sve_end("SVE: END OF TEST OUTPUT\0A")

>From 2e5f33fdedb1a4f4c6c1e415fad6295d0198ac26 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 19 Oct 2023 12:38:24 +0000
Subject: [PATCH 4/5] fixup! [mlir][SVE] Add an e2e test for vectorization of
 linalg.matmul

Use vector.print instead of @printCString
---
 .../Dialect/Linalg/CPU/ArmSVE/matmul.mlir           | 13 +------------
 1 file changed, 1 insertion(+), 12 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
index 9b2aaefc5d63136..df6144bfdc019db 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
@@ -4,15 +4,6 @@
 // RUN: %mcr_aarch64_cmd -e=entry -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
 // RUN: FileCheck %s
 
-func.func @printTestEnd() {
-  %0 = llvm.mlir.addressof @str_sve_end : !llvm.ptr<array<24 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<24 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 func.func @entry() {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
@@ -45,7 +36,7 @@ func.func @entry() {
   call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
 
   // CHECK: SVE: END OF TEST OUTPUT
-  func.call @printTestEnd() : () -> ()
+  vector.print str "SVE: END OF TEST OUTPUT"
 
   return
 }
@@ -72,6 +63,4 @@ transform.sequence failures(propagate) {
   } : !transform.op<"func.func">
 }
 
-llvm.func @printCString(!llvm.ptr<i8>)
 func.func private @printMemrefF32(%ptr : tensor<*xf32>)
-llvm.mlir.global internal constant @str_sve_end("SVE: END OF TEST OUTPUT\0A")

>From 395644a72ac8d1e2f103534884de3bf4f15457b6 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 23 Oct 2023 07:26:40 +0000
Subject: [PATCH 5/5] fixup! [mlir][SVE] Add an e2e test for vectorization of
 linalg.matmul

Simplity check-lines
---
 .../Dialect/Linalg/CPU/ArmSVE/matmul.mlir        | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
index df6144bfdc019db..922c5d89e3217cf 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
@@ -25,17 +25,21 @@ func.func @entry() {
 
   %C_out = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>) outs(%C_in: tensor<?x?xf32>) -> tensor<?x?xf32>
 
-  // There are at least 4 f32 elements in every SVE vector, i.e. 
-  //    * %vscale is >= 1. 
-  // For implementations with wider vectors, you should see more elements being
-  // printed.
-  // CHECK: {{\[}}[9.8596,   9.8596,   9.8596,   9.8596
+  // CHECK-LABEL: SVE: START OF TEST OUTPUT
+  vector.print str "SVE: START OF TEST OUTPUT"
+
+  // There are at least 4 x f32 elements in every SVE vector, i.e. 
+  //    * %vscale >= 1. 
+  // Hence, when checking the outupt there will always be at least 4 elements
+  // in every row. For implementations with wider vectors, you should see more
+  // elements being printed.
+  // CHECK: [9.8596,   9.8596,   9.8596,   9.8596
   // CHECK-NEXT: [9.8596,   9.8596,   9.8596,   9.8596
 
   %xf = tensor.cast %C_out : tensor<?x?xf32> to tensor<*xf32>
   call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
 
-  // CHECK: SVE: END OF TEST OUTPUT
+  // CHECK-NEXT: SVE: END OF TEST OUTPUT
   vector.print str "SVE: END OF TEST OUTPUT"
 
   return



More information about the Mlir-commits mailing list