[Mlir-commits] [mlir] [mlir][ArmSME] Add support for lowering masked tile_load ops (PR #70915)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Nov 7 02:46:06 PST 2023
================
@@ -142,6 +141,245 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
}
};
+/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+///
+/// BEFORE:
+/// ```mlir
+/// %pad = arith.constant 0 : i32
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
+/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
+/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+/// %tile_update = arm_sme.load_tile_slice
+/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+/// }
+/// ```
+///
+/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
+struct TileLoadOpWithMaskAndPadZeroConversion
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = tileLoadOp.getLoc();
+ auto tileType = tileLoadOp.getVectorType();
+
+ auto maskOp = tileLoadOp.getMask();
+ if (!maskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has no mask, needs unmasked pattern");
+
+ auto padOp = tileLoadOp.getPadding();
+ assert(padOp && "expected padding when masking!");
+
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+
+ auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+ if (!constPadOp || constPadOp.getValue() !=
+ rewriter.getZeroAttr(tileType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
+
+ auto numRows = createMaskOp.getOperands()[0];
+ auto numCols = createMaskOp.getOperands()[1];
+
+ auto predicateType =
+ VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+ auto numColsOp =
+ rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+ // Initialize tile with zero to satisfy padding. Inactive cols will be
+ // zeroed anyway since the loads use zeroing predication. For inactive rows
+ // however, no load will occur so these need to be zeroed.
+ auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+
+ // Create a loop to load the active tile slices from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = numRows;
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+ // tile.
+ SmallVector<Value> memrefIndices;
+ auto tileSliceIndex = forOp.getInductionVar();
+ getMemrefIndices(tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+ upperBound, memrefIndices, loc, rewriter);
+ rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
+ tileSliceIndex, tileLoadOp.getLayout());
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ // Replace 'arm_sme.tile_load' with the tile.
+ rewriter.replaceOp(tileLoadOp, tile);
+
+ return success();
+ }
+};
+
+/// Lower `arm_sme.tile_load` with mask and non-zero pad.
+///
+/// BEFORE:
+/// ```mlir
+/// %pad = arith.constant 1 : i32
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %num_cols_i32 = arith.index_castui %num_cols : index to i32
+/// %tile_id = arm_sme.get_tile_id : i32
+/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+/// %vscale = vector.vscale
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %min_svl_s = arith.constant 4 : index
+/// %svl_s = arith.muli %min_svl_s, %vscale : index
+/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
+/// %row_is_active_i32 = arith.extsi %row_is_active : i1 to i32
+/// %mask = arith.andi %row_is_active_i32, %num_cols_i32 : i32
+/// %mask_index = arith.index_cast %mask : i32 to index
+/// %mask_1d = vector.create_mask %mask_index : vector<[4]xi1>
+/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad
+/// : memref<?x?xi32>, vector<[4]xi1>,
+/// vector<[4]xi32> into vector<[4]xi32>
+/// // Insert slice into tile
+/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
+/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// }
----------------
banach-space wrote:
I find these snippets very helpful, but I think that you can strip it a bit and only leave only the key parts. For example:
```suggestion
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
/// (...)
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
/// (...)
/// %mask_1d = vector.create_mask %mask_index : vector<[4]xi1>
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
/// : memref<?x?xi32>, vector<[4]xi1>,
/// vector<[4]xi32> into vector<[4]xi32>
/// // Insert slice into tile
/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
/// }
```
Basically, what's key is that `arm_sme.tile_load` is replaced with an `scf.for` loop. Within that loop, two Ops are key:
* masked load
* `arm_sme.move_vector_to_tile_slice`
I think that everything else is just glue code. Happy to have this as is (though rename `%pad` as `%pad_1d` inside the loop).
https://github.com/llvm/llvm-project/pull/70915
More information about the Mlir-commits
mailing list