[Mlir-commits] [mlir] [mlir][ArmSME] Lower multi-tile stores to a single loop (PR #96187)
Cullen Rhodes
llvmlistbot at llvm.org
Tue Jun 25 04:11:28 PDT 2024
================
@@ -373,6 +374,139 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};
+/// Legalize a multi-tile transfer_write as a single store loop. This is done as
+/// part of type decomposition as at this level we know each tile write is
+/// disjoint, but that information is lost after decomposition (without analysis
+/// to reconstruct it).
+///
+/// Example (pseudo-MLIR):
+///
+/// ```
+/// vector.transfer_write %vector, %dest[%y, %x], %mask
+/// : vector<[16]x[8]xi16>, memref<?x?xi16>
+/// ```
+/// Is rewritten to:
+/// ```
+/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
+/// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
+/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
+/// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
+/// : vector<[8]xi16> from vector<[8]x[8]xi16> |
+/// vector.transfer_write %upper_slice, |
+/// %dest[%slice_idx + %y, %x], %upper_slice_mask |
+/// : vector<[8]xi16>, memref<?x?xi16> ┘
+/// %lower_slice_idx = %slice_idx + %c8_vscale ─┐
+/// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
+/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
+/// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
+/// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile
+/// vector.transfer_write %lower_slice, |
+/// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
+/// : vector<[8]xi16>, memref<?x?xi16> ┘
+/// }
+/// ```
+struct LegalizeMultiTileTransferWriteAsStoreLoop
+ : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ if (writeOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ writeOp, "TODO: tensor semantics are unsupported");
+
+ auto permutationMap = writeOp.getPermutationMap();
+ if (!permutationMap.isPermutation())
+ return rewriter.notifyMatchFailure(writeOp,
+ kMatchFailureNonPermutationMap);
+
+ bool transposed = !permutationMap.isIdentity();
+ if (transposed)
+ return rewriter.notifyMatchFailure(writeOp,
+ "TODO: transpose unsupported");
+
+ auto vectorType = writeOp.getVectorType();
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return rewriter.notifyMatchFailure(writeOp,
+ kMatchFailureNotSMETileTypeMultiple);
+
+ // Note: We also disallow masks where any dimension is larger than 16 as
+ // that won't be possible to arm_sve.psel.
----------------
c-rhodes wrote:
nit: grammar
```suggestion
// Note: We also disallow masks where any dimension is larger than 16 as
// it won't be possible to use arm_sve.psel.
```
https://github.com/llvm/llvm-project/pull/96187
More information about the Mlir-commits
mailing list