[Mlir-commits] [mlir] bfb5fe2 - [mlir][ArmSME] Fold transpose into xfer read to enable in-flight transpose (#92562)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 21 00:08:08 PDT 2024


Author: Cullen Rhodes
Date: 2024-05-21T08:08:05+01:00
New Revision: bfb5fe218ec7aa282375905189bf0aab79609c04

URL: https://github.com/llvm/llvm-project/commit/bfb5fe218ec7aa282375905189bf0aab79609c04
DIFF: https://github.com/llvm/llvm-project/commit/bfb5fe218ec7aa282375905189bf0aab79609c04.diff

LOG: [mlir][ArmSME] Fold transpose into xfer read to enable in-flight transpose (#92562)

vector.transpose ops whose inputs come from vector.transfer_read can be
eliminated by folding the transpose into the xfer op to enable in-flight
transposition when converting xfer read to arm_sme.tile_load.

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d8e473a562e53..87923477766d1 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -356,6 +356,20 @@ struct TransposeOpToArmSMELowering
       return failure();
 
     auto loc = transposeOp.getLoc();
+    Value input = transposeOp.getVector();
+
+    if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
+        xferOp && xferOp->hasOneUse()) {
+      // Fold transpose into transfer_read to enable in-flight transpose when
+      // converting to arm_sme.tile_load.
+      rewriter.modifyOpInPlace(xferOp, [&]() {
+        xferOp->setAttr(xferOp.getPermutationMapAttrName(),
+                        AffineMapAttr::get(AffineMap::getPermutationMap(
+                            permutation, transposeOp.getContext())));
+      });
+      rewriter.replaceOp(transposeOp, xferOp);
+      return success();
+    }
 
     // Allocate buffer to store input tile to.
     Value vscale =
@@ -372,8 +386,6 @@ struct TransposeOpToArmSMELowering
     auto buffer = rewriter.create<memref::AllocaOp>(
         loc, bufferType, ValueRange{numTileSlices, numTileSlices});
 
-    Value input = transposeOp.getVector();
-
     // Store input tile.
     auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
         loc, input, buffer, ValueRange{c0, c0});

diff  --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index ce0b46e0f061a..f22b6de52f367 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -150,6 +150,39 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas
 
 // -----
 
+// CHECK-LABEL: @fold_transpose_into_load
+// CHECK-NOT: arm_sme.tile_store
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK-NOT: arm_sme.tile_store
+func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
+  "prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+/// Transposes with more than a single use cannot be folded into load and will
+/// instead be transposed via memory.
+
+// CHECK-LABEL: @fold_transpose_into_load_multi_use
+// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: %[[TILE_TRANSPOSED_VIA_MEM:.*]] = arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: "prevent.dce"(%[[TILE_TRANSPOSED_VIA_MEM]]) : (vector<[4]x[4]xf32>) -> ()
+func.func @fold_transpose_into_load_multi_use(%src : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
+  %1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
+  "prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list