[Mlir-commits] [mlir] 4240b17 - [mlir][ArmSME] Lower transfer_write + transpose to vertical store (#71181)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 9 23:51:10 PST 2023


Author: Cullen Rhodes
Date: 2023-11-10T07:51:06Z
New Revision: 4240b1790fb539e8d8f834e547a224d7f1e614c9

URL: https://github.com/llvm/llvm-project/commit/4240b1790fb539e8d8f834e547a224d7f1e614c9
DIFF: https://github.com/llvm/llvm-project/commit/4240b1790fb539e8d8f834e547a224d7f1e614c9.diff

LOG: [mlir][ArmSME] Lower transfer_write + transpose to vertical store (#71181)

This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 5491f7dd30629ad..953a465c18de69f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering
 
 /// Conversion pattern for vector.transfer_write.
 ///
-///   vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
-///                                                      memref<?x?xi8>
+/// ---
+///
+/// Example 1: op with identity permutation map to horizontal
+///            arm_sme.tile_store:
+///
+///   vector.transfer_write %vector, %source[%c0, %c0]
+///     {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
 ///
 /// is converted to:
 ///
 ///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
 ///                                                   vector<[16]x[16]xi8>
+/// ---
+///
+/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
+///            (in-flight transpose):
+///
+///   vector.transfer_write %vector, %source[%c0, %c0]
+///     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+///      in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+///
+/// is converted to:
+///
+///   arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
+///     : memref<?x?xi8>, vector<[16]x[16]xi8>
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -156,9 +174,28 @@ struct TransferWriteToArmSMELowering
     if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
       return failure();
 
+    // Out-of-bounds dims are not supported.
+    if (writeOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "not inbounds transfer write");
+
+    AffineExpr d0, d1;
+    bindDims(writeOp.getContext(), d0, d1);
+    AffineMap map = writeOp.getPermutationMap();
+    bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+                                              writeOp.getContext()));
+
+    if (!map.isIdentity() && !isTranspose)
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "unsupported permutation map");
+
+    arm_sme::TileSliceLayout layout =
+        isTranspose ? arm_sme::TileSliceLayout::Vertical
+                    : arm_sme::TileSliceLayout::Horizontal;
+
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
         writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
-        writeOp.getMask());
+        writeOp.getMask(), layout);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index ed33f8508dba0bf..ae3b260da83a2c1 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest
 
 // -----
 
+/// in-flight transpose via vertical store.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
+// CHECK-SAME:                                             %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi64>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
+  return
+}
+
+// -----
+
+/// in-flight transpose via vertical store with mask.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
+// CHECK-SAME:                                                        %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME:                                                        %[[DEST:.*]]: memref<?x?xbf16>,
+// CHECK-SAME:                                                        %[[MASK:.*]]: vector<[8]x[8]xi1>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
+  return
+}
+
+// -----
+
 // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
 // lowering only occurs for vector types of correct rank, shape, element size
 // and number of scalable dims.
@@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
   return
 }
 
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__out_of_bounds
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.broadcast
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index b599b976c3e1592..49c513badb7b071 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -32,6 +32,25 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
   return
 }
 
+// Vector transpose + store.
+func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// Vector transpose + masked store.
+func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
 // Vector load + print.
 func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -116,6 +135,26 @@ func.func @entry() {
   call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
   call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
 
+  // 4. Reload 3. + transpose + store.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 20, 30
+  // CHECK-NEXT: ( 0, 0, 21, 31
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 3, 13, 0, 0
+  call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Reload 4. + transpose + masked store (nrows=4, ncols=2).
+  // The mask applies after permutation. Columns 2 and 3 (from 4.) are
+  // preserved.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 20, 30
+  // CHECK-NEXT: ( 0, 0, 21, 31
+  // CHECK-NEXT: ( 20, 21, 0, 0
+  // CHECK-NEXT: ( 30, 31, 0, 0
+  call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
   memref.dealloc %A : memref<?x?xf32>
 
   return


        


More information about the Mlir-commits mailing list