[Mlir-commits] [mlir] [mlir][ArmSME] Fix loop bounds of masked loads/stores (PR #78983)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Jan 23 09:25:01 PST 2024
================
@@ -400,77 +383,25 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
PatternRewriter &rewriter) const override {
- OpBuilder::InsertionGuard g(rewriter);
- auto loc = tileStoreOp.getLoc();
- auto tileType = tileStoreOp.getVectorType();
- auto tileElementType = tileType.getElementType();
-
- auto predicateType =
- VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-
- Value maskCols;
- Value upperBound;
- auto maskOp = tileStoreOp.getMask();
- if (maskOp) {
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp)
- return rewriter.notifyMatchFailure(
- tileStoreOp, "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
-
- auto numRows = createMaskOp.getOperands()[0];
- auto numCols = createMaskOp.getOperands()[1];
-
- upperBound = numRows;
- maskCols =
- rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
- } else {
- // Store all tile slices if no mask.
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- // This describes both the number of ZA tile slices and the number of
- // elements in a vector of SVL bits for a given element type (SVL_B,
- // SVL_H,
- // ..., SVL_Q).
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-
- upperBound = numTileSlices;
- // Create an 'all true' predicate for the tile slice.
- maskCols = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
- }
-
// Create a loop that stores each (active) active ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
-
- rewriter.setInsertionPointToStart(forOp.getBody());
-
- SmallVector<Value> memrefIndices;
- auto tileSliceIndex = forOp.getInductionVar();
- getMemrefIndices(tileStoreOp.getIndices(),
- tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
- upperBound, memrefIndices, loc, rewriter);
-
- tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
- rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
- tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
-
- return success();
+ return createLoadStoreForOverTileSlices(
+ rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
+ tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
+ tileStoreOp.getMask(),
+ [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
+ tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
+ rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
+ predicate, tileStoreOp.getBase(), memrefIndices,
+ tileStoreOp.getLayout());
+ });
}
};
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns
- .add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
- TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
- patterns.getContext());
+ patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
+ TileStoreOpConversion>(patterns.getContext());
----------------
MacDue wrote:
See: #79172.
https://github.com/llvm/llvm-project/pull/78983
More information about the Mlir-commits
mailing list