[Mlir-commits] [mlir] [mlir][ArmSME] Fold transpose into xfer read to enable in-flight transpose (PR #92562)
Cullen Rhodes
llvmlistbot at llvm.org
Mon May 20 00:33:38 PDT 2024
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/92562
>From 650e1e8637d3297e4d3929eac074b3c44251c915 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 17 May 2024 09:52:35 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Fold transpose into xfer read to enable
in-flight transpose
vector.transpose ops whose inputs come from vector.transfer_read can be
eliminated by folding the transpose into the xfer op to enable in-flight
transposition when converting xfer read to arm_sme.tile_load.
---
.../Conversion/VectorToArmSME/VectorToArmSME.cpp | 16 ++++++++++++++--
.../VectorToArmSME/vector-to-arm-sme.mlir | 12 ++++++++++++
2 files changed, 26 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d8e473a562e53..b1b84705da7d3 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -356,6 +356,20 @@ struct TransposeOpToArmSMELowering
return failure();
auto loc = transposeOp.getLoc();
+ Value input = transposeOp.getVector();
+
+ if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>()) {
+ // Fold transpose into transfer_read to enable in-flight transpose when
+ // converting to arm_sme.tile_load.
+ rewriter.modifyOpInPlace(xferOp, [&]() {
+ SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
+ xferOp->setAttr(xferOp.getPermutationMapAttrName(),
+ AffineMapAttr::get(AffineMap::getPermutationMap(
+ permutation, transposeOp.getContext())));
+ });
+ rewriter.replaceOp(transposeOp, xferOp);
+ return success();
+ }
// Allocate buffer to store input tile to.
Value vscale =
@@ -372,8 +386,6 @@ struct TransposeOpToArmSMELowering
auto buffer = rewriter.create<memref::AllocaOp>(
loc, bufferType, ValueRange{numTileSlices, numTileSlices});
- Value input = transposeOp.getVector();
-
// Store input tile.
auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
loc, input, buffer, ValueRange{c0, c0});
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index ce0b46e0f061a..48e92ce88ed16 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -150,6 +150,18 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas
// -----
+// CHECK-LABEL: @fold_transpose_into_load
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
+ "prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.transfer_write
//===----------------------------------------------------------------------===//
>From 637c9537ec4226e46a593d86ca7eef84a89c069e Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 17 May 2024 15:31:04 +0000
Subject: [PATCH 2/3] address comments
---
.../Conversion/VectorToArmSME/VectorToArmSME.cpp | 4 ++--
.../VectorToArmSME/vector-to-arm-sme.mlir | 15 +++++++++++++++
2 files changed, 17 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index b1b84705da7d3..87923477766d1 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -358,11 +358,11 @@ struct TransposeOpToArmSMELowering
auto loc = transposeOp.getLoc();
Value input = transposeOp.getVector();
- if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>()) {
+ if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
+ xferOp && xferOp->hasOneUse()) {
// Fold transpose into transfer_read to enable in-flight transpose when
// converting to arm_sme.tile_load.
rewriter.modifyOpInPlace(xferOp, [&]() {
- SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
xferOp->setAttr(xferOp.getPermutationMapAttrName(),
AffineMapAttr::get(AffineMap::getPermutationMap(
permutation, transposeOp.getContext())));
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 48e92ce88ed16..c68327df8a076 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -162,6 +162,21 @@ func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
// -----
+// CHECK-LABEL: @fold_transpose_into_load_multi_use
+// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @fold_transpose_into_load_multi_use(%src : memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
+ %1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
+ "prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.transfer_write
//===----------------------------------------------------------------------===//
>From 3fc4fe44275c45fd58462df7858284b2ee1dd09c Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 20 May 2024 07:32:33 +0000
Subject: [PATCH 3/3] address comments
---
.../test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index c68327df8a076..f22b6de52f367 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -151,7 +151,9 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas
// -----
// CHECK-LABEL: @fold_transpose_into_load
+// CHECK-NOT: arm_sme.tile_store
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK-NOT: arm_sme.tile_store
func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
@@ -162,10 +164,14 @@ func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
// -----
+/// Transposes with more than a single use cannot be folded into load and will
+/// instead be transposed via memory.
+
// CHECK-LABEL: @fold_transpose_into_load_multi_use
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: %[[TILE_TRANSPOSED_VIA_MEM:.*]] = arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: "prevent.dce"(%[[TILE_TRANSPOSED_VIA_MEM]]) : (vector<[4]x[4]xf32>) -> ()
func.func @fold_transpose_into_load_multi_use(%src : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
More information about the Mlir-commits
mailing list