[Mlir-commits] [mlir] Fix unsupported transpose ops scalable (PR #86163)

Crefeda Rodrigues llvmlistbot at llvm.org
Thu Mar 21 11:05:17 PDT 2024


https://github.com/cfRod created https://github.com/llvm/llvm-project/pull/86163

Addresses comment in https://github.com/llvm/llvm-project/pull/85632#discussion_r1530727576

>From f2f28d739561822068a24c24f0aba5f7ea5bc057 Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
Date: Mon, 18 Mar 2024 11:26:51 +0000
Subject: [PATCH 1/6] [mlir][vector] Fix TransferWriteNonPermutationLowering
 for scalable vectors

Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 12 ++++++++++--
 .../vector-transfer-permutation-lowering.mlir | 19 +++++++++++++++++++
 2 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..6c63928a009377 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -41,8 +41,16 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
   SmallVector<int64_t> newShape(addedRank, 1);
   newShape.append(originalVecType.getShape().begin(),
                   originalVecType.getShape().end());
-  VectorType newVecType =
-      VectorType::get(newShape, originalVecType.getElementType());
+
+  ArrayRef<bool> originalScalableDims = originalVecType.getScalableDims();
+  SmallVector<bool> tempScalableDims(originalVecType.getShape().size());
+  for (const auto &pos : llvm::enumerate(originalScalableDims)) {
+    tempScalableDims[pos.index()] = originalScalableDims[pos.index()];
+  }
+  SmallVector<bool> newScalableDims(addedRank, 0);
+  newScalableDims.append(tempScalableDims);
+  VectorType newVecType = VectorType::get(
+      newShape, originalVecType.getElementType(), newScalableDims);
   return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..20afa574327d05 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,25 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
   return %1 : vector<8x[4]x2xf32>
 }
 
+// CHECK:           func.func @permutation_with_mask_transfer_write_scalable(
+// CHECK-SAME:        %[[VAL_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:        %[[VAL_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME:        %[[VAL_2:.*]]: vector<4x[8]xi1>) {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             %[[VAL_4:.*]] = vector.broadcast %[[VAL_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
+// CHECK:             %[[VAL_5:.*]] = vector.broadcast %[[VAL_2]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
+// CHECK:             %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
+// CHECK:             %[[VAL_7:.*]] = vector.transpose %[[VAL_4]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
+// CHECK:             vector.transfer_write %[[VAL_7]], %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_6]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
+// CHECK:             return
+// CHECK:           }
+func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
+     %c0 = arith.constant 0 : index
+      vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+
+    return
+}
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

>From 4e1a036d8007ae4f8e18195183a728fd234356c9 Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <65665931+cfRod at users.noreply.github.com>
Date: Tue, 19 Mar 2024 16:33:55 +0000
Subject: [PATCH 2/6] Update
 mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Co-authored-by: Benjamin Maxwell <macdue at dueutil.tech>
---
 .../Dialect/Vector/Transforms/LowerVectorTransfer.cpp  | 10 +++-------
 1 file changed, 3 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 6c63928a009377..4ce5974e27ab22 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -42,13 +42,9 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
   newShape.append(originalVecType.getShape().begin(),
                   originalVecType.getShape().end());
 
-  ArrayRef<bool> originalScalableDims = originalVecType.getScalableDims();
-  SmallVector<bool> tempScalableDims(originalVecType.getShape().size());
-  for (const auto &pos : llvm::enumerate(originalScalableDims)) {
-    tempScalableDims[pos.index()] = originalScalableDims[pos.index()];
-  }
-  SmallVector<bool> newScalableDims(addedRank, 0);
-  newScalableDims.append(tempScalableDims);
+  SmallVector<bool> newScalableDims(addedRank, false);
+  newScalableDims.append(originalVecType.getScalableDims().begin(), 
+                         originalVecType.getScalableDims().end());
   VectorType newVecType = VectorType::get(
       newShape, originalVecType.getElementType(), newScalableDims);
   return builder.create<vector::BroadcastOp>(loc, newVecType, vec);

>From 2ae81116c32a646331484167b6c9a687cc80956a Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
Date: Tue, 19 Mar 2024 16:47:58 +0000
Subject: [PATCH 3/6] fix clang-format issues

Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4ce5974e27ab22..0693aa596cb28f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -43,7 +43,7 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
                   originalVecType.getShape().end());
 
   SmallVector<bool> newScalableDims(addedRank, false);
