[Mlir-commits] [mlir] [mlir][vector] Add scalable lowering for `transfer_write(transpose)` (PR #101353)

Benjamin Maxwell llvmlistbot at llvm.org
Sat Aug 10 07:58:45 PDT 2024


================
@@ -987,6 +988,185 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
   }
 };
 
+/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
+/// and ConstantMaskOp.
+template <typename VscaleConstantBuilder>
+static FailureOr<SmallVector<OpFoldResult>>
+getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
+  if (!mask)
+    return SmallVector<OpFoldResult>{};
+  if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
+    return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
+      return OpFoldResult(dimSize);
+    });
+  }
+  if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+    int dimIdx = 0;
+    VectorType maskType = constantMask.getVectorType();
+    auto indexType = IndexType::get(mask.getContext());
+    return llvm::map_to_vector(
+        constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
+          // A scalable dim in a constant_mask means vscale x dimSize.
+          if (maskType.getScalableDims()[dimIdx++])
+            return OpFoldResult(createVscaleMultiple(dimSize));
+          return OpFoldResult(IntegerAttr::get(indexType, dimSize));
+        });
+  }
+  return failure();
+}
+
+/// Scalable vector lowering of transfer_write(transpose). This lowering only
+/// supports rank 2 (scalable) vectors, but can be used in conjunction with
+/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
+/// unrolls until the first scalable dimension.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vec, [1, 0]
+///    : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
+///    : vector<[4]x4xf32>,  memref<?x?xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %c0 = arith.constant 0 : index
+/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %vscale = vector.vscale
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// scf.for %idx = %c0 to %c4_vscale step %c1 {
+///   %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
+///   %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
+///   %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
+///   %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+///   %slice_i = affine.apply #map(%idx)[%i]
+///   %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
+///   vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
+///     : vector<4xf32>, memref<?x?xf32>
+/// }
+/// ```
+struct ScalableTransposeTransferWriteConversion
+    : VectorToSCFPattern<vector::TransferWriteOp> {
+  using VectorToSCFPattern::VectorToSCFPattern;
+
+  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    if (isTensorOp(writeOp) && !options.lowerTensors) {
+      return rewriter.notifyMatchFailure(
+          writeOp, "lowering tensor transfers is disabled");
+    }
----------------
MacDue wrote:

I have no immediate plans to do so, but I need this option as I want to be able to control when this lowering is enabled. Also, it does not make much sense under the existing options. It's not unrolling, so it should not be under "full-unroll", and I also don't want to enable the (non-scalable) loop-based transfer lowerings when I use this pattern.

https://github.com/llvm/llvm-project/pull/101353


More information about the Mlir-commits mailing list