[Mlir-commits] [mlir] [mlir][ArmSME] Fold MoveTileSliceToVector + TransferWrite to StoreTileSlice (PR #95907)

Cullen Rhodes llvmlistbot at llvm.org
Tue Jun 18 23:32:25 PDT 2024


================
@@ -666,14 +666,64 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
   }
 };
 
+/// Folds a MoveTileSliceToVectorOp + TransferWriteOp to a StoreTileSliceOp.
+///
+///  BEFORE:
+///  ```mlir
+///  %slice = arm_sme.move_tile_slice_to_vector %tile[%index]
+///             : vector<[4]xf32> from vector<[4]x[4]xf32>
+///  vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
+///             : vector<[4]xf32>, memref<?x?xf32>
+///  ```
+///  AFTER:
+///  ```mlir
+///  arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
+///             : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+///  ```
+struct FoldTransferWriteOfExtractTileSlice
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const final {
+    if (!isa<MemRefType>(writeOp.getSource().getType()))
+      return failure();
+
+    auto moveTileSlice =
+        writeOp.getVector().getDefiningOp<arm_sme::MoveTileSliceToVectorOp>();
+    if (!moveTileSlice)
+      return failure();
+
+    AffineMap map = writeOp.getPermutationMap();
+    if (!map.isMinorIdentity())
----------------
c-rhodes wrote:

makes sense thanks. Perhaps worth adding a negative test for this?

https://github.com/llvm/llvm-project/pull/95907


More information about the Mlir-commits mailing list