[Mlir-commits] [mlir] [mlir][vector] Update `castAwayContractionLeadingOneDim` to omit transposes solely on leading unit dims. (PR #85694)

Kojo Acquah llvmlistbot at llvm.org
Thu Mar 21 15:04:40 PDT 2024


https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/85694

>From 5f038e13bbc7ffc919ebcdfac676e88c09f9765d Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Mon, 18 Mar 2024 18:25:48 +0000
Subject: [PATCH 1/2] only tranpose non leading unit dims

---
 .../Transforms/VectorDropLeadUnitDim.cpp      | 19 ++++++++++++++--
 .../vector-dropleadunitdim-transforms.mlir    | 22 +++++++++++++++++++
 2 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..6b69f5f1932ad7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -399,13 +399,28 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
           transposeResults.push_back(targetExpr);
         }
       }
+
+      // Check if the transpose effects outer unit dims only. Such transposes do
+      // not materially effect the underlying vector and can be omitted.
+      bool tranposeNonOuterUnitDims = false;
+      for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
+        if (perm[i] != i && i != (int64_t)perm.size() - 1) {
+          if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
+              1) {
+            tranposeNonOuterUnitDims = true;
+          }
+        }
+      }
+
       // Do the tranpose now if needed so that we can drop the
       // correct dim using extract later.
       if (tranposeNeeded) {
         map = AffineMap::get(map.getNumDims(), 0, transposeResults,
                              contractOp.getContext());
-        operands[it.index()] = rewriter.create<vector::TransposeOp>(
-            contractOp.getLoc(), operands[it.index()], perm);
+        if (tranposeNonOuterUnitDims) {
+          operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+              contractOp.getLoc(), operands[it.index()], perm);
+        }
       }
     }
     // We have taken care to have the dim to be dropped be
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..31b0867c851f58 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -166,6 +166,28 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
 
 // -----
 
+// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dims_vec_mat(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: vector<1x1x8xi32>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: vector<1x8x8xi32>,
+// CHECK-SAME:                                 %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK:           %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
+// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
+// CHECK:           %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
+// CHECK:           %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
+// CHECK:           return %[[VAL_6]] : vector<1x8xi32>
+// CHECK:         }
+func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
+                          %rhs: vector<1x8x8xi32>,
+                          %acc: vector<1x8xi32>) -> vector<1x8xi32> {
+  %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
+  return %result : vector<1x8xi32>
+}
+
+// -----
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>

>From 163ea7347b90de514d82b81d0f4328d1acaabe7d Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Thu, 21 Mar 2024 22:03:23 +0000
Subject: [PATCH 2/2] review comments

---
 .../Transforms/VectorDropLeadUnitDim.cpp      | 17 ++++++++-------
 .../vector-dropleadunitdim-transforms.mlir    | 21 ++++++++++---------
 2 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 6b69f5f1932ad7..b9101577eba553 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -400,15 +400,16 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
         }
       }
 
-      // Check if the transpose effects outer unit dims only. Such transposes do
-      // not materially effect the underlying vector and can be omitted.
+      // Checks if only the outer, unit dimensions (of size 1) are permuted.
+      // Such transposes do not materially effect the underlying vector and can
+      // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
       bool tranposeNonOuterUnitDims = false;
-      for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
-        if (perm[i] != i && i != (int64_t)perm.size() - 1) {
-          if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
-              1) {
-            tranposeNonOuterUnitDims = true;
-          }
+      for (auto [index, dim] :
+           llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
+        if (dim != static_cast<int64_t>(index) &&
+            operands[it.index()].getType().cast<ShapedType>().getDimSize(
+                index) != 1) {
+          tranposeNonOuterUnitDims = true;
         }
       }
 
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 31b0867c851f58..3dd38c339f50a3 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -170,17 +170,18 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
 // CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
 
-// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dims_vec_mat(
-// CHECK-SAME:                                 %[[VAL_0:.*]]: vector<1x1x8xi32>,
-// CHECK-SAME:                                 %[[VAL_1:.*]]: vector<1x8x8xi32>,
-// CHECK-SAME:                                 %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
-// CHECK:           %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
-// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
-// CHECK:           %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
-// CHECK:           %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
-// CHECK:           return %[[VAL_6]] : vector<1x8xi32>
+// CHECK-LABEL:   func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(
+// CHECK-SAME:                                 %[[LHS:.*]]: vector<1x1x8xi32>,
+// CHECK-SAME:                                 %[[RHS:.*]]: vector<1x8x8xi32>,
+// CHECK-SAME:                                 %[[ACC:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK:           %[[EXT_LHS:.*]] = vector.extract %[[LHS]][0] : vector<1x8xi32> from vector<1x1x8xi32>
+// CHECK:           %[[EXT_ACC:.*]] = vector.extract %[[ACC]][0] : vector<8xi32> from vector<1x8xi32>
+// CHECK:           %[[RES:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[EXT_LHS]], %[[RHS]], %[[EXT_ACC]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
+// CHECK:           %[[BROADCAST_RES:.*]] = vector.broadcast %[[RES]] : vector<8xi32> to vector<1x8xi32>
+// CHECK:           return %[[BROADCAST_RES]] : vector<1x8xi32>
 // CHECK:         }
-func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
+// CHECK-NOT        vector.transpose
+func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>,
                           %rhs: vector<1x8x8xi32>,
                           %acc: vector<1x8xi32>) -> vector<1x8xi32> {
   %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>



More information about the Mlir-commits mailing list