[Mlir-commits] [mlir] [mlir][vector] Add scalable lowering for `transfer_write(transpose)` (PR #101353)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Aug 9 09:30:54 PDT 2024
================
@@ -987,6 +988,185 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
}
};
+/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
+/// and ConstantMaskOp.
+template <typename VscaleConstantBuilder>
+static FailureOr<SmallVector<OpFoldResult>>
+getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
+ if (!mask)
+ return SmallVector<OpFoldResult>{};
+ if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
+ return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
+ return OpFoldResult(dimSize);
+ });
+ }
+ if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+ int dimIdx = 0;
+ VectorType maskType = constantMask.getVectorType();
+ auto indexType = IndexType::get(mask.getContext());
+ return llvm::map_to_vector(
+ constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
+ // A scalable dim in a constant_mask means vscale x dimSize.
+ if (maskType.getScalableDims()[dimIdx++])
+ return OpFoldResult(createVscaleMultiple(dimSize));
+ return OpFoldResult(IntegerAttr::get(indexType, dimSize));
+ });
+ }
+ return failure();
+}
+
+/// Scalable vector lowering of transfer_write(transpose). This lowering only
+/// supports rank 2 (scalable) vectors, but can be used in conjunction with
+/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
+/// unrolls until the first scalable dimension.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vec, [1, 0]
+/// : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
+/// : vector<[4]x4xf32>, memref<?x?xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %c0 = arith.constant 0 : index
+/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %vscale = vector.vscale
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// scf.for %idx = %c0 to %c4_vscale step %c1 {
+/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
+/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
+/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
+/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+/// %slice_i = affine.apply #map(%idx)[%i]
+/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
+/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
+/// : vector<4xf32>, memref<?x?xf32>
+/// }
+/// ```
+struct ScalableTransposeTransferWriteConversion
+ : VectorToSCFPattern<vector::TransferWriteOp> {
+ using VectorToSCFPattern::VectorToSCFPattern;
+
+ LogicalResult matchAndRewrite(TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ if (isTensorOp(writeOp) && !options.lowerTensors) {
+ return rewriter.notifyMatchFailure(
+ writeOp, "lowering tensor transfers is disabled");
+ }
----------------
MacDue wrote:
Lower tensors does not work like that. The patterns work can work for both memrefs and tensors (rather than only memrefs or only tensors). You can conceivably have some IR that has both types of transfers, and if lower-tensors is disabled those same patterns will still work for the transfer ops that work on memrefs.
https://github.com/llvm/llvm-project/pull/101353
More information about the Mlir-commits
mailing list