[Mlir-commits] [mlir] [mlir][ArmSME] Lower multi-tile stores to a single loop (PR #96187)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Jun 25 02:35:22 PDT 2024


================
@@ -373,6 +374,139 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+/// Legalize a multi-tile transfer_write as a single store loop. This is done as
+/// part of type decomposition as at this level we know each tile write is
+/// disjoint, but that information is lost after decomposition (without analysis
+/// to reconstruct it).
+///
+/// Example:
+///
+/// ```
+/// vector.transfer_write %vector, %dest[%y, %x], %mask
+///   : vector<[16]x[8]xi16>, memref<?x?xi16>
+/// ```
+/// Is rewritten to:
+/// ```
+/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
+///   %upper_slice_y = arith.addi %slice_idx, %y : index
+///   %upper_slice_mask = vector.extract %mask[%slice_idx]
+///     : vector<[8]xi1> from vector<[16]x[8]xi1>
+///   %upper_slice = vector.extract %upper_tile[%slice_idx]
+///     : vector<[8]xi16> from vector<[8]x[8]xi16>
+///   vector.transfer_write %upper_slice,
+///     %dest[%upper_slice_y, %x], %upper_slice_mask
+///     : vector<[8]xi16>, memref<?x?xi16>
+///   // Same again for the lower tile:
+///   %lower_slice_idx = arith.addi %c8_vscale, %slice_idx : index
+///   %lower_slice_y = arith.addi %lower_slice_idx, %y : index
+///   %lower_slice_mask = vector.extract %mask[%lower_slice_idx]
+///     : vector<[8]xi1> from vector<[16]x[8]xi1>
+///   %lower_slice = vector.extract %lower_tile[%slice_idx]
+///     : vector<[8]xi16> from vector<[8]x[8]xi16>
+///   vector.transfer_write %lower_slice,
+///     %dest[%lower_slice_y, %x], %lower_slice_mask
+///     : vector<[8]xi16>, memref<?x?xi16>
+/// }
+/// ```
+struct LegalizeMultiTileTransferWriteAsStoreLoop
+    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (writeOp.hasPureTensorSemantics())
+      return rewriter.notifyMatchFailure(
+          writeOp, "TODO: tensor semantics are unsupported");
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNonPermutationMap);
+
+    bool transposed = !permutationMap.isIdentity();
+    if (transposed)
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "TODO: transpose unsupported");
+
+    auto vectorType = writeOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNotSMETileTypeMultiple);
+
+    auto mask = writeOp.getMask();
+    if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
+                                              vectorType.getDimSize(1) > 16)))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureUnsupportedMaskOp);
+
+    auto loc = writeOp.getLoc();
+    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+    auto createVscaleMultiple = [&](int64_t multiplier) {
+      return rewriter.create<arith::MulIOp>(
+          loc, vscale,
+          rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
+    };
+
+    // Get SME tile and slice types.
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto minTileSlices = smeTileType.getDimSize(0);
+    VectorType sliceMaskType =
+        VectorType::get(minTileSlices, rewriter.getI1Type(), true);
+
+    // Create loop over all tile slices.
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = createVscaleMultiple(minTileSlices);
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto storeLoop =
+        rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    rewriter.setInsertionPointToStart(storeLoop.getBody());
+
+    // For each tile sub-tile of the multi-tile `vectorType`.
+    auto inputSMETiles = adaptor.getVector();
+    auto inductionVar = storeLoop.getInductionVar();
----------------
MacDue wrote:

I'll call this `tileSliceIndex` instead.

https://github.com/llvm/llvm-project/pull/96187


More information about the Mlir-commits mailing list