[Mlir-commits] [mlir] [mlir][vector] Update `castAwayContractionLeadingOneDim` to omit transposes solely on leading unit dims. (PR #85694)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 19 09:47:16 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
Author: Kojo Acquah (KoolJBlack)
<details>
<summary>Changes</summary>
Fixes #<!-- -->85691
---
Full diff: https://github.com/llvm/llvm-project/pull/85694.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+17-2)
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+22)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..6b69f5f1932ad7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -399,13 +399,28 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
transposeResults.push_back(targetExpr);
}
}
+
+ // Check if the transpose effects outer unit dims only. Such transposes do
+ // not materially effect the underlying vector and can be omitted.
+ bool tranposeNonOuterUnitDims = false;
+ for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
+ if (perm[i] != i && i != (int64_t)perm.size() - 1) {
+ if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
+ 1) {
+ tranposeNonOuterUnitDims = true;
+ }
+ }
+ }
+
// Do the tranpose now if needed so that we can drop the
// correct dim using extract later.
if (tranposeNeeded) {
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
- operands[it.index()] = rewriter.create<vector::TransposeOp>(
- contractOp.getLoc(), operands[it.index()], perm);
+ if (tranposeNonOuterUnitDims) {
+ operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+ contractOp.getLoc(), operands[it.index()], perm);
+ }
}
}
// We have taken care to have the dim to be dropped be
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..31b0867c851f58 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -166,6 +166,28 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
// -----
+// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
+// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
+// CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
+// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
+// CHECK: return %[[VAL_6]] : vector<1x8xi32>
+// CHECK: }
+func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
+ %rhs: vector<1x8x8xi32>,
+ %acc: vector<1x8xi32>) -> vector<1x8xi32> {
+ %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
+ return %result : vector<1x8xi32>
+}
+
+// -----
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
``````````
</details>
https://github.com/llvm/llvm-project/pull/85694
More information about the Mlir-commits
mailing list