[Mlir-commits] [mlir] [mlir][vector] Update `castAwayContractionLeadingOneDim` to omit transposes solely on leading unit dims. (PR #85694)
Kojo Acquah
llvmlistbot at llvm.org
Thu Mar 21 15:04:40 PDT 2024
https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/85694
>From 5f038e13bbc7ffc919ebcdfac676e88c09f9765d Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Mon, 18 Mar 2024 18:25:48 +0000
Subject: [PATCH 1/2] only tranpose non leading unit dims
---
.../Transforms/VectorDropLeadUnitDim.cpp | 19 ++++++++++++++--
.../vector-dropleadunitdim-transforms.mlir | 22 +++++++++++++++++++
2 files changed, 39 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..6b69f5f1932ad7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -399,13 +399,28 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
transposeResults.push_back(targetExpr);
}
}
+
+ // Check if the transpose effects outer unit dims only. Such transposes do
+ // not materially effect the underlying vector and can be omitted.
+ bool tranposeNonOuterUnitDims = false;
+ for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
+ if (perm[i] != i && i != (int64_t)perm.size() - 1) {
+ if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
+ 1) {
+ tranposeNonOuterUnitDims = true;
+ }
+ }
+ }
+
// Do the tranpose now if needed so that we can drop the
// correct dim using extract later.
if (tranposeNeeded) {
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
- operands[it.index()] = rewriter.create<vector::TransposeOp>(
- contractOp.getLoc(), operands[it.index()], perm);
+ if (tranposeNonOuterUnitDims) {
+ operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+ contractOp.getLoc(), operands[it.index()], perm);
+ }
}
}
// We have taken care to have the dim to be dropped be
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..31b0867c851f58 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -166,6 +166,28 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
// -----
+// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
+// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
+// CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
+// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
+// CHECK: return %[[VAL_6]] : vector<1x8xi32>
+// CHECK: }
+func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
+ %rhs: vector<1x8x8xi32>,
+ %acc: vector<1x8xi32>) -> vector<1x8xi32> {
+ %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
+ return %result : vector<1x8xi32>
+}
+
+// -----
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
>From 163ea7347b90de514d82b81d0f4328d1acaabe7d Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Thu, 21 Mar 2024 22:03:23 +0000
Subject: [PATCH 2/2] review comments
---
.../Transforms/VectorDropLeadUnitDim.cpp | 17 ++++++++-------
.../vector-dropleadunitdim-transforms.mlir | 21 ++++++++++---------
2 files changed, 20 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 6b69f5f1932ad7..b9101577eba553 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -400,15 +400,16 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
}
}
- // Check if the transpose effects outer unit dims only. Such transposes do
- // not materially effect the underlying vector and can be omitted.
+ // Checks if only the outer, unit dimensions (of size 1) are permuted.
+ // Such transposes do not materially effect the underlying vector and can
+ // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
bool tranposeNonOuterUnitDims = false;
- for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
- if (perm[i] != i && i != (int64_t)perm.size() - 1) {
- if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
- 1) {
- tranposeNonOuterUnitDims = true;
- }
+ for (auto [index, dim] :
+ llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
+ if (dim != static_cast<int64_t>(index) &&
+ operands[it.index()].getType().cast<ShapedType>().getDimSize(
+ index) != 1) {
+ tranposeNonOuterUnitDims = true;
}
}
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 31b0867c851f58..3dd38c339f50a3 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -170,17 +170,18 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
-// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
-// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
-// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
-// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
-// CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
-// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
-// CHECK: return %[[VAL_6]] : vector<1x8xi32>
+// CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(
+// CHECK-SAME: %[[LHS:.*]]: vector<1x1x8xi32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<1x8x8xi32>,
+// CHECK-SAME: %[[ACC:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK: %[[EXT_LHS:.*]] = vector.extract %[[LHS]][0] : vector<1x8xi32> from vector<1x1x8xi32>
+// CHECK: %[[EXT_ACC:.*]] = vector.extract %[[ACC]][0] : vector<8xi32> from vector<1x8xi32>
+// CHECK: %[[RES:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[EXT_LHS]], %[[RHS]], %[[EXT_ACC]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
+// CHECK: %[[BROADCAST_RES:.*]] = vector.broadcast %[[RES]] : vector<8xi32> to vector<1x8xi32>
+// CHECK: return %[[BROADCAST_RES]] : vector<1x8xi32>
// CHECK: }
-func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
+// CHECK-NOT vector.transpose
+func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>,
%rhs: vector<1x8x8xi32>,
%acc: vector<1x8xi32>) -> vector<1x8xi32> {
%result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
More information about the Mlir-commits
mailing list