[Mlir-commits] [mlir] [mlir][ArmSME] Add support for vector.transpose (PR #66760)

Cullen Rhodes llvmlistbot at llvm.org
Mon Sep 25 02:38:44 PDT 2023


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/66760

>From d40e18ca7eea336430f1714460948bb2d60f98e2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 19 Sep 2023 10:03:47 +0000
Subject: [PATCH 1/4] [mlir][ArmSME] Add support for vector.transpose

This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on #66758.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td |   3 +-
 .../VectorToArmSME/VectorToArmSME.cpp         |  76 ++++++++++-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |   1 +
 mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt     |   1 +
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 122 ++++++++++++++++++
 .../Vector/CPU/ArmSME/test-transpose.mlir     | 113 ++++++++++++++++
 6 files changed, 314 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 01a8670fd3817b0..d10cee5956d5e5f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -36,7 +36,8 @@ def ArmSME_Dialect : Dialect {
     https://developer.arm.com/documentation/ddi0616
     https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
   }];
-  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
+                           "memref::MemRefDialect"];
   let useDefaultAttributePrinterParser = 1;
 }
 
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..bb665375c00842c 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Casting.h"
 
@@ -239,11 +240,84 @@ struct BroadcastOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.transpose.
+///
+/// Stores the input tile to memory and reloads vertically.
+///
+/// Example:
+///
+///   %transposed_src = vector.transpose %src, [1, 0]
+///     : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
+///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///   %transposed_src = arm_sme.tile_load <ver>, %alloca[%c0, %c0]
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///
+/// NOTE: Tranposing via memory is obviously expensive, the current intention
+/// is to avoid the transpose if possible, this is therefore intended as a
+/// fallback and to provide base support for Vector ops. If it turns out
+/// transposes can't be avoided then this should be replaced with a more optimal
+/// implementation, perhaps with tile <-> vector (MOVA) ops.
+struct TransposeOpToArmSMELowering
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const final {
+    auto tileType = transposeOp.getResultVectorType();
+    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+      return failure();
+
+    SmallVector<int64_t> transp;
+    for (auto attr : transposeOp.getTransp())
+      transp.push_back(cast<IntegerAttr>(attr).getInt());
+
+    if (transp[0] != 1 && transp[1] != 0)
+      return failure();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = transposeOp.getLoc();
+
+    // Allocate buffer to store input tile to.
+    Value vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    Value minTileSlices = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
+    Value c0 =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
+    Value numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
+    auto bufferType =
+        MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
+                        tileType.getElementType());
+    auto buffer = rewriter.create<memref::AllocaOp>(
+        loc, bufferType, ValueRange{numTileSlices, numTileSlices});
+
+    Value input = transposeOp.getVector();
+
+    // Store input tile.
+    auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
+        loc, input, arm_sme::TileSliceLayout::Horizontal, buffer,
+        ValueRange{c0, c0});
+
+    // Reload input tile vertically.
+    rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
+        transposeOp, tileType, arm_sme::TileSliceLayout::Vertical,
+        tileStoreOp.getBase(), tileStoreOp.getIndices());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
   patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
                VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
-               BroadcastOpToArmSMELowering>(&ctx);
+               BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
 }
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 25fed2c477a1886..101cb750f4a6f30 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 85f90a8303d466f..50cfd9bf2f27810 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRMemRefDialect
   MLIRSCFDialect
   MLIRSideEffectInterfaces
   MLIRVectorDialect
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index cb35de11ab5b3ed..c4b0c6d495137e3 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
+// =============================================================================
+// vector.transfer_write
+// =============================================================================
+
 // CHECK-LABEL: func.func @transfer_write_2d_i8(
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi8>) {
@@ -215,3 +219,121 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
   "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
   return
 }
