[Mlir-commits] [mlir] [mlir][vector] Add scalable lowering for `transfer_write(transpose)` (PR #101353)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Aug 9 08:03:01 PDT 2024


================
@@ -987,6 +988,185 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
   }
 };
 
+/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
+/// and ConstantMaskOp.
+template <typename VscaleConstantBuilder>
+static FailureOr<SmallVector<OpFoldResult>>
+getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
+  if (!mask)
+    return SmallVector<OpFoldResult>{};
+  if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
+    return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
+      return OpFoldResult(dimSize);
+    });
+  }
+  if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+    int dimIdx = 0;
+    VectorType maskType = constantMask.getVectorType();
+    auto indexType = IndexType::get(mask.getContext());
+    return llvm::map_to_vector(
+        constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
+          // A scalable dim in a constant_mask means vscale x dimSize.
+          if (maskType.getScalableDims()[dimIdx++])
+            return OpFoldResult(createVscaleMultiple(dimSize));
+          return OpFoldResult(IntegerAttr::get(indexType, dimSize));
+        });
+  }
+  return failure();
+}
+
+/// Scalable vector lowering of transfer_write(transpose). This lowering only
+/// supports rank 2 (scalable) vectors, but can be used in conjunction with
+/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
+/// unrolls until the first scalable dimension.
----------------
MacDue wrote:

I've added a comment above L1080 instead (which explains things).

https://github.com/llvm/llvm-project/pull/101353


More information about the Mlir-commits mailing list