[Mlir-commits] [mlir] [mlir][vector] Propagate scalability in TransferWriteNonPermutationLowering (PR #85632)
Crefeda Rodrigues
llvmlistbot at llvm.org
Wed Mar 20 08:00:24 PDT 2024
https://github.com/cfRod updated https://github.com/llvm/llvm-project/pull/85632
>From f2f28d739561822068a24c24f0aba5f7ea5bc057 Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
Date: Mon, 18 Mar 2024 11:26:51 +0000
Subject: [PATCH 1/4] [mlir][vector] Fix TransferWriteNonPermutationLowering
for scalable vectors
Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
---
.../Vector/Transforms/LowerVectorTransfer.cpp | 12 ++++++++++--
.../vector-transfer-permutation-lowering.mlir | 19 +++++++++++++++++++
2 files changed, 29 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..6c63928a009377 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -41,8 +41,16 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
SmallVector<int64_t> newShape(addedRank, 1);
newShape.append(originalVecType.getShape().begin(),
originalVecType.getShape().end());
- VectorType newVecType =
- VectorType::get(newShape, originalVecType.getElementType());
+
+ ArrayRef<bool> originalScalableDims = originalVecType.getScalableDims();
+ SmallVector<bool> tempScalableDims(originalVecType.getShape().size());
+ for (const auto &pos : llvm::enumerate(originalScalableDims)) {
+ tempScalableDims[pos.index()] = originalScalableDims[pos.index()];
+ }
+ SmallVector<bool> newScalableDims(addedRank, 0);
+ newScalableDims.append(tempScalableDims);
+ VectorType newVecType = VectorType::get(
+ newShape, originalVecType.getElementType(), newScalableDims);
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..20afa574327d05 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,25 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
return %1 : vector<8x[4]x2xf32>
}
+// CHECK: func.func @permutation_with_mask_transfer_write_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<4x[8]xi1>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = vector.broadcast %[[VAL_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
+// CHECK: %[[VAL_5:.*]] = vector.broadcast %[[VAL_2]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
+// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
+// CHECK: %[[VAL_7:.*]] = vector.transpose %[[VAL_4]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
+// CHECK: vector.transfer_write %[[VAL_7]], %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_6]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
+// CHECK: return
+// CHECK: }
+func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+
+ return
+}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
>From 4e1a036d8007ae4f8e18195183a728fd234356c9 Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <65665931+cfRod at users.noreply.github.com>
Date: Tue, 19 Mar 2024 16:33:55 +0000
Subject: [PATCH 2/4] Update
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Co-authored-by: Benjamin Maxwell <macdue at dueutil.tech>
---
.../Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 10 +++-------
1 file changed, 3 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 6c63928a009377..4ce5974e27ab22 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -42,13 +42,9 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
newShape.append(originalVecType.getShape().begin(),
originalVecType.getShape().end());
- ArrayRef<bool> originalScalableDims = originalVecType.getScalableDims();
- SmallVector<bool> tempScalableDims(originalVecType.getShape().size());
- for (const auto &pos : llvm::enumerate(originalScalableDims)) {
- tempScalableDims[pos.index()] = originalScalableDims[pos.index()];
- }
- SmallVector<bool> newScalableDims(addedRank, 0);
- newScalableDims.append(tempScalableDims);
+ SmallVector<bool> newScalableDims(addedRank, false);
+ newScalableDims.append(originalVecType.getScalableDims().begin(),
+ originalVecType.getScalableDims().end());
VectorType newVecType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
>From 2ae81116c32a646331484167b6c9a687cc80956a Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
Date: Tue, 19 Mar 2024 16:47:58 +0000
Subject: [PATCH 3/4] fix clang-format issues
Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues at arm.com>
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4ce5974e27ab22..0693aa596cb28f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -43,7 +43,7 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
originalVecType.getShape().end());
SmallVector<bool> newScalableDims(addedRank, false);
- newScalableDims.append(originalVecType.getScalableDims().begin(),
+ newScalableDims.append(originalVecType.getScalableDims().begin(),
originalVecType.getScalableDims().end());
VectorType newVecType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
>From e40056abe65dbe140caa35ff85a3a97217b6a676 Mon Sep 17 00:00:00 2001
From: Crefeda Rodrigues <65665931+cfRod at users.noreply.github.com>
Date: Wed, 20 Mar 2024 15:00:17 +0000
Subject: [PATCH 4/4] Update
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Co-authored-by: Benjamin Maxwell <macdue at dueutil.tech>
---
.../Dialect/Vector/vector-transfer-permutation-lowering.mlir | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 20afa574327d05..9adc5c1b31f7eb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -52,7 +52,6 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
// CHECK: %[[VAL_7:.*]] = vector.transpose %[[VAL_4]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
// CHECK: vector.transfer_write %[[VAL_7]], %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_6]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
// CHECK: return
-// CHECK: }
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
%c0 = arith.constant 0 : index
vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
More information about the Mlir-commits
mailing list