[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA (PR #98620)
Cullen Rhodes
llvmlistbot at llvm.org
Mon Jul 22 04:20:13 PDT 2024
================
@@ -775,6 +780,148 @@ struct ConvertIllegalShapeCastOpsToTransposes
}
};
+/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
+/// the ZA state. This workaround rewrite to support these transposes when ZA is
+/// available.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vec, [1, 0]
+/// : vector<2x[4]xf32> to vector<[4]x2xf32>
+/// vector.transfer_write %transpose, %dest[%y, %x]
+/// : vector<[4]x2xf32>, memref<?x?xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+/// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
+/// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+/// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
+/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
+/// vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
+/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
+/// : vector<[4]x[4]xf32>, memref<?x?xf32>
+/// ```
+///
+/// Values larger than a single tile are supported via decomposition.
+struct LowerIllegalTransposeStoreViaZA
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ if (!isSupportedMaskOp(writeOp.getMask()))
+ return rewriter.notifyMatchFailure(writeOp,
+ kMatchFailureUnsupportedMaskOp);
+
+ auto permutationMap = writeOp.getPermutationMap();
+ if (!permutationMap.isIdentity())
+ return rewriter.notifyMatchFailure(writeOp,
+ kMatchFailureNonPermutationMap);
+
+ auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp)
+ return failure();
+
+ auto sourceType = transposeOp.getSourceVectorType();
+ auto resultType = transposeOp.getResultVectorType();
+
+ if (resultType.getRank() != 2)
+ return rewriter.notifyMatchFailure(transposeOp, "not rank 2");
+
+ if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(
+ transposeOp, "not illegal/unsupported SVE transpose");
+
+ auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
+ VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
+
+ if (sourceType.getDimSize(0) <= 1 ||
+ sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
+ return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
+
+ auto loc = writeOp.getLoc();
+ auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
+
+ auto transposeMap = AffineMapAttr::get(
+ AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
+
+ // Note: We need to use `get_tile` as there's no vector-level `undef`.
+ Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
+ Value destTensorOrMemref = writeOp.getSource();
+ auto numSlicesPerTile =
+ std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
+ auto numSlices =
+ rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
+ for (auto [index, smeTile] : llvm::enumerate(
+ decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
+ // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
+ // of slices from the source type into the SME tile. Without checking
+ // vscale (and emitting multiple implementations) we can't make use of the
+ // rows of the tile after 1*vscale rows.
+ Value tile = undefTile;
+ for (int d = 0, e = numSlicesPerTile; d < e; ++d) {
+ Value vector = rewriter.create<vector::ExtractOp>(
+ loc, transposeOp.getVector(),
+ rewriter.getIndexAttr(d + smeTile.row));
+ if (vector.getType() != smeSliceType) {
+ vector = rewriter.create<vector::ScalableExtractOp>(
+ loc, smeSliceType, vector, smeTile.col);
+ }
----------------
c-rhodes wrote:
untested
https://github.com/llvm/llvm-project/pull/98620
More information about the Mlir-commits
mailing list