[Mlir-commits] [mlir] Update `castAwayContractionLeadingOneDim` to omit transposes solely on leading unit dims. (PR #85694)

Kojo Acquah llvmlistbot at llvm.org
Mon Mar 18 13:31:37 PDT 2024


https://github.com/KoolJBlack created https://github.com/llvm/llvm-project/pull/85694

Fixes #85691

>From b1ff4b0f0380c2408f6f3aa078e13c06a9eeaaf9 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Mon, 18 Mar 2024 18:25:48 +0000
Subject: [PATCH] only tranpose non leading unit dims

---
 .../Transforms/VectorDropLeadUnitDim.cpp      | 19 ++++++++++++++--
 .../vector-dropleadunitdim-transforms.mlir    | 22 +++++++++++++++++++
 2 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..6b69f5f1932ad7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -399,13 +399,28 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
           transposeResults.push_back(targetExpr);
         }
       }
+
+      // Check if the transpose effects outer unit dims only. Such transposes do
+      // not materially effect the underlying vector and can be omitted.
+      bool tranposeNonOuterUnitDims = false;
+      for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
+        if (perm[i] != i && i != (int64_t)perm.size() - 1) {
+          if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
+              1) {
+            tranposeNonOuterUnitDims = true;
+          }
+        }
+      }
+
       // Do the tranpose now if needed so that we can drop the
       // correct dim using extract later.
       if (tranposeNeeded) {
         map = AffineMap::get(map.getNumDims(), 0, transposeResults,
                              contractOp.getContext());
-        operands[it.index()] = rewriter.create<vector::TransposeOp>(
-            contractOp.getLoc(), operands[it.index()], perm);
+        if (tranposeNonOuterUnitDims) {
+          operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+              contractOp.getLoc(), operands[it.index()], perm);
+        }
       }
     }
     // We have taken care to have the dim to be dropped be
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..31b0867c851f58 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -166,6 +166,28 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
 
 // -----
 
+// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dims_vec_mat(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: vector<1x1x8xi32>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: vector<1x8x8xi32>,
+// CHECK-SAME:                                 %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK:           %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
+// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
+// CHECK:           %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
+// CHECK:           %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
+// CHECK:           return %[[VAL_6]] : vector<1x8xi32>
+// CHECK:         }
+func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
+                          %rhs: vector<1x8x8xi32>,
+                          %acc: vector<1x8xi32>) -> vector<1x8xi32> {
+  %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
+  return %result : vector<1x8xi32>
+}
+
+// -----
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>



More information about the Mlir-commits mailing list