[Mlir-commits] [mlir] eaf1590 - [mlir][ArmSME] Add support for vector.transpose (#66760)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 25 04:15:17 PDT 2023


Author: Cullen Rhodes
Date: 2023-09-25T12:15:12+01:00
New Revision: eaf15900ff5b4c103b52acf7327e6e6f7e8b2ebe

URL: https://github.com/llvm/llvm-project/commit/eaf15900ff5b4c103b52acf7327e6e6f7e8b2ebe
DIFF: https://github.com/llvm/llvm-project/commit/eaf15900ff5b4c103b52acf7327e6e6f7e8b2ebe.diff

LOG: [mlir][ArmSME] Add support for vector.transpose (#66760)

This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on https://github.com/llvm/llvm-project/pull/66758.

Added: 
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
    mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 01a8670fd3817b0..d10cee5956d5e5f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -36,7 +36,8 @@ def ArmSME_Dialect : Dialect {
     https://developer.arm.com/documentation/ddi0616
     https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
   }];
-  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
+                           "memref::MemRefDialect"];
   let useDefaultAttributePrinterParser = 1;
 }
 

diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..264539b85c0ee23 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Casting.h"
 
@@ -239,11 +240,84 @@ struct BroadcastOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.transpose.
+///
+/// Stores the input tile to memory and reloads vertically.
+///
+/// Example:
+///
+///   %transposed_src = vector.transpose %src, [1, 0]
+///     : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
+///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///   %transposed_src = arm_sme.tile_load %alloca[%c0, %c0], <vertical>
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///
+/// NOTE: Tranposing via memory is obviously expensive, the current intention
+/// is to avoid the transpose if possible, this is therefore intended as a
+/// fallback and to provide base support for Vector ops. If it turns out
+/// transposes can't be avoided then this should be replaced with a more optimal
+/// implementation, perhaps with tile <-> vector (MOVA) ops.
+struct TransposeOpToArmSMELowering
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const final {
+    auto tileType = transposeOp.getResultVectorType();
+    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+      return failure();
+
+    SmallVector<int64_t> transp;
+    for (auto attr : transposeOp.getTransp())
+      transp.push_back(cast<IntegerAttr>(attr).getInt());
+
+    // Bail unless this is a true 2-D matrix transpose.
+    if (transp[0] != 1 || transp[1] != 0)
+      return failure();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = transposeOp.getLoc();
+
+    // Allocate buffer to store input tile to.
+    Value vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    Value minTileSlices = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
+    Value c0 =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
+    Value numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
+    auto bufferType =
+        MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
+                        tileType.getElementType());
+    auto buffer = rewriter.create<memref::AllocaOp>(
+        loc, bufferType, ValueRange{numTileSlices, numTileSlices});
+
+    Value input = transposeOp.getVector();
+
+    // Store input tile.
+    auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
+        loc, input, buffer, ValueRange{c0, c0});
+
+    // Reload input tile vertically.
+    rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
+        transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
+        arm_sme::TileSliceLayout::Vertical);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
   patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
                VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
-               BroadcastOpToArmSMELowering>(&ctx);
+               BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
 }

diff  --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 25fed2c477a1886..101cb750f4a6f30 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"

diff  --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 85f90a8303d466f..50cfd9bf2f27810 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRMemRefDialect
   MLIRSCFDialect
   MLIRSideEffectInterfaces
   MLIRVectorDialect

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index cb35de11ab5b3ed..a64753578a1c861 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// vector.transfer_write
+//===----------------------------------------------------------------------===//
+
 // CHECK-LABEL: func.func @transfer_write_2d_i8(
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi8>) {
@@ -165,9 +169,9 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
   return
 }
 
-// =============================================================================
+//===----------------------------------------------------------------------===//
 // vector.broadcast
-// =============================================================================
+//===----------------------------------------------------------------------===//
 
 // -----
 
@@ -215,3 +219,121 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
   "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
   return
 }
+
+//===----------------------------------------------------------------------===//
+// vector.transpose
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL:   func.func @transpose_i8(
+// CHECK-SAME:                            %[[TILE:.*]]: vector<[16]x[16]xi8>)
+// CHECK:           %[[C16:.*]] = arith.constant 16 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[VSCALE:.*]] = vector.vscale
+// CHECK:           %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
+// CHECK:           %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
+// CHECK:           arm_sme.tile_store %[[TILE]], %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK:           arm_sme.tile_load %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[16]x[16]xi8> to vector<[16]x[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i16
+// CHECK: arith.constant 8
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i32
+// CHECK: arith.constant 4
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i64
+// CHECK: arith.constant 2
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i128
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref<?x?xi128>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128>
+  "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_f16
+// CHECK: arith.constant 8
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_bf16
+// CHECK: arith.constant 8
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_f32
+// CHECK: arith.constant 4
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_f64
+// CHECK: arith.constant 2
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
new file mode 100644
index 000000000000000..4350abbd13eca75
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -0,0 +1,113 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:   -e %{entry_point} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+func.func @printTileBegin() {
+  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @printTileEnd() {
+  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1_i32 = arith.constant 1 : i32
+
+  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
+  %vscale = vector.vscale
+  %min_elts_s = arith.constant 4 : index
+  %svl_s = arith.muli %min_elts_s, %vscale : index
+  %za_s_size = arith.muli %svl_s, %svl_s : index
+
+  // Allocate memory.
+  %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+  %mem2 = memref.alloca(%za_s_size) : memref<?xi32>
+
+  // Fill each "row" of "mem1" with row number.
+  //
+  // For example, assuming an SVL of 128-bits:
+  //
+  //   0, 0, 0, 0
+  //   1, 1, 1, 1
+  //   2, 2, 2, 2
+  //   3, 3, 3, 3
+  //
+  %init_0 = arith.constant 0 : i32
+  scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+    %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
+    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    %val_next = arith.addi %val, %c1_i32 : i32
+    scf.yield %val_next : i32
+  }
+
+  // Load tile from "mem1".
+  %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Transpose tile.
+  %transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+
+  // Store tile back to "mem2" to print.
+  // TODO: Replace this with vector.print when
+  // https://github.com/llvm/llvm-project/pull/66691 lands.
+  vector.store %transposed_tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xi32.
+  //
+  // CHECK:      TILE BEGIN
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 3, 3, 3, 3
+  // CHECK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.print %tileslice : vector<[4]xi32>
+  }
+  func.call @printTileEnd() : () -> ()
+
+  // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xi32.
+  //
+  // CHECK:      TILE BEGIN
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.print %tileslice : vector<[4]xi32>
+  }
+  func.call @printTileEnd() : () -> ()
+
+  return
+}
+
+llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
+llvm.mlir.global internal constant @str_tile_end("TILE END\0A")


        


More information about the Mlir-commits mailing list