[Mlir-commits] [mlir] [mlir][vector] Add scalable lowering for `transfer_write(transpose)` (PR #101353)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Aug 7 13:30:55 PDT 2024
================
@@ -987,6 +988,185 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
}
};
+/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
+/// and ConstantMaskOp.
+template <typename VscaleConstantBuilder>
+static FailureOr<SmallVector<OpFoldResult>>
+getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
+ if (!mask)
+ return SmallVector<OpFoldResult>{};
+ if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
+ return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
+ return OpFoldResult(dimSize);
+ });
+ }
+ if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+ int dimIdx = 0;
+ VectorType maskType = constantMask.getVectorType();
+ auto indexType = IndexType::get(mask.getContext());
+ return llvm::map_to_vector(
+ constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
+ // A scalable dim in a constant_mask means vscale x dimSize.
+ if (maskType.getScalableDims()[dimIdx++])
+ return OpFoldResult(createVscaleMultiple(dimSize));
+ return OpFoldResult(IntegerAttr::get(indexType, dimSize));
+ });
+ }
+ return failure();
+}
+
+/// Scalable vector lowering of transfer_write(transpose). This lowering only
+/// supports rank 2 (scalable) vectors, but can be used in in conjunction with
----------------
banach-space wrote:
```suggestion
/// supports rank 2 (scalable) vectors, but can be used in conjunction with
```
https://github.com/llvm/llvm-project/pull/101353
More information about the Mlir-commits
mailing list