[Mlir-commits] [mlir] [mlir][vector] Rewrite vector transfer write with unit dims for scalable vectors (PR #85270)

Crefeda Rodrigues llvmlistbot at llvm.org
Thu Mar 14 10:10:35 PDT 2024


https://github.com/cfRod created https://github.com/llvm/llvm-project/pull/85270

This PR fixes the issue of lowering vector transfer writes on scalable vectors with unit dims to vector broadcast ops and vector transpose ops - where the scalable dims are dropped. 

>From 44b533ce60cda0761b50d0b98e99729d9224977d Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
Date: Thu, 14 Mar 2024 17:01:01 +0000
Subject: [PATCH] Rewrite vector transfer write with unit dims for scalable
 vectors

Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 67 +++++++++++++++++++
 .../vector-transfer-permutation-lowering.mlir | 16 +++++
 2 files changed, 83 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..cef8a497a80996 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -226,6 +226,38 @@ struct TransferWritePermutationLowering
 ///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
 ///     vector<1x8x16xf32>
 /// ```
+/// Returns the number of dims that aren't unit dims.
+static int getReducedRank(ArrayRef<int64_t> shape) {
+  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
+}
+
+static int getFirstNonUnitDim(MemRefType oldType) {
+  int idx = 0;
+  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
+    if (dimSize == 1) {
+      continue;
+    } else {
+      idx = dimIdx;
+      break;
+    }
+  }
+  return idx;
+}
+
+static int getLasttNonUnitDim(MemRefType oldType) {
+  int idx = 0;
+  for (auto [dimIdx, dimSize] :
+       llvm::enumerate(llvm::reverse(oldType.getShape()))) {
+    if (dimSize == 1) {
+      continue;
+    } else {
+      idx = oldType.getRank() - (dimIdx)-1;
+      break;
+    }
+  }
+  return idx;
+}
+
 struct TransferWriteNonPermutationLowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -264,6 +296,41 @@ struct TransferWriteNonPermutationLowering
       missingInnerDim.push_back(i);
       exprs.push_back(rewriter.getAffineDimExpr(i));
     }
+
+    // Fix for lowering transfer write when we have Scalable vectors and unit
+    // dims
+    auto sourceVectorType = op.getVectorType();
+    auto memRefType = dyn_cast<MemRefType>(op.getShapedType());
+
+    if (sourceVectorType.isScalable() && !memRefType.hasStaticShape()) {
+      int reducedRank = getReducedRank(memRefType.getShape());
+
+      auto loc = op.getLoc();
+      SmallVector<Value> indices(
+          reducedRank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+      // Check if the result shapes has unit dim before and after the scalable
+      // and non-scalable dim
+      int firstIdx = getFirstNonUnitDim(memRefType);
+      int lastIdx = getLasttNonUnitDim(memRefType);
+
+      SmallVector<ReassociationIndices> reassociation;
+      ReassociationIndices collapsedFirstIndices;
+      for (int64_t i = 0; i < firstIdx + 1; ++i)
+        collapsedFirstIndices.push_back(i);
+      reassociation.push_back(ReassociationIndices{collapsedFirstIndices});
+      ReassociationIndices collapsedIndices;
+      for (int64_t i = lastIdx; i < memRefType.getRank(); ++i)
+        collapsedIndices.push_back(i);
+      reassociation.push_back(collapsedIndices);
+      // Create mem collapse op
+      auto newOp = rewriter.create<memref::CollapseShapeOp>(loc, op.getSource(),
+                                                            reassociation);
+      rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(op, op.getVector(),
+                                                           newOp, indices);
+      return success();
+    }
+
     // Vector: add unit dims at the beginning of the shape.
     Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
                                     missingInnerDim.size());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..a654274f0a73e9 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,22 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
   return %1 : vector<8x[4]x2xf32>
 }
 
+// CHECK-LABEL:     func.func @permutation_with_masked_transfer_write_scalable(
+// CHECK-SAME:        %[[VAL_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:        %[[VAL_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME:        %[[VAL_2:.*]]: vector<4x[8]xi1>) {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3, 4, 5, 6]] : memref<1x4x?x1x1x1x1xi16> into memref<4x?xi16>
+// CHECK:             vector.transfer_write %[[VAL_0]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, false]} : vector<4x[8]xi16>, memref<4x?xi16>
+// CHECK:             return
+// CHECK:           }
+  func.func @permutation_with_masked_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
+     %c0 = arith.constant 0 : index
+      vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+    return
+  }
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op



More information about the Mlir-commits mailing list