[Mlir-commits] [mlir] [mlir][ArmSME] Add support for vector.transfer_read with transpose (PR #67527)

Cullen Rhodes llvmlistbot at llvm.org
Wed Sep 27 01:10:21 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/67527

This patch adds support for lowering a vector.transfer_read with a transpose permutation map to a vertical tile load, for example:
```
vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)
```
is converted to:
```
arm_sme.tile_load ... <vertical>
```
On SME the transpose can be done in-flight, rather than as a separate operation as in the generic `TransferReadPermutationLowering`, which would do the following:
```
%0 = vector.transfer_read ...
vector.transpose %0, [1, 0] ...
```
The lowering doesn't support masking yet and the transfer_read must be in-bounds. It also intentionally doesn't handle simple loads as transfer_write currently does, as the generic
`TransferReadToVectorLoadLowering` can lower these to simple vector.load ops, which can already be lowered to ArmSME.

A subsequent patch will update the existing transfer_write lowering, this is a separate patch as there is currently no lowering for vector.transfer_read.

>From 5f841ef681d55f7e6b036c5036cf4cdb33b778c4 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 26 Sep 2023 11:01:00 +0000
Subject: [PATCH] [mlir][ArmSME] Add support for vector.transfer_read with
 transpose

This patch adds support for lowering a vector.transfer_read with a
transpose permutation map to a vertical tile load, for example:

  vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)

is converted to:

  arm_sme.tile_load ... <vertical>

On SME the transpose can be done in-flight, rather than as a separate
operation as in the TransferReadPermutationLowering, which would do the
following:

  %0 = vector.transfer_read ...
  vector.transpose %0, [1, 0] ...

The lowering doesn't support masking yet and the transfer_read must be
in-bounds. It also intentionally doesn't handle simple loads as
transfer_write currently does, as the generic
TransferReadToVectorLoadLowering can lower these to simple vector.load
ops, which can already be lowered to ArmSME.

A subsequent patch will update the existing transfer_write lowering,
this is a separate patch as there is currently no lowering for
vector.transfer_read.
---
 .../VectorToArmSME/VectorToArmSME.cpp         |  64 +++++-
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 207 ++++++++++++++++++
 2 files changed, 270 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 264539b85c0ee23..16eb60b2bd9bd69 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -59,6 +59,67 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
 
 namespace {
 
+/// Conversion pattern for vector.transfer_read op with transpose permutation
+/// map to vertical arm_sme.tile_load (in-flight transpose).
+///
+///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)
+///
+/// is converted to:
+///
+///   arm_sme.tile_load ... <vertical>
+struct TransferReadPermutationToArmSMELowering
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
+                                PatternRewriter &rewriter) const final {
+    // The permutation map must have two results.
+    if (transferReadOp.getTransferRank() != 2)
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "not a 2 result permutation map");
+
+    auto vectorType = transferReadOp.getVectorType();
+    if (!arm_sme::isValidSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "not a valid vector type for SME");
+
+    if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
+      return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
+
+    if (transferReadOp.getMask())
+      // TODO: support masking.
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "masking not yet supported");
+
+    // Out-of-bounds dims are not supported.
+    if (transferReadOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "not inbounds transfer read");
+
+    AffineMap map = transferReadOp.getPermutationMap();
+
+    // Permutation map doesn't perform permutation, can be lowered to
+    // vector.load by TransferReadToVectorLoadLowering and then
+    // arm_sme.tile_load by VectorLoadToArmSMELowering.
+    if (map.isIdentity())
+      return rewriter.notifyMatchFailure(
+          transferReadOp, "map is an identity, apply another pattern");
+
+    AffineExpr d0, d1;
+    bindDims(transferReadOp.getContext(), d0, d1);
+    if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0},
+                              transferReadOp.getContext()))
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "not true 2-D matrix transpose");
+
+    rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
+        transferReadOp, vectorType, transferReadOp.getSource(),
+        transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical);
+
+    return success();
+  }
+};
+
 /// Conversion pattern for vector.transfer_write.
 ///
 ///   vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
@@ -317,7 +378,8 @@ struct TransposeOpToArmSMELowering
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
+  patterns.add<TransferReadPermutationToArmSMELowering,
+               TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
                VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
                BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a64753578a1c861..d8bbdea5a06e946 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,5 +1,212 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// vector.transfer_read (with in-flight transpose)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i8
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i8
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i32
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i64
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i128
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i128
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_bf16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : bf16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f32
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f64
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__non_memref_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__unsupported_mask
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+/// transfer_read with identity map should be lowered to vector.load by
+/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by
+/// VectorLoadToArmSMELowering.
+
+// CHECK-LABEL: @transfer_read_2d__non_permuting_map
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_permuting_map(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__non_transpose
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_transpose(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, 0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__out_of_bounds
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__out_of_bounds(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list