[Mlir-commits] [mlir] [mlir][ArmSME] Add rudimentary support for tile spills to the stack (PR #76086)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Jan 9 07:26:57 PST 2024


================
@@ -129,8 +132,267 @@ IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
   return tileId;
 }
 
-struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTileOp> {
-  using ConvertOpToLLVMPattern<arm_sme::GetTileOp>::ConvertOpToLLVMPattern;
+/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
+/// placed in the first block of the function.
+static memref::AllocaOp
+createAllocaForTile(RewriterBase &rewriter, Location loc,
+                    FunctionOpInterface func,
+                    arm_sme::ArmSMETileOpInterface tileOp) {
+  RewriterBase::InsertionGuard g(rewriter);
+  // Move to the first operation in the function.
+  rewriter.setInsertionPointToStart(&func.getBlocks().front());
+  // Create an alloca matching the tile size of the `tileOp`.
+  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+  auto tileElementType = tileOp.getTileType().getElementType();
+  auto memrefType = MemRefType::get(
+      {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
+  unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
+  auto minElementsOp =
+      rewriter.create<arith::ConstantIndexOp>(loc, minElements);
+  auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
+  auto alloca = rewriter.create<memref::AllocaOp>(
+      loc, memrefType, ValueRange{vectorLen, vectorLen});
+  return alloca;
+}
+
+/// Finds or creates an alloca for a spill of a tile.
+static memref::AllocaOp getOrCreateAllocaForTile(
+    RewriterBase &rewriter, Location loc, FunctionOpInterface func,
+    arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
+  // Find an alloca at the top of the function tagged with a
+  // 'arm_sme.in_memory_tile_id' that matches `tileId`.
+  for (auto &op : func.getBlocks().front()) {
+    auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
+    if (!alloca)
+      continue;
+    auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
+        alloca->getDiscardableAttr(kInMemoryTileIdAttr));
+    if (!inMemoryTileId)
+      continue;
+    if (inMemoryTileId.getInt() == tileId)
+      return alloca;
+  }
+  // Otherwise, create a new alloca:
+  auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
+  alloca->setDiscardableAttr(kInMemoryTileIdAttr,
+                             rewriter.getI32IntegerAttr(tileId));
+  return alloca;
+}
+
+/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
+/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
+/// the op to tile 0, then emitting a full tile swap between ZA and memory
+/// before + after the tile op.
+///
+/// Example:
+///
+///    // Note: <IN MEMORY TILE> = tile ID >= 16.
+///    arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
+///
+/// is converted to:
+///     // At function entry:
+///     %spill = memref.alloca ... : memref<?x?xty>
+///
+///     // Around op:
+///     scf.for %slice_idx {
+///       %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+///       "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx)  <{tile_id = 0 : i32}>
+///       vector.store %slice_to_save, %spill[%slice_idx, %c0]
+///     }
+///     arm_sme.tile_op { tile_id = 0 }
+///     scf.for %slice_idx {
+///       %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+///       "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx)  <{tile_id = 0 : i32}>
+///       vector.store %slice_to_save, %spill[%slice_idx, %c0]
+///     }
+///
+/// Note that these spills/fills are not inserted earlier as concept of a
+/// register, and the need to swap the contents, can't really be represented
+/// correctly at a high level in MLIR.
+///
+/// TODO: Reduce the spills/reloads to single slices where possible (and omit
+/// redundant reloads). This could be done via a method on the
+/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
+///
+/// `tileOp.getZaUsage()` could return:
+///
+/// struct ArmSMEOpZAUsage {
+///   enum class Kind {
+///     TileRead,        // Omit store after tile operation.
+///     TileWrite,       // Omit load before tile operation.
+///     TileReadWrite,   // Needs both tile load and store.
+///     SliceRead,       // Spill single slice and omit store after operation.
+///     SliceWrite,      // Spill single slice and omit load before operation.
+///     SliceReadWrite   // Spill single slice.
+///   };
+///   Value sliceIndex {};
+///   TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
+/// };
+///
+struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
+
+  ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
+                                    const LLVMTypeConverter &typeConverter,
+                                    PatternBenefit benefit)
+      : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
+                             typeConverter, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
+    // Tile has a real (hardware) tile. No spills/reloads required.
+    if (!tileOp.isInMemoryTile())
+      return failure();
+
+    // Step 1. Create an alloca for the tile at the top of the function (if one
+    // does not already exist).
+    auto loc = tileOp.getLoc();
+    auto func = tileOp->getParentOfType<FunctionOpInterface>();
+    auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
+                                               tileOp.getTileId().getInt());
+
+    // Step 2. Assign the op a real tile ID.
+    // For simplicity, we always use tile 0 (which always exists).
+    auto zeroTileId = rewriter.getI32IntegerAttr(0);
+    rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
+
+    VectorType tileVectorType = tileOp.getTileType();
+    auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
+    auto emitTileSwap = [&] {
+      emitFullTileSwap(rewriter, loc, tileAlloca,
+                       *arm_sme::getSMETileType(tileVectorType), sliceType,
+                       zeroTileId);
+    };
+
+    // Step 3. Emit tile swaps before and after the op.
+    // TODO: Reduce the amount spilled to the amount of data the `tileOp`
+    // touches (i.e. a single tile slice).
+    {
+      rewriter.setInsertionPoint(op);
+      // Swap the contents of ZA and the in-memory tile before the op.
+      emitTileSwap();
+      rewriter.setInsertionPointAfter(op);
+      // Swap the tile back out to memory again after the op.
+      emitTileSwap();
+    }
+
+    return success();
+  }
+
+  /// Extracts a pointer to a slice of an in-memory tile.
+  Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
+                                Value tileMemory, Value sliceIndex) const {
+    auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
+    auto descriptor =
+        rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
+    auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
+    auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI64Type(), sliceIndex);
+    return getStridedElementPtr(
+        loc, llvm::cast<MemRefType>(tileMemory.getType()),
+        descriptor.getResult(0), {sliceIndexI64, zero},
+        static_cast<ConversionPatternRewriter &>(rewriter));
+  }
+
+  /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
+  /// tile-sized memref (`tileAlloca`).
+  void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
+                     arm_sme::ArmSMETileType tileType, VectorType sliceType,
+                     IntegerAttr tileId, Value sliceIndex) const {
+    // Cast the slice index to an i32.
+    auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI32Type(), sliceIndex);
+    // Create an all-true predicate for the slice.
+    auto predicateType = sliceType.clone(rewriter.getI1Type());
+    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+    // Create padding vector (never used due to all-true predicate).
+    auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
+    // Get a pointer to the current slice.
+    auto slicePtr =
+        getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
+    // Read the value of the current slice from ZA.
+    auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
+        loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
+    // Load the new tile slice back from memory into ZA.
+    createLoadTileSliceIntrinsic(
+        rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
+        allTruePredicate, slicePtr, tileId, sliceIndexI32);
+    // Store the current tile slice to memory.
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
+                                     ValueRange{sliceIndex, zero});
+  }
+
+  /// Emits a full in-place swap of the contents of a tile in ZA and a
+  /// tile-sized memref (`tileAlloca`).
+  void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
----------------
banach-space wrote:

I think that removing the final argument (which is hard-coded in practice) and renaming this as e.g. `swapInMemoryTileWithSMETileZero` could help to convey the fact that right now everything is hard-coded to use Tile 0.

https://github.com/llvm/llvm-project/pull/76086


More information about the Mlir-commits mailing list