[Mlir-commits] [mlir] a490d38 - [mlir][Vector] Add ExtractOp folding when fed by a TransposeOp

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jul 10 08:11:17 PDT 2020


Author: Nicolas Vasilache
Date: 2020-07-10T11:09:27-04:00
New Revision: a490d387e6e6085b35a6850581b62db3d2d47009

URL: https://github.com/llvm/llvm-project/commit/a490d387e6e6085b35a6850581b62db3d2d47009
DIFF: https://github.com/llvm/llvm-project/commit/a490d387e6e6085b35a6850581b62db3d2d47009.diff

LOG: [mlir][Vector] Add ExtractOp folding when fed by a TransposeOp

TransposeOp are often followed by ExtractOp.
In certain cases however, it is unnecessary (and even detrimental) to lower a TransposeOp to either a flat transpose (llvm.matrix intrinsics) or to unrolled scalar insert / extract chains.

Providing foldings of ExtractOp mitigates some of the unnecessary complexity.

Differential revision: https://reviews.llvm.org/D83487

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index a44723024843..54f81db92a3e 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -170,6 +170,10 @@ class AffineMap {
   ///     `(d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)`
   AffineMap compose(AffineMap map);
 
+  /// Applies composition by the dims of `this` to the integer `values` and
+  /// returns the resulting values. `this` must be symbol-less.
+  SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values);
+
   /// Returns true if the AffineMap represents a subset (i.e. a projection) of a
   /// symbol-less permutation map.
   bool isProjectedPermutation();
@@ -180,6 +184,11 @@ class AffineMap {
   /// Returns the map consisting of the `resultPos` subset.
   AffineMap getSubMap(ArrayRef<unsigned> resultPos);
 
+  /// Returns the map consisting of the most major `numResults` results.
+  /// Returns the null AffineMap if `numResults` == 0.
+  /// Returns `*this` if `numResults` >= `this->getNumResults()`.
+  AffineMap getMajorSubMap(unsigned numResults);
+
   /// Returns the map consisting of the most minor `numResults` results.
   /// Returns the null AffineMap if `numResults` == 0.
   /// Returns `*this` if `numResults` >= `this->getNumResults()`.

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 0aae97f24d20..cdf09c4a8f68 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -602,6 +603,63 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
   return success();
 }
 
+/// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
+static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
+  auto transposeOp = extractOp.vector().getDefiningOp<TransposeOp>();
+  if (!transposeOp)
+    return failure();
+
+  auto permutation = extractVector<unsigned>(transposeOp.transp());
+  auto extractedPos = extractVector<int64_t>(extractOp.position());
+
+  // If transposition permutation is larger than the ExtractOp, all minor
+  // dimensions must be an identity for folding to occur. If not, individual
+  // elements within the extracted value are transposed and this is not just a
+  // simple folding.
+  unsigned minorRank = permutation.size() - extractedPos.size();
+  MLIRContext *ctx = extractOp.getContext();
+  AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
+  AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
+  if (minorMap && !AffineMap::isMinorIdentity(minorMap))
+    return failure();
+
+  //   %1 = transpose %0[x, y, z] : vector<axbxcxf32>
+  //   %2 = extract %1[u, v] : vector<..xf32>
+  // may turn into:
+  //   %2 = extract %0[w, x] : vector<..xf32>
+  // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
+  // -1 denotes the inverse.
+  permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
+  // The major submap has fewer results but the same number of dims. To compose
+  // cleanly, we need to drop dims to form a "square matrix". This is possible
+  // because:
+  //   (a) this is a permutation map and
+  //   (b) the minor map has already been checked to be identity.
+  // Therefore, the major map cannot contain dims of position greater or equal
+  // than the number of results.
+  assert(llvm::all_of(permutationMap.getResults(),
+                      [&](AffineExpr e) {
+                        auto dim = e.dyn_cast<AffineDimExpr>();
+                        return dim && dim.getPosition() <
+                                          permutationMap.getNumResults();
+                      }) &&
+         "Unexpected map results depend on higher rank positions");
+  // Project on the first domain dimensions to allow composition.
+  permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
+                                  permutationMap.getResults(), ctx);
+
+  extractOp.setOperand(transposeOp.vector());
+  // Compose the inverse permutation map with the extractedPos.
+  auto newExtractedPos =
+      inversePermutation(permutationMap).compose(extractedPos);
+  // OpBuilder is only used as a helper to build an I64ArrayAttr.
+  OpBuilder b(extractOp.getContext());
+  extractOp.setAttr(ExtractOp::getPositionAttrName(),
+                    b.getI64ArrayAttr(newExtractedPos));
+
+  return success();
+}
+
 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
 /// result is always the input to some InsertOp.
 static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
