[Mlir-commits] [mlir] [mlir][ArmSME] Add support for vector.transpose (PR #66760)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 19 03:41:02 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on https://github.com/llvm/llvm-project/pull/66758.

---

Patch is 110.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66760.diff


15 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+5) 
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td (+79-40) 
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt (+7) 
- (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+9-9) 
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+83-7) 
- (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+13) 
- (modified) mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt (+2) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+71-22) 
- (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+31-11) 
- (modified) mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir (+3-3) 
- (added) mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir (+401) 
- (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+450-90) 
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+130-8) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir (+110) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir (+113) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index d1ed02abfd5c552..f947fc8fe1631b8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -21,6 +21,11 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
+
 #include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..884773aa559bcf9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
 #ifndef ARMSME_OPS
 #define ARMSME_OPS
 
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -35,7 +36,9 @@ def ArmSME_Dialect : Dialect {
     https://developer.arm.com/documentation/ddi0616
     https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
   }];
-  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
+                           "memref::MemRefDialect"];
+  let useDefaultAttributePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -83,6 +86,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
                   "::llvm::cast<VectorType>($_self).getElementType())"
                   ".getWidth())">;
 
+//===----------------------------------------------------------------------===//
+// ArmSME attr definitions
+//===----------------------------------------------------------------------===//
+
+def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
+  I32EnumAttrCase<"Horizontal", 0, "hor">,
+  I32EnumAttrCase<"Vertical", 1, "ver">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies the layout of a tile slice in a tile.
+def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
+                                          "layout"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -239,27 +260,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   let summary = "Tile load operation";
   let description = [{
     Loads a 2D SME "virtual tile" from memory defined by a base and indices,
-    with the shape defined by the 2D scalable vector type of the result tile.
-    The slice of memory must be contiguous. The memref must be either rank 1 or
-    rank 2 with dynamic dimensions, since the operation is scalable, and the
-    element type must be a scalar that matches the element type of the result.
+    with the shape defined by the 2D scalable vector type of the result tile. A
+    tile slice layout attribute specifies whether the slices of the tile being
+    loaded are horizontal or vertical. The slice of memory must be contiguous.
+    The memref must be either rank 1 or rank 2 with dynamic dimensions, since
+    the operation is scalable, and the element type must be a scalar that
+    matches the element type of the result.
+
+    The default tile slice layout when lowering from higher-level dialects is
+    horizontal.
 
-    Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+    Example 1: Load an 8-bit element ZA tile with horizontal layout from memory (ZA0.B).
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
     ```
 
-    Example 2: Load a FP 32-bit element ZA tile from memory.
+    Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile = arm_sme.tile_load <ver>, %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
-    Example 3: Load a 128-bit element ZA tile from memory.
+    Example 3: Load a 128-bit element ZA tile with horizontal layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
       Variadic<Index>:$indices);
   let results = (outs SMETile:$result);
@@ -274,7 +301,8 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   }];
 
   let assemblyFormat =
-      "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+      "$layout `,` $base `[` $indices `]` attr-dict "
+        "`:` type($base) `,` type($result)";
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store"> {
@@ -282,27 +310,32 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
   let description = [{
     Stores a 2D SME "virtual tile" to memory defined by a base and indices,
     with the shape defined by the 2D scalable vector type of the tile being
-    stored. The slice of memory must be contiguous. The memref must be either
-    rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
-    and the element type must be a scalar that matches the element type of the
-    result.
+    stored. A tile slice layout attribute specifies whether the slices of the
+    tile being stored are horizontal or vertical. The slice of memory must be
+    contiguous.  The memref must be either rank 1 or rank 2 with dynamic
+    dimensions, since the operation is scalable, and the element type must be a
+    scalar that matches the element type of the result.
+
+    The default tile slice layout when lowering from higher-level dialects is
+    horizontal.
 
-    Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+    Example 1: Store an 8-bit element ZA tile with horizontal layout to memory (ZA0.B).
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
     ```
 
-    Example 2: Store a FP 32-bit element ZA tile to memory.
+    Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.tile_store %tile, <ver>, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
-    Example 3: Store a 128-bit element ZA tile to memory.
+    Example 3: Store a 128-bit element ZA tile with horizontal layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$valueToStore,
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
       Variadic<Index>:$indices);
   let extraClassDeclaration = [{
@@ -314,8 +347,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     }
   }];
 
-  let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
-                       "`:` type($base) `,` type($valueToStore)";
+  let assemblyFormat =
+    "$valueToStore `,` $layout `,` $base `[` $indices `]` attr-dict "
+      "`:` type($base) `,` type($valueToStore)";
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,29 +360,32 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
     slice is defined by the dimension of the 2D scalable vector type pointed by
     the index. A tile slice index describes where in the input tile the tile
