[Mlir-commits] [mlir] [mlir][ArmSME] Add support for vector.transfer_read with transpose (PR #67527)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 27 01:11:27 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This patch adds support for lowering a vector.transfer_read with a transpose permutation map to a vertical tile load, for example:
```
vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
```
is converted to:
```
arm_sme.tile_load ... <vertical>
```
On SME the transpose can be done in-flight, rather than as a separate operation as in the generic `TransferReadPermutationLowering`, which would do the following:
```
%0 = vector.transfer_read ...
vector.transpose %0, [1, 0] ...
```
The lowering doesn't support masking yet and the transfer_read must be in-bounds. It also intentionally doesn't handle simple loads as transfer_write currently does, as the generic
`TransferReadToVectorLoadLowering` can lower these to simple vector.load ops, which can already be lowered to ArmSME.
A subsequent patch will update the existing transfer_write lowering, this is a separate patch as there is currently no lowering for vector.transfer_read.
---
Full diff: https://github.com/llvm/llvm-project/pull/67527.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+63-1)
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+207)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 264539b85c0ee23..16eb60b2bd9bd69 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -59,6 +59,67 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
namespace {
+/// Conversion pattern for vector.transfer_read op with transpose permutation
+/// map to vertical arm_sme.tile_load (in-flight transpose).
+///
+/// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
+///
+/// is converted to:
+///
+/// arm_sme.tile_load ... <vertical>
+struct TransferReadPermutationToArmSMELowering
+ : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
+ PatternRewriter &rewriter) const final {
+ // The permutation map must have two results.
+ if (transferReadOp.getTransferRank() != 2)
+ return rewriter.notifyMatchFailure(transferReadOp,
+ "not a 2 result permutation map");
+
+ auto vectorType = transferReadOp.getVectorType();
+ if (!arm_sme::isValidSMETileVectorType(vectorType))
+ return rewriter.notifyMatchFailure(transferReadOp,
+ "not a valid vector type for SME");
+
+ if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
+ return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
+
+ if (transferReadOp.getMask())
+ // TODO: support masking.
+ return rewriter.notifyMatchFailure(transferReadOp,
+ "masking not yet supported");
+
+ // Out-of-bounds dims are not supported.
+ if (transferReadOp.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(transferReadOp,
+ "not inbounds transfer read");
+
+ AffineMap map = transferReadOp.getPermutationMap();
+
+ // Permutation map doesn't perform permutation, can be lowered to
+ // vector.load by TransferReadToVectorLoadLowering and then
+ // arm_sme.tile_load by VectorLoadToArmSMELowering.
+ if (map.isIdentity())
+ return rewriter.notifyMatchFailure(
+ transferReadOp, "map is an identity, apply another pattern");
+
+ AffineExpr d0, d1;
+ bindDims(transferReadOp.getContext(), d0, d1);
+ if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0},
+ transferReadOp.getContext()))
+ return rewriter.notifyMatchFailure(transferReadOp,
+ "not true 2-D matrix transpose");
+
+ rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
+ transferReadOp, vectorType, transferReadOp.getSource(),
+ transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical);
+
+ return success();
+ }
+};
+
/// Conversion pattern for vector.transfer_write.
///
/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
@@ -317,7 +378,8 @@ struct TransposeOpToArmSMELowering
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
+ patterns.add<TransferReadPermutationToArmSMELowering,
+ TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a64753578a1c861..d8bbdea5a06e946 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,5 +1,212 @@
// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
+//===----------------------------------------------------------------------===//
+// vector.transfer_read (with in-flight transpose)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i8
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i8
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i16
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i32
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i32
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i64
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i64
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i128
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i128
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
+ "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f16
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
+ "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_bf16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : bf16
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f32
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f64
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f64
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f64
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
+ "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__non_memref_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f64
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f64
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__unsupported_mask
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f64
+ %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+ return
+}
+
+// -----
+
+/// transfer_read with identity map should be lowered to vector.load by
+/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by
+/// VectorLoadToArmSMELowering.
+
+// CHECK-LABEL: @transfer_read_2d__non_permuting_map
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_permuting_map(%src : memref<?x?xf64>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f64
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__non_transpose
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_transpose(%src : memref<?x?xf64>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f64
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, 0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__out_of_bounds
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__out_of_bounds(%src : memref<?x?xf64>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f64
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.transfer_write
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/67527
More information about the Mlir-commits
mailing list