@@ -689,6 +747,8 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
+  if (succeeded(foldExtractOpFromTranspose(*this)))
+    return getResult();
   if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
     return val;
   return OpFoldResult();

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index c17df954558b..b09c51a3abbb 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -330,6 +330,21 @@ AffineMap AffineMap::compose(AffineMap map) {
   return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
 }
 
+SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) {
+  assert(getNumSymbols() == 0 && "Expected symbol-less map");
+  SmallVector<AffineExpr, 4> exprs;
+  exprs.reserve(values.size());
+  MLIRContext *ctx = getContext();
+  for (auto v : values)
+    exprs.push_back(getAffineConstantExpr(v, ctx));
+  auto resMap = compose(AffineMap::get(0, 0, exprs, ctx));
+  SmallVector<int64_t, 4> res;
+  res.reserve(resMap.getNumResults());
+  for (auto e : resMap.getResults())
+    res.push_back(e.cast<AffineConstantExpr>().getValue());
+  return res;
+}
+
 bool AffineMap::isProjectedPermutation() {
   if (getNumSymbols() > 0)
     return false;
@@ -360,6 +375,14 @@ AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
   return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
 }
 
+AffineMap AffineMap::getMajorSubMap(unsigned numResults) {
+  if (numResults == 0)
+    return AffineMap();
+  if (numResults > getNumResults())
+    return *this;
+  return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults)));
+}
+
 AffineMap AffineMap::getMinorSubMap(unsigned numResults) {
   if (numResults == 0)
     return AffineMap();

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 94f3f627e777..836a9869248d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -300,13 +300,48 @@ func @insert_extract_transpose_3d_2d(
 
 // CHECK-LABEL: fold_extracts
 //  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
-//  CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
-//  CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
-//  CHECK-NEXT: return
 func @fold_extracts(%a : vector<3x4x5x6xf32>) -> (f32, vector<4x5x6xf32>) {
   %b = vector.extract %a[0] : vector<3x4x5x6xf32>
   %c = vector.extract %b[1, 2] : vector<4x5x6xf32>
+  //  CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
   %d = vector.extract %c[3] : vector<6xf32>
+
+  //  CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
   %e = vector.extract %a[0] : vector<3x4x5x6xf32>
+
+  //  CHECK-NEXT: return
   return %d, %e : f32, vector<4x5x6xf32>
 }
+
+// -----
+
+// CHECK-LABEL: fold_extract_transpose
+//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
+//  CHECK-SAME:   %[[B:[a-zA-Z0-9]*]]: vector<3x6x5x6xf32>
+func @fold_extract_transpose(
+    %a : vector<3x4x5x6xf32>, %b : vector<3x6x5x6xf32>) -> (
+      vector<6xf32>, vector<6xf32>, vector<6xf32>) {
+  // [3] is a proper most minor identity map in transpose.
+  // Permutation is a self inverse and we have.
+  // [0, 2, 1] ^ -1 o [0, 1, 2] = [0, 2, 1] o [0, 1, 2]
+  //                            = [0, 2, 1]
+  //  CHECK-NEXT: vector.extract %[[A]][0, 2, 1] : vector<3x4x5x6xf32>
+  %0 = vector.transpose %a, [0, 2, 1, 3] : vector<3x4x5x6xf32> to vector<3x5x4x6xf32>
+  %1 = vector.extract %0[0, 1, 2] : vector<3x5x4x6xf32>
+
+  // [3] is a proper most minor identity map in transpose.
+  // Permutation is a not self inverse and we have.
+  // [1, 2, 0] ^ -1 o [0, 1, 2] = [2, 0, 1] o [0, 1, 2]
+  //                            = [2, 0, 1]
+  //  CHECK-NEXT: vector.extract %[[A]][2, 0, 1] : vector<3x4x5x6xf32>
+  %2 = vector.transpose %a, [1, 2, 0, 3] : vector<3x4x5x6xf32> to vector<4x5x3x6xf32>
+  %3 = vector.extract %2[0, 1, 2] : vector<4x5x3x6xf32>
+
+  // Not a minor identity map so intra-vector level has been permuted
+  //  CHECK-NEXT: vector.transpose %[[B]], [0, 2, 3, 1]
+  //  CHECK-NEXT: vector.extract %{{.*}}[0, 1, 2]
+  %4 = vector.transpose %b, [0, 2, 3, 1] : vector<3x6x5x6xf32> to vector<3x5x6x6xf32>
+  %5 = vector.extract %4[0, 1, 2] : vector<3x5x6x6xf32>
+
+  return %1, %3, %5 : vector<6xf32>, vector<6xf32>, vector<6xf32>
+}


        


More information about the Mlir-commits mailing list