-    slice is loaded to. The updated tile is returned as the result.
+    slice is loaded to. A tile slice layout attribute specifies whether the
+    tile slice being loaded at the given index is horizontal or vertical. The
+    updated tile is returned as the result.
 
     The slice of memory read is defined by a base and indices and must be
     contiguous. The memref must be either rank 1 or rank 2, have dynamic
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
-    Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+    Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile_update = arm_sme.load_tile_slice <hor>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
     ```
 
-    Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+    Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
-    Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+    Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to load from">:$base,
       SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
   let results = (outs SMETile:$result);
@@ -363,7 +400,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+    $layout `,` $base `[` $indices `]` `,` $tile `,` $tile_slice_index
       attr-dict `:` type($base) `,` type($result)
   }];
 }
@@ -374,29 +411,31 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
     Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
     slice is defined by the dimension of the 2D scalable vector type pointed by
     the index. A tile slice index describes where in the input tile the tile
-    slice is stored from.
+    slice is stored from. A tile slice layout attribute specifies whether the
+    tile slice being stored from the given index is horizontal or vertical.
 
     The slice of memory written is defined by a base and indices and must be
     contiguous. The memref must be either rank 1 or rank 2, have dynamic
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the input tile.
 
-    Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+    Example 1: Store vector<[16]xi8> horizontal tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
     ```
 
-    Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+    Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
-    Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+    Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
       Variadic<Index>:$indices);
   let extraClassDeclaration = [{
@@ -409,7 +448,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
   }];
 
   let assemblyFormat = [{
-    $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+    $tile `,` $tile_slice_index `,` $layout `,` $base `[` $indices `]`
       attr-dict `:` type($base) `,` type($tile)
   }];
 }
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index d20ee65e62e7dc0..7afd0d014541687 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -4,3 +4,10 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
 set(LLVM_TARGET_DEFINITIONS ArmSME.td)
 mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+
+mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
+mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
+add_public_tablegen_target(MLIRArmSMEEnumsIncGen)
+add_dependencies(mlir-headers MLIRArmSMEEnumsIncGen)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..86cabe67f2695f1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -54,7 +54,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///
 ///  BEFORE:
 ///  ```mlir
-///  %tile = arm_sme.tile_load %src[%c0, %c0] :
+///  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] :
 ///    memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  ```
 ///
@@ -68,7 +68,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
+///    %tile_update = arm_sme.load_tile_slice <hor>, %src[%tile_slice_idx],
 ///      %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
@@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
-                                              tileLoadOp.getBase(), tile,
-                                              memrefIndices, tileSliceIndex);
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getLayout(), tileLoadOp.getBase(), tile,
+        memrefIndices, tileSliceIndex);
 
     rewriter.setInsertionPointAfter(forOp);
 
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///
 ///  BEFORE:
 ///  ```mlir
-///  arm_sme.tile_store %tile, %dest[%c0, %c0]
+///  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0]
 ///    : memref<?x?xi32>, vector<[4]x[4]xi32
 ///  ```
 ///
@@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-///    arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
-///      : memref<?x?xi32>, vector<[4]x[4]xi32>
+///    arm_sme.store_tile_slice %tile, %tile_slice_idx, <ver>,
+///      %dest[%tile_slice_idx] : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
                      numTileSlices, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
         tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        tileStoreOp.getBase(), memrefIndices);
+        tileStoreOp.getLayout(), tileStoreOp.getBase(), memrefIndices);
 
     return success();
   }
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..30c516ffbe1e900 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Casting.h"
 
@@ -65,8 +66,8 @@ namespace {
 ///
 /// is converted to:
 ///
-///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
-///                                                   vector<[16]x[16]xi8>
+///   arm_sme.tile_store %vector, <hor>, %source[%c0, %c0]
+///     : memref<?x?xi8>, vector<[16]x[16]xi8>
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -81,8 +82,8 @@ struct TransferWriteToArmSMELowering
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        writeOp, writeOp.getVector(), writeOp.getSource(),
-        writeOp.getIndices());
+        writeOp, writeOp.getVector(), arm_sme::TileSliceLayout::Horizontal,
+        writeOp.getSource(), writeOp.getIndices());
     return success();
   }
 };
@@ -97,7 +98,8 @@ struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
-        load, load.getVectorType(), load.getBase(), load.getIndices());
+        load, load.getVectorType(), arm_sme::TileSliceLayout::Horizontal,
+        load.getBase(), load.getIndices());
 
     return success();
   }
@@ -113,7 +115,8 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        store, store.getValueToStore(), store.getBase(), store.getIndices());
+        store, store.getValueToStore(), arm_sme::TileSliceLayout::Horizontal,
+        store.getBase(), store.getIndices());
 
     return success();
   }
@@ -239,11 +242,84 @@ struct BroadcastOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.transpose.
+///
+/// Stores the input tile to memory and reloads vertically.
+///
+/// Example:
+///
+///   %transposed_src = vector.transpose %src, [1, 0]
+///     : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
+///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///   %transposed_src = arm_sme.tile_load <ver>, %alloca[%c0, %c0]
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///
+/// NOTE: Tranposing via memory is obviously expensive, the current intention
+/// is to avoid the transpose if possible, this is therefore intended as a
+/// fallback and to provide base support for Vec...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/66760


More information about the Mlir-commits mailing list