[Mlir-commits] [mlir] [mlir][ArmSME] Add rudimentary support for tile spills to the stack (PR #76086)
Cullen Rhodes
llvmlistbot at llvm.org
Thu Dec 21 07:02:10 PST 2023
================
@@ -40,6 +132,204 @@ IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
return tileId;
}
+/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
+/// placed in the first block of the function.
+static memref::AllocaOp
+createAllocaForTile(RewriterBase &rewriter, Location loc,
+ FunctionOpInterface func,
+ arm_sme::ArmSMETileOpInterface tileOp) {
+ RewriterBase::InsertionGuard g(rewriter);
+ // Move to the first operation in the function.
+ rewriter.setInsertionPointToStart(&func.getBlocks().front());
+ // Create an alloca matching the tile size of the `tileOp`.
+ auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto tileElementType = tileOp.getTileType().getElementType();
+ auto memrefType = MemRefType::get(
+ {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
+ unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
+ auto minElementsOp =
+ rewriter.create<arith::ConstantIndexOp>(loc, minElements);
+ auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
+ auto alloca = rewriter.create<memref::AllocaOp>(
+ loc, memrefType, ValueRange{vectorLen, vectorLen});
+ return alloca;
+}
+
+/// Finds or creates an alloca for a spill of a tile.
+static memref::AllocaOp getOrCreateAllocaForTile(
+ RewriterBase &rewriter, Location loc, FunctionOpInterface func,
+ arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
+ // Find an alloca at the top of the function tagged with a
+ // 'arm_sme.in_memory_tile_id' that matches `tileId`.
+ for (auto &op : func.getBlocks().front()) {
+ auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
+ if (!alloca)
+ continue;
+ auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
+ alloca->getDiscardableAttr(kInMemoryTileIdAttr));
+ if (!inMemoryTileId)
+ continue;
+ if (inMemoryTileId.getInt() == tileId)
+ return alloca;
+ }
+ // Otherwise, create a new alloca:
+ auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
+ alloca->setDiscardableAttr(kInMemoryTileIdAttr,
+ rewriter.getI32IntegerAttr(tileId));
+ return alloca;
+}
+
+/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
+/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
+/// the op to tile 0, then emitting a full tile swap between ZA and memory
+/// before + after the tile op.
+///
+/// Example:
+///
+/// // Note: <IN MEMORY TILE> = tile ID >= 16.
+/// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
+///
+/// is converted to:
+/// // At function entry:
+/// %spill = memref.alloca ... : memref<?x?xty>
+///
+/// // Around op:
+/// scf.for %slice_idx {
+/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
+/// vector.store %current_slice, %spill[%slice_idx, %c0]
+/// }
+/// arm_sme.tile_op { tile_id = 0 }
+/// scf.for %slice_idx {
+/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
+/// vector.store %current_slice, %spill[%slice_idx, %c0]
+/// }
+///
+/// Note that these spills/fills are not inserted earlier as concept of a
+/// register, and the need to swap the contents, can't really be represented
+/// correctly at a high level in MLIR.
+///
+/// TODO: Reduce the spills/reloads to single slices where possible.
+struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
+
+ ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
+ const LLVMTypeConverter &typeConverter,
+ PatternBenefit benefit)
+ : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
+ typeConverter, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
+ // Tile has a real (hardware) tile. No spills/reloads required.
+ if (!tileOp.isInMemoryTile())
+ return failure();
+
+ // Step 1. Create an alloca for the tile at the top of the function (if one
+ // does not already exist).
+ auto loc = tileOp.getLoc();
+ auto func = tileOp->getParentOfType<FunctionOpInterface>();
+ auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
+ tileOp.getTileId().getInt());
+
+ // Step 2. Assign the op a real tile ID.
+ // For simplicity, we always use tile 0 (which always exists).
+ auto zeroTileId = rewriter.getI32IntegerAttr(0);
+ rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
+
+ VectorType tileVectorType = tileOp.getTileType();
+ auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
+ auto emitTileSwap = [&] {
+ emitFullTileSwap(rewriter, loc, tileAlloca,
+ *arm_sme::getSMETileType(tileVectorType), sliceType,
+ zeroTileId);
+ };
+
+ // Step 3. Emit tile swaps before and after the op.
+ // TODO: Reduce the amount spilled to the amount of data the `tileOp`
+ // touches (i.e. a single tile slice).
+ {
+ rewriter.setInsertionPoint(op);
+ // Swap the in-memory tile's contents into ZA before the op.
+ emitTileSwap();
+ rewriter.setInsertionPointAfter(op);
+ // Swap the tile back out to memory again after the op.
+ emitTileSwap();
+ }
+
+ return success();
+ }
+
+ /// Extracts a pointer to a slice of an in-memory tile.
+ Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
+ Value tileMemory, Value sliceIndex) const {
+ auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
+ auto descriptor =
+ rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
+ auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
+ auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI64Type(), sliceIndex);
+ return getStridedElementPtr(
+ loc, llvm::cast<MemRefType>(tileMemory.getType()),
+ descriptor.getResult(0), {sliceIndexI64, zero},
+ static_cast<ConversionPatternRewriter &>(rewriter));
+ }
+
+ /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
+ /// tile-sized memref (`tileAlloca`).
+ void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
+ arm_sme::ArmSMETileType tileType, VectorType sliceType,
+ IntegerAttr tileId, Value sliceIndex) const {
+ // Cast the slice index to an i32.
+ auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), sliceIndex);
+ // Create an all-true predicate for the slice.
+ auto predicateType = sliceType.clone(rewriter.getI1Type());
+ auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(predicateType, true));
+ // Create zero padding vector (never used due to all-true predicate).
+ auto zeroVector = rewriter.create<arith::ConstantOp>(
+ loc, sliceType, rewriter.getZeroAttr(sliceType));
----------------
c-rhodes wrote:
should undef be used here instead?
https://github.com/llvm/llvm-project/pull/76086
More information about the Mlir-commits
mailing list