[Mlir-commits] [mlir] [mlir][vector] Add scalable lowering for `transfer_write(transpose)` (PR #101353)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Aug 7 13:30:55 PDT 2024
================
@@ -987,6 +988,185 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
}
};
+/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
+/// and ConstantMaskOp.
+template <typename VscaleConstantBuilder>
+static FailureOr<SmallVector<OpFoldResult>>
+getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
+ if (!mask)
+ return SmallVector<OpFoldResult>{};
+ if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
+ return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
+ return OpFoldResult(dimSize);
+ });
+ }
+ if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+ int dimIdx = 0;
+ VectorType maskType = constantMask.getVectorType();
+ auto indexType = IndexType::get(mask.getContext());
+ return llvm::map_to_vector(
+ constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
+ // A scalable dim in a constant_mask means vscale x dimSize.
+ if (maskType.getScalableDims()[dimIdx++])
+ return OpFoldResult(createVscaleMultiple(dimSize));
+ return OpFoldResult(IntegerAttr::get(indexType, dimSize));
+ });
+ }
+ return failure();
+}
+
+/// Scalable vector lowering of transfer_write(transpose). This lowering only
+/// supports rank 2 (scalable) vectors, but can be used in in conjunction with
+/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
+/// unrolls until the first scalable dimension.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vec, [1, 0]
+/// : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
+/// : vector<[4]x4xf32>, memref<?x?xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %c0 = arith.constant 0 : index
+/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %vscale = vector.vscale
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// scf.for %idx = %c0 to %c4_vscale step %c1 {
+/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
+/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
+/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
+/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+/// %slice_i = affine.apply #map(%idx)[%i]
+/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
+/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
+/// : vector<4xf32>, memref<?x?xf32>
+/// }
+/// ```
+struct ScalableTransposeTransferWriteConversion
+ : VectorToSCFPattern<vector::TransferWriteOp> {
+ using VectorToSCFPattern::VectorToSCFPattern;
+
+ LogicalResult matchAndRewrite(TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ if (isTensorOp(writeOp) && !options.lowerTensors) {
+ return rewriter.notifyMatchFailure(
+ writeOp, "lowering tensor transfers is disabled");
+ }
+
+ auto vector = writeOp.getVector();
+ auto vectorType = vector.getType();
+ auto scalableFlags = vectorType.getScalableDims();
+ if (scalableFlags != ArrayRef<bool>{true, false}) {
+ return rewriter.notifyMatchFailure(
+ writeOp, "expected vector of the form vector<[N]xMxty>");
+ }
+
+ auto permutationMap = writeOp.getPermutationMap();
+ if (!permutationMap.isIdentity()) {
+ return rewriter.notifyMatchFailure(
+ writeOp, "non-identity permutations are unsupported (lower first)");
+ }
+
+ if (!writeOp.isDimInBounds(0)) {
+ return rewriter.notifyMatchFailure(
+ writeOp, "out-of-bounds dims are unsupported (use masking)");
+ }
+
+ auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp ||
+ transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
+ return rewriter.notifyMatchFailure(writeOp, "source not transpose");
+ }
+
+ auto loc = writeOp.getLoc();
+ auto createVscaleMultiple =
----------------
banach-space wrote:
[nit] spell-out `auto`
https://github.com/llvm/llvm-project/pull/101353
More information about the Mlir-commits
mailing list