[Mlir-commits] [mlir] 8e64e9c - [mlir][ArmSME] Add support for vector.transfer_read with transpose (#67527)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 28 02:54:25 PDT 2023


Author: Cullen Rhodes
Date: 2023-09-28T10:54:20+01:00
New Revision: 8e64e9c365443bf9790fb04064bf3f1d3ff0bafc

URL: https://github.com/llvm/llvm-project/commit/8e64e9c365443bf9790fb04064bf3f1d3ff0bafc
DIFF: https://github.com/llvm/llvm-project/commit/8e64e9c365443bf9790fb04064bf3f1d3ff0bafc.diff

LOG: [mlir][ArmSME] Add support for vector.transfer_read with transpose (#67527)

This patch adds support for lowering a vector.transfer_read with a
transpose permutation map to a vertical tile load, for example:

  vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)

is converted to:

  arm_sme.tile_load ... <vertical>

On SME the transpose can be done in-flight, rather than as a separate
operation as in the TransferReadPermutationLowering, which would do the
following:

  %0 = vector.transfer_read ...
  vector.transpose %0, [1, 0] ...

The lowering doesn't support masking yet and the transfer_read must be
in-bounds. It also intentionally doesn't handle simple loads as
transfer_write currently does, as the generic
TransferReadToVectorLoadLowering can lower these to simple vector.load
ops, which can already be lowered to ArmSME.

A subsequent patch will update the existing transfer_write lowering,
this is a separate patch as there is currently no lowering for
vector.transfer_read.

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 264539b85c0ee23..3a4866e00799404 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -59,6 +59,67 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
 
 namespace {
 
+/// Conversion pattern for vector.transfer_read op with transpose permutation
+/// map to vertical arm_sme.tile_load (in-flight transpose).
+///
+///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)
+///
+/// is converted to:
+///
+///   arm_sme.tile_load ... <vertical>
+struct TransferReadPermutationToArmSMELowering
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
+                                PatternRewriter &rewriter) const final {
+    // The permutation map must have two results.
+    if (transferReadOp.getTransferRank() != 2)
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "not a 2 result permutation map");
+
+    AffineMap map = transferReadOp.getPermutationMap();
+
+    // Permutation map doesn't perform permutation, can be lowered to
+    // vector.load by TransferReadToVectorLoadLowering and then
+    // arm_sme.tile_load by VectorLoadToArmSMELowering.
+    if (map.isIdentity())
+      return rewriter.notifyMatchFailure(
+          transferReadOp, "map is an identity, apply another pattern");
+
+    auto vectorType = transferReadOp.getVectorType();
+    if (!arm_sme::isValidSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "not a valid vector type for SME");
+
+    if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
+      return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
+
+    if (transferReadOp.getMask())
+      // TODO: support masking.
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "masking not yet supported");
+
+    // Out-of-bounds dims are not supported.
+    if (transferReadOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "not inbounds transfer read");
+
+    AffineExpr d0, d1;
+    bindDims(transferReadOp.getContext(), d0, d1);
+    if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0},
+                              transferReadOp.getContext()))
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "not true 2-D matrix transpose");
+
+    rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
+        transferReadOp, vectorType, transferReadOp.getSource(),
+        transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical);
+
+    return success();
+  }
+};
+
 /// Conversion pattern for vector.transfer_write.
 ///
 ///   vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
@@ -317,7 +378,8 @@ struct TransposeOpToArmSMELowering
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
+  patterns.add<TransferReadPermutationToArmSMELowering,
+               TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
                VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
                BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
 }

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a64753578a1c861..d8bbdea5a06e946 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,5 +1,212 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// vector.transfer_read (with in-flight transpose)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i8
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i8
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i32
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i64
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i128
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i128
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_bf16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : bf16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f32
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f64
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__non_memref_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__unsupported_mask
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+/// transfer_read with identity map should be lowered to vector.load by
+/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by
+/// VectorLoadToArmSMELowering.
+
+// CHECK-LABEL: @transfer_read_2d__non_permuting_map
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_permuting_map(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__non_transpose
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_transpose(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, 0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__out_of_bounds
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__out_of_bounds(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list