[Mlir-commits] [mlir] 4d6b992 - [mlir][ArmSME] Fold MoveTileSliceToVector + TransferWrite to StoreTileSlice (#95907)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 19 04:52:56 PDT 2024


Author: Benjamin Maxwell
Date: 2024-06-19T12:52:53+01:00
New Revision: 4d6b9921b3801709dca9245b5b4d7c17944a036f

URL: https://github.com/llvm/llvm-project/commit/4d6b9921b3801709dca9245b5b4d7c17944a036f
DIFF: https://github.com/llvm/llvm-project/commit/4d6b9921b3801709dca9245b5b4d7c17944a036f.diff

LOG: [mlir][ArmSME] Fold MoveTileSliceToVector + TransferWrite to StoreTileSlice (#95907)

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Conversion/VectorToArmSME/unsupported.mlir
    mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index c2f1584e43bac..56ae46a6098ee 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -666,14 +666,69 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
   }
 };
 
+/// Folds a MoveTileSliceToVectorOp + TransferWriteOp to a StoreTileSliceOp.
+///
+///  BEFORE:
+///  ```mlir
+///  %slice = arm_sme.move_tile_slice_to_vector %tile[%index]
+///             : vector<[4]xf32> from vector<[4]x[4]xf32>
+///  vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
+///             : vector<[4]xf32>, memref<?x?xf32>
+///  ```
+///  AFTER:
+///  ```mlir
+///  arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
+///             : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+///  ```
+struct FoldTransferWriteOfExtractTileSlice
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const final {
+    if (!isa<MemRefType>(writeOp.getSource().getType()))
+      return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
+
+    if (writeOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "not inbounds transfer write");
+
+    auto moveTileSlice =
+        writeOp.getVector().getDefiningOp<arm_sme::MoveTileSliceToVectorOp>();
+    if (!moveTileSlice)
+      return rewriter.notifyMatchFailure(
+          writeOp, "vector to store not from MoveTileSliceToVectorOp");
+
+    AffineMap map = writeOp.getPermutationMap();
+    if (!map.isMinorIdentity())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "unsupported permutation map");
+
+    Value mask = writeOp.getMask();
+    if (!mask) {
+      auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
+      mask = rewriter.create<arith::ConstantOp>(
+          writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
+    }
+
+    rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
+        writeOp, moveTileSlice.getTile(), moveTileSlice.getTileSliceIndex(),
+        mask, writeOp.getSource(), writeOp.getIndices(),
+        moveTileSlice.getLayout());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
-               TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
-               TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
-               VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
-               VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
-               VectorPrintToArmSMELowering>(&ctx);
+  patterns
+      .add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+           TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+           TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+           VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+           VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+           VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
+          &ctx);
 }

diff  --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 35089ebebac7e..8ed52cde784ce 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -145,6 +145,18 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_slice_unsupported_permutation
+// CHECK-NOT: arm_sme.store_tile_slice
+func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
+  %c0 = arith.constant 0 : index
+  %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
+  return
+}
+
+
 //===----------------------------------------------------------------------===//
 // vector.outerproduct
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index f22b6de52f367..8aeffb066de90 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -334,6 +334,50 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_slice(
+// CHECK-SAME:                                  %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                  %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                  %[[INDEX:.*]]: index) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         %[[MASK:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK:         arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
+  %c0 = arith.constant 0 : index
+  %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_slice_with_mask(
+// CHECK-SAME:                                            %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                            %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                            %[[MASK:.*]]: vector<[4]xi1>,
+// CHECK-SAME:                                            %[[INDEX:.*]]: index) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+func.func @transfer_write_slice_with_mask(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask: vector<[4]xi1>, %slice_index: index) {
+  %c0 = arith.constant 0 : index
+  %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  vector.transfer_write %slice, %dest[%slice_index, %c0], %mask { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_vertical_slice
+// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical>
+func.func @transfer_write_vertical_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
+  %c0 = arith.constant 0 : index
+   %slice = arm_sme.move_tile_slice_to_vector %vector[%slice_index] layout<vertical>
+            : vector<[4]xf32> from vector<[4]x[4]xf32>
+  vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.broadcast
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list