+
+// =============================================================================
+// vector.transpose
+// =============================================================================
+
+// -----
+
+// CHECK-LABEL:   func.func @transpose_i8(
+// CHECK-SAME:                            %[[TILE:.*]]: vector<[16]x[16]xi8>)
+// CHECK:           %[[C16:.*]] = arith.constant 16 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[VSCALE:.*]] = vector.vscale
+// CHECK:           %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
+// CHECK:           %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
+// CHECK:           arm_sme.tile_store %[[TILE]], <hor>, %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK:           arm_sme.tile_load <ver>, %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[16]x[16]xi8> to vector<[16]x[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i16
+// CHECK: arith.constant 8
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i32
+// CHECK: arith.constant 4
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i64
+// CHECK: arith.constant 2
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i128
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref<?x?xi128>
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128>
+  "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_f16
+// CHECK: arith.constant 8
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_bf16
+// CHECK: arith.constant 8
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_f32
+// CHECK: arith.constant 4
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_f64
+// CHECK: arith.constant 2
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
new file mode 100644
index 000000000000000..4350abbd13eca75
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -0,0 +1,113 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:   -e %{entry_point} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+func.func @printTileBegin() {
+  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @printTileEnd() {
+  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1_i32 = arith.constant 1 : i32
+
+  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
+  %vscale = vector.vscale
+  %min_elts_s = arith.constant 4 : index
+  %svl_s = arith.muli %min_elts_s, %vscale : index
+  %za_s_size = arith.muli %svl_s, %svl_s : index
+
+  // Allocate memory.
+  %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+  %mem2 = memref.alloca(%za_s_size) : memref<?xi32>
+
+  // Fill each "row" of "mem1" with row number.
+  //
+  // For example, assuming an SVL of 128-bits:
+  //
+  //   0, 0, 0, 0
+  //   1, 1, 1, 1
+  //   2, 2, 2, 2
+  //   3, 3, 3, 3
+  //
+  %init_0 = arith.constant 0 : i32
+  scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+    %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
+    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    %val_next = arith.addi %val, %c1_i32 : i32
+    scf.yield %val_next : i32
+  }
+
+  // Load tile from "mem1".
+  %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Transpose tile.
+  %transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+
+  // Store tile back to "mem2" to print.
+  // TODO: Replace this with vector.print when
+  // https://github.com/llvm/llvm-project/pull/66691 lands.
+  vector.store %transposed_tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xi32.
+  //
+  // CHECK:      TILE BEGIN
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 3, 3, 3, 3
+  // CHECK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.print %tileslice : vector<[4]xi32>
+  }
+  func.call @printTileEnd() : () -> ()
+
+  // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xi32.
+  //
+  // CHECK:      TILE BEGIN
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.print %tileslice : vector<[4]xi32>
+  }
+  func.call @printTileEnd() : () -> ()
+
+  return
+}
+
+llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
+llvm.mlir.global internal constant @str_tile_end("TILE END\0A")

>From 0e213dbdfdcd235de6ade98b1bd15f480a7f20a5 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 20 Sep 2023 11:18:15 +0000
Subject: [PATCH 2/4] check for layout and data type in vector-ops-to-sme.mlir

---
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 32 +++++++++----------
 1 file changed, 16 insertions(+), 16 deletions(-)

diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index c4b0c6d495137e3..4ac538fd37c5c5e 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -245,8 +245,8 @@ func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
 
 // CHECK-LABEL: @transpose_i16
 // CHECK: arith.constant 8
-// CHECK: arm_sme.tile_store
-// CHECK: arm_sme.tile_load
+// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
 func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16>
   "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
@@ -257,8 +257,8 @@ func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
 
 // CHECK-LABEL: @transpose_i32
 // CHECK: arith.constant 4
-// CHECK: arm_sme.tile_store
-// CHECK: arm_sme.tile_load
+// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
 func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
   "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
@@ -269,8 +269,8 @@ func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
 
 // CHECK-LABEL: @transpose_i64
 // CHECK: arith.constant 2
-// CHECK: arm_sme.tile_store
-// CHECK: arm_sme.tile_load
+// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
 func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64>
   "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
@@ -282,8 +282,8 @@ func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
 // CHECK-LABEL: @transpose_i128
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref<?x?xi128>
-// CHECK: arm_sme.tile_store
-// CHECK: arm_sme.tile_load
+// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
 func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128>
   "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
@@ -294,8 +294,8 @@ func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
 
 // CHECK-LABEL: @transpose_f16
 // CHECK: arith.constant 8
-// CHECK: arm_sme.tile_store
-// CHECK: arm_sme.tile_load
+// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
 func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
@@ -306,8 +306,8 @@ func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
 
 // CHECK-LABEL: @transpose_bf16
 // CHECK: arith.constant 8
-// CHECK: arm_sme.tile_store
-// CHECK: arm_sme.tile_load
+// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
 func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
@@ -318,8 +318,8 @@ func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
 
 // CHECK-LABEL: @transpose_f32
 // CHECK: arith.constant 4
-// CHECK: arm_sme.tile_store
-// CHECK: arm_sme.tile_load
+// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
 func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
@@ -330,8 +330,8 @@ func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
 
 // CHECK-LABEL: @transpose_f64
 // CHECK: arith.constant 2
-// CHECK: arm_sme.tile_store
-// CHECK: arm_sme.tile_load
+// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
 func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()

>From 544e77e7c21122ee3b086615ba56162b6f71b2d5 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 20 Sep 2023 11:25:26 +0000
Subject: [PATCH 3/4] add comment for bailing unless true 2-D matrix transpose

