[Mlir-commits] [mlir] 66fed33 - [mlir][vector] Update `castAwayContractionLeadingOneDim` to omit transposes solely on leading unit dims. (#85694)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 3 16:27:05 PDT 2024


Author: Kojo Acquah
Date: 2024-04-03T19:27:01-04:00
New Revision: 66fed33db014bd705433e4b4f1ea766a8d71cadf

URL: https://github.com/llvm/llvm-project/commit/66fed33db014bd705433e4b4f1ea766a8d71cadf
DIFF: https://github.com/llvm/llvm-project/commit/66fed33db014bd705433e4b4f1ea766a8d71cadf.diff

LOG: [mlir][vector] Update `castAwayContractionLeadingOneDim` to omit transposes solely on leading unit dims.  (#85694)

Updates `castAwayContractionLeadingOneDim` to check for leading unit
dimensions before inserting `vector.transpose` ops.

Currently `castAwayContractionLeadingOneDim` removes all leading unit
dims based on the accumulator and transpose any subsequent operands to
match the accumulator indexing. This does not take into account if the
transpose is strictly necessary, for instance when given this
vector-matrix contract:
```mlir
  %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
```
Passing this through `castAwayContractionLeadingOneDim` pattern produces
the following:
```mlir
    %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x1x8xi32> to vector<1x1x8xi32>
    %1 = vector.extract %0[0] : vector<1x8xi32> from vector<1x1x8xi32>
    %2 = vector.extract %arg2[0] : vector<8xi32> from vector<1x8xi32>
    %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %arg1, %2 : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
    %4 = vector.broadcast %3 : vector<8xi32> to vector<1x8xi32>
```
The `vector.transpose` introduced does not affect the underlying data
layout (effectively a no op), but it cannot be folded automatically.
This change avoids inserting transposes when only leading unit
dimensions are involved.

Fixes #85691

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 593c1e53557ad8..8d733c5a8849b6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -398,13 +398,30 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
           transposeResults.push_back(targetExpr);
         }
       }
+
+      // Checks if only the outer, unit dimensions (of size 1) are permuted.
+      // Such transposes do not materially effect the underlying vector and can
+      // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
+      bool transposeNonOuterUnitDims = false;
+      auto operandShape = operands[it.index()].getType().cast<ShapedType>();
+      for (auto [index, dim] :
+           llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
+        if (dim != static_cast<int64_t>(index) &&
+            operandShape.getDimSize(index) != 1) {
+          transposeNonOuterUnitDims = true;
+          break;
+        }
+      }
+
       // Do the tranpose now if needed so that we can drop the
       // correct dim using extract later.
       if (tranposeNeeded) {
         map = AffineMap::get(map.getNumDims(), 0, transposeResults,
                              contractOp.getContext());
-        operands[it.index()] = rewriter.create<vector::TransposeOp>(
-            loc, operands[it.index()], perm);
+        if (transposeNonOuterUnitDims) {
+          operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+              loc, operands[it.index()], perm);
+        }
       }
     }
     // We have taken care to have the dim to be dropped be

diff  --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 3a120a56056cad..252aeb0c15cbeb 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -238,6 +238,17 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
   return %0: vector<1x1x2x16xf32>
 }
 
+// -----
+
+// CHECK-LABEL:   func.func @cast_away_contraction_does_not_transpose_leading_unit_dims
+// CHECK-NOT   vector.transpose
+// CHECK:           vector.contract
+func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>,
+                          %rhs: vector<1x8x8xi32>,
+                          %acc: vector<1x8xi32>) -> vector<1x8xi32> {
+  %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
+  return %result : vector<1x8xi32>
+}
 
 // -----
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
@@ -663,4 +674,3 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>,
   %sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
   return %sel : vector<1x16xi1>
 }
-


        


More information about the Mlir-commits mailing list