[Mlir-commits] [mlir] [mlir][vector] Rewrite vector transfer write with unit dims for scalable vectors (PR #85270)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 14 10:11:23 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Crefeda Rodrigues (cfRod)
<details>
<summary>Changes</summary>
This PR fixes the issue of lowering vector transfer writes on scalable vectors with unit dims to vector broadcast ops and vector transpose ops - where the scalable dims are dropped.
---
Full diff: https://github.com/llvm/llvm-project/pull/85270.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+67)
- (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+16)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..cef8a497a80996 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -226,6 +226,38 @@ struct TransferWritePermutationLowering
/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
/// vector<1x8x16xf32>
/// ```
+/// Returns the number of dims that aren't unit dims.
+static int getReducedRank(ArrayRef<int64_t> shape) {
+ return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
+}
+
+static int getFirstNonUnitDim(MemRefType oldType) {
+ int idx = 0;
+ for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
+ if (dimSize == 1) {
+ continue;
+ } else {
+ idx = dimIdx;
+ break;
+ }
+ }
+ return idx;
+}
+
+static int getLasttNonUnitDim(MemRefType oldType) {
+ int idx = 0;
+ for (auto [dimIdx, dimSize] :
+ llvm::enumerate(llvm::reverse(oldType.getShape()))) {
+ if (dimSize == 1) {
+ continue;
+ } else {
+ idx = oldType.getRank() - (dimIdx)-1;
+ break;
+ }
+ }
+ return idx;
+}
+
struct TransferWriteNonPermutationLowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
@@ -264,6 +296,41 @@ struct TransferWriteNonPermutationLowering
missingInnerDim.push_back(i);
exprs.push_back(rewriter.getAffineDimExpr(i));
}
+
+ // Fix for lowering transfer write when we have Scalable vectors and unit
+ // dims
+ auto sourceVectorType = op.getVectorType();
+ auto memRefType = dyn_cast<MemRefType>(op.getShapedType());
+
+ if (sourceVectorType.isScalable() && !memRefType.hasStaticShape()) {
+ int reducedRank = getReducedRank(memRefType.getShape());
+
+ auto loc = op.getLoc();
+ SmallVector<Value> indices(
+ reducedRank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ // Check if the result shapes has unit dim before and after the scalable
+ // and non-scalable dim
+ int firstIdx = getFirstNonUnitDim(memRefType);
+ int lastIdx = getLasttNonUnitDim(memRefType);
+
+ SmallVector<ReassociationIndices> reassociation;
+ ReassociationIndices collapsedFirstIndices;
+ for (int64_t i = 0; i < firstIdx + 1; ++i)
+ collapsedFirstIndices.push_back(i);
+ reassociation.push_back(ReassociationIndices{collapsedFirstIndices});
+ ReassociationIndices collapsedIndices;
+ for (int64_t i = lastIdx; i < memRefType.getRank(); ++i)
+ collapsedIndices.push_back(i);
+ reassociation.push_back(collapsedIndices);
+ // Create mem collapse op
+ auto newOp = rewriter.create<memref::CollapseShapeOp>(loc, op.getSource(),
+ reassociation);
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(op, op.getVector(),
+ newOp, indices);
+ return success();
+ }
+
// Vector: add unit dims at the beginning of the shape.
Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
missingInnerDim.size());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..a654274f0a73e9 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,22 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
return %1 : vector<8x[4]x2xf32>
}
+// CHECK-LABEL: func.func @permutation_with_masked_transfer_write_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<4x[8]xi1>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3, 4, 5, 6]] : memref<1x4x?x1x1x1x1xi16> into memref<4x?xi16>
+// CHECK: vector.transfer_write %[[VAL_0]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, false]} : vector<4x[8]xi16>, memref<4x?xi16>
+// CHECK: return
+// CHECK: }
+ func.func @permutation_with_masked_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+ return
+ }
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
``````````
</details>
https://github.com/llvm/llvm-project/pull/85270
More information about the Mlir-commits
mailing list