[Mlir-commits] [mlir] [mlir][ArmSME] Fix loop bounds of masked loads/stores (PR #78983)
Cullen Rhodes
llvmlistbot at llvm.org
Tue Jan 23 01:59:24 PST 2024
================
@@ -47,99 +48,103 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
if (rank == 2)
outIndices.push_back(indices[1]);
-}
-/// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice
-/// using `arm_sme.load_tile_slice`.
-///
-/// BEFORE:
-/// ```mlir
-/// %tile = arm_sme.tile_load %src[%c0, %c0] :
-/// memref<?x?xi32>, vector<[4]x[4]xi32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-/// %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-/// %vscale = vector.vscale
-/// %c0 = arith.constant 0 : index
-/// %c1 = arith.constant 1 : index
-/// %min_svl_s = arith.constant 4 : index
-/// %svl_s = arith.muli %min_svl_s, %vscale : index
-/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
-/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
-/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-/// %ptrue_s, %iter_tile, %tile_slice_idx
-/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
-/// scf.yield %tile_update : vector<[4]x[4]xi32>
-/// }
-/// ```
-struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
- using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
- PatternRewriter &rewriter) const override {
- if (tileLoadOp.getMask())
- return rewriter.notifyMatchFailure(tileLoadOp,
- "op has mask, apply masked patterns");
-
- OpBuilder::InsertionGuard g(rewriter);
- auto loc = tileLoadOp.getLoc();
- auto tileType = tileLoadOp.getVectorType();
- auto tileElementType = tileType.getElementType();
-
- // Allocate a new SME tile.
- auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
- rewriter, loc, tileType);
+ return outIndices;
+}
- // Create a loop that loads each ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- // This describes both the number of ZA tile slices and the number of
- // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
- // ..., SVL_Q).
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
- step, ValueRange{initTile});
+/// Creates an scf.for for the load/store of an ArmSME tile.
+FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
+ PatternRewriter &rewriter, Location loc, VectorType tileType,
----------------
c-rhodes wrote:
I came across [ImplicitLocOpBuilder](https://mlir.llvm.org/doxygen/classmlir_1_1ImplicitLocOpBuilder.html#a6b440b25bf12be7622e136c5111419d7) the other day which can be used to combine rewriter/loc which is pretty nice and could simplify the args a little, it's used like:
```
ImplicitLocOpBuilder b(getLoc(), rewriter);
```
https://github.com/llvm/llvm-project/pull/78983
More information about the Mlir-commits
mailing list