-  newScalableDims.append(originalVecType.getScalableDims().begin(), 
+  newScalableDims.append(originalVecType.getScalableDims().begin(),
                          originalVecType.getScalableDims().end());
   VectorType newVecType = VectorType::get(
       newShape, originalVecType.getElementType(), newScalableDims);

>From e40056abe65dbe140caa35ff85a3a97217b6a676 Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <65665931+cfRod at users.noreply.github.com>
Date: Wed, 20 Mar 2024 15:00:17 +0000
Subject: [PATCH 4/6] Update
 mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Co-authored-by: Benjamin Maxwell <macdue at dueutil.tech>
---
 .../Dialect/Vector/vector-transfer-permutation-lowering.mlir     | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 20afa574327d05..9adc5c1b31f7eb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -52,7 +52,6 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
 // CHECK:             %[[VAL_7:.*]] = vector.transpose %[[VAL_4]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
 // CHECK:             vector.transfer_write %[[VAL_7]], %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_6]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
 // CHECK:             return
-// CHECK:           }
 func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
      %c0 = arith.constant 0 : index
       vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>

>From 32b998d043f1b5f933b79cc05adb335213983951 Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
Date: Thu, 21 Mar 2024 16:59:44 +0000
Subject: [PATCH 5/6] Name values in test file

Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
---
 .../vector-transfer-permutation-lowering.mlir  | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 9adc5c1b31f7eb..31bd19c0be8e83 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -42,15 +42,15 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
 }
 
 // CHECK:           func.func @permutation_with_mask_transfer_write_scalable(
-// CHECK-SAME:        %[[VAL_0:.*]]: vector<4x[8]xi16>,
-// CHECK-SAME:        %[[VAL_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
-// CHECK-SAME:        %[[VAL_2:.*]]: vector<4x[8]xi1>) {
-// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:             %[[VAL_4:.*]] = vector.broadcast %[[VAL_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
-// CHECK:             %[[VAL_5:.*]] = vector.broadcast %[[VAL_2]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
-// CHECK:             %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
-// CHECK:             %[[VAL_7:.*]] = vector.transpose %[[VAL_4]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
-// CHECK:             vector.transfer_write %[[VAL_7]], %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_6]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
+// CHECK-SAME:        %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:        %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>) {
+// CHECK:             %[[C0:.*]] = arith.constant 0 : index
+// CHECK:             %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
+// CHECK:             %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
+// CHECK:             %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
+// CHECK:             %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
+// CHECK:             vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[TRANSPOSE_1]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
 // CHECK:             return
 func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
      %c0 = arith.constant 0 : index

>From 0c6c1d16650ccd2f509c1328e77ad2804656af14 Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
Date: Thu, 21 Mar 2024 17:24:42 +0000
Subject: [PATCH 6/6] Fix LowerVectorTransfer patterns to remove unsupported
 transpose ops for scalable vectors

Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 11 ++++++--
 .../vector-transfer-permutation-lowering.mlir | 25 +++++++++----------
 2 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 0693aa596cb28f..570a5222862b72 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -205,12 +205,19 @@ struct TransferWritePermutationLowering
     // Generate new transfer_write operation.
     Value newVec = rewriter.create<vector::TransposeOp>(
         op.getLoc(), op.getVector(), indices);
+
+    auto vectorType = cast<VectorType>(newVec.getType());
+
+    if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
+      rewriter.eraseOp(newVec.getDefiningOp());
+      return failure();
+    }
+
     auto newMap = AffineMap::getMinorIdentityMap(
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
         op.getMask(), newInBoundsAttr);
-
     return success();
   }
 };
@@ -273,7 +280,7 @@ struct TransferWriteNonPermutationLowering
                                     missingInnerDim.size());
     // Mask: add unit dims at the end of the shape.
     Value newMask;
-    if (op.getMask())
+    if (op.getMask() && !op.getVectorType().isScalable())
       newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
                                missingInnerDim.size());
     exprs.append(map.getResults().begin(), map.getResults().end());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 31bd19c0be8e83..83a7f21daf683f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,24 +41,23 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
   return %1 : vector<8x[4]x2xf32>
 }
 
-// CHECK:           func.func @permutation_with_mask_transfer_write_scalable(
-// CHECK-SAME:        %[[ARG_0:.*]]: vector<4x[8]xi16>,
-// CHECK-SAME:        %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
-// CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>) {
-// CHECK:             %[[C0:.*]] = arith.constant 0 : index
-// CHECK:             %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
-// CHECK:             %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
-// CHECK:             %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
-// CHECK:             %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
-// CHECK:             vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[TRANSPOSE_1]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
-// CHECK:             return
-func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
+// CHECK-LABEL:   func.func @permutation_with_mask_transfer_write_scalable(
+// CHECK-SAME:                                                             %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:                                                             %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME:                                                             %[[MASK:.*]]: vector<4x[8]xi1>) {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
+// CHECK:           vector.transfer_write %[[BCAST]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true, true, true], permutation_map = #map} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+// CHECK:           return
+// CHECK:         }
+  func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
      %c0 = arith.constant 0 : index
       vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
 } : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
 
     return
-}
+  }
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op



More information about the Mlir-commits mailing list