---
 mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index bb665375c00842c..59d4ec368da2bbf 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -276,6 +276,7 @@ struct TransposeOpToArmSMELowering
     for (auto attr : transposeOp.getTransp())
       transp.push_back(cast<IntegerAttr>(attr).getInt());
 
+    // Bail unless this is a true 2-D matrix transpose.
     if (transp[0] != 1 && transp[1] != 0)
       return failure();
 

>From fed0df90ec22784b9f69bf8aab8a3f14154081bf Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 25 Sep 2023 09:37:57 +0000
Subject: [PATCH 4/4] resolve rebase conflicts

---
 .../VectorToArmSME/VectorToArmSME.cpp         | 11 ++---
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 48 +++++++++----------
 2 files changed, 29 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 59d4ec368da2bbf..264539b85c0ee23 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -254,7 +254,7 @@ struct BroadcastOpToArmSMELowering
 ///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
 ///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
 ///     : memref<?x?xi32>, vector<[4]x[4]xi32>
-///   %transposed_src = arm_sme.tile_load <ver>, %alloca[%c0, %c0]
+///   %transposed_src = arm_sme.tile_load %alloca[%c0, %c0], <vertical>
 ///     : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///
 /// NOTE: Tranposing via memory is obviously expensive, the current intention
@@ -277,7 +277,7 @@ struct TransposeOpToArmSMELowering
       transp.push_back(cast<IntegerAttr>(attr).getInt());
 
     // Bail unless this is a true 2-D matrix transpose.
-    if (transp[0] != 1 && transp[1] != 0)
+    if (transp[0] != 1 || transp[1] != 0)
       return failure();
 
     OpBuilder::InsertionGuard g(rewriter);
@@ -302,13 +302,12 @@ struct TransposeOpToArmSMELowering
 
     // Store input tile.
     auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
-        loc, input, arm_sme::TileSliceLayout::Horizontal, buffer,
-        ValueRange{c0, c0});
+        loc, input, buffer, ValueRange{c0, c0});
 
     // Reload input tile vertically.
     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
-        transposeOp, tileType, arm_sme::TileSliceLayout::Vertical,
-        tileStoreOp.getBase(), tileStoreOp.getIndices());
+        transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
+        arm_sme::TileSliceLayout::Vertical);
 
     return success();
   }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 4ac538fd37c5c5e..a64753578a1c861 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
-// =============================================================================
+//===----------------------------------------------------------------------===//
 // vector.transfer_write
-// =============================================================================
+//===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: func.func @transfer_write_2d_i8(
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
@@ -169,9 +169,9 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
   return
 }
 
-// =============================================================================
+//===----------------------------------------------------------------------===//
 // vector.broadcast
-// =============================================================================
+//===----------------------------------------------------------------------===//
 
 // -----
 
@@ -220,9 +220,9 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
   return
 }
 
-// =============================================================================
+//===----------------------------------------------------------------------===//
 // vector.transpose
-// =============================================================================
+//===----------------------------------------------------------------------===//
 
 // -----
 
@@ -233,8 +233,8 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
 // CHECK:           %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
-// CHECK:           arm_sme.tile_store %[[TILE]], <hor>, %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
-// CHECK:           arm_sme.tile_load <ver>, %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK:           arm_sme.tile_store %[[TILE]], %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK:           arm_sme.tile_load %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
 func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[16]x[16]xi8> to vector<[16]x[16]xi8>
   "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
@@ -245,8 +245,8 @@ func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
 
 // CHECK-LABEL: @transpose_i16
 // CHECK: arith.constant 8
-// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
-// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
 func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16>
   "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
@@ -257,8 +257,8 @@ func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
 
 // CHECK-LABEL: @transpose_i32
 // CHECK: arith.constant 4
-// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
-// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
 func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
   "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
@@ -269,8 +269,8 @@ func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
 
 // CHECK-LABEL: @transpose_i64
 // CHECK: arith.constant 2
-// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
-// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
 func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64>
   "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
@@ -282,8 +282,8 @@ func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
 // CHECK-LABEL: @transpose_i128
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref<?x?xi128>
-// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
-// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
 func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128>
   "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
@@ -294,8 +294,8 @@ func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
 
 // CHECK-LABEL: @transpose_f16
 // CHECK: arith.constant 8
-// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
-// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
 func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
@@ -306,8 +306,8 @@ func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
 
 // CHECK-LABEL: @transpose_bf16
 // CHECK: arith.constant 8
-// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
-// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
 func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
@@ -318,8 +318,8 @@ func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
 
 // CHECK-LABEL: @transpose_f32
 // CHECK: arith.constant 4
-// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
-// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
 func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
@@ -330,8 +330,8 @@ func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
 
 // CHECK-LABEL: @transpose_f64
 // CHECK: arith.constant 2
-// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
-// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
 func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()



More information about the Mlir-commits mailing list