[Mlir-commits] [mlir] 24ed3a9 - [mlir][Vector] Add ExtractOp folding

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jul 7 13:49:37 PDT 2020


Author: Nicolas Vasilache
Date: 2020-07-07T16:48:49-04:00
New Revision: 24ed3a9403fa0494275212765026a1bb4169ac76

URL: https://github.com/llvm/llvm-project/commit/24ed3a9403fa0494275212765026a1bb4169ac76
DIFF: https://github.com/llvm/llvm-project/commit/24ed3a9403fa0494275212765026a1bb4169ac76.diff

LOG: [mlir][Vector] Add ExtractOp folding

This revision adds foldings for ExtractOp operations that come from previous InsertOp.
InsertOp have cumulative semantic where multiple chained inserts are necessary to produce the final value from which the extracts are obtained.
Additionally, TransposeOp may be interleaved and need to be tracked in order to follow the producer consumer relationships and properly compute positions.

Differential revision: https://reviews.llvm.org/D83150

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index a02f39f943f8..b205e5a2e286 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -390,6 +390,7 @@ def Vector_ExtractOp :
       return vector().getType().cast<VectorType>();
     }
   }];
+  let hasFolder = 1;
 }
 
 def Vector_ExtractSlicesOp :

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 841ec6abcedf..a44723024843 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -180,6 +180,11 @@ class AffineMap {
   /// Returns the map consisting of the `resultPos` subset.
   AffineMap getSubMap(ArrayRef<unsigned> resultPos);
 
+  /// Returns the map consisting of the most minor `numResults` results.
+  /// Returns the null AffineMap if `numResults` == 0.
+  /// Returns `*this` if `numResults` >= `this->getNumResults()`.
+  AffineMap getMinorSubMap(unsigned numResults);
+
   friend ::llvm::hash_code hash_value(AffineMap arg);
 
 private:

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index f97906c2570d..019b73a0e00b 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -571,6 +571,100 @@ static LogicalResult verify(vector::ExtractOp op) {
   return success();
 }
 
+static SmallVector<unsigned, 4> extractUnsignedVector(ArrayAttr arrayAttr) {
+  return llvm::to_vector<4>(llvm::map_range(
+      arrayAttr.getAsRange<IntegerAttr>(),
+      [](IntegerAttr attr) { return static_cast<unsigned>(attr.getInt()); }));
+}
+
+static Value foldExtractOp(ExtractOp extractOp) {
+  MLIRContext *context = extractOp.getContext();
+  AffineMap permutationMap;
+  auto extractedPos = extractUnsignedVector(extractOp.position());
+  // Walk back a chain of InsertOp/TransposeOp until we hit a match.
+  // Compose TransposeOp permutations as we walk back.
+  auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
+  auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
+  while (insertOp || transposeOp) {
+    if (transposeOp) {
+      // If it is transposed, compose the map and iterate.
+      auto permutation = extractUnsignedVector(transposeOp.transp());
+      AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
+      if (!permutationMap)
+        permutationMap = newMap;
+      else if (newMap.getNumInputs() != permutationMap.getNumResults())
+        return Value();
+      else
+        permutationMap = newMap.compose(permutationMap);
+      // Compute insert/transpose for the next iteration.
+      Value transposed = transposeOp.vector();
+      insertOp = transposed.getDefiningOp<vector::InsertOp>();
+      transposeOp = transposed.getDefiningOp<vector::TransposeOp>();
+      continue;
+    }
+
+    assert(insertOp);
+    Value insertionDest = insertOp.dest();
+    // If it is inserted into, either the position matches and we have a
+    // successful folding; or we iterate until we run out of
+    // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
+    // produces a new vector with 1 modified value/slice in exactly the static
+    // position we need to match.
+    auto insertedPos = extractUnsignedVector(insertOp.position());
+    // Trivial permutations are solved with position equality checks.
+    if (!permutationMap || permutationMap.isIdentity()) {
+      if (extractedPos == insertedPos)
+        return insertOp.source();
+      // Fallthrough: if the position does not match, just skip to the next
+      // producing `vector.insert` / `vector.transpose`.
+      // Compute insert/transpose for the next iteration.
+      insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
+      transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
+      continue;
+    }
+
+    // More advanced permutations require application of the permutation.
+    // However, the rank of `insertedPos` may be 
diff erent from that of the
+    // `permutationMap`. To support such case, we need to:
+    //   1. apply on the `insertedPos.size()` major dimensions
+    //   2. check the other dimensions of the permutation form a minor identity.
+    assert(permutationMap.isPermutation() && "expected a permutation");
+    if (insertedPos.size() == extractedPos.size()) {
+      bool fold = true;
+      for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
+        auto pos =
+            permutationMap.getResult(idx).cast<AffineDimExpr>().getPosition();
+        if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
+          fold = false;
+          break;
+        }
+      }
+      if (fold) {
+        assert(permutationMap.getNumResults() >= insertedPos.size() &&
+               "expected map of rank larger than insert indexing");
+        unsigned minorRank =
+            permutationMap.getNumResults() - insertedPos.size();
+        AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
+        if (!minorMap || AffineMap::isMinorIdentity(minorMap))
+          return insertOp.source();
+      }
+    }
+
+    // If we haven't found a match, just continue to the next producing
+    // `vector.insert` / `vector.transpose`.
+    // Compute insert/transpose for the next iteration.
+    insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
+    transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
+  }
+  return Value();
+}
+
+OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
+  if (auto val = foldExtractOp(*this))
+    return val;
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractSlicesOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 050cb831f7a1..c17df954558b 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -355,12 +355,20 @@ bool AffineMap::isPermutation() {
 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
   SmallVector<AffineExpr, 4> exprs;
   exprs.reserve(resultPos.size());
-  for (auto idx : resultPos) {
+  for (auto idx : resultPos)
     exprs.push_back(getResult(idx));
-  }
   return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
 }
 
+AffineMap AffineMap::getMinorSubMap(unsigned numResults) {
+  if (numResults == 0)
+    return AffineMap();
+  if (numResults > getNumResults())
+    return *this;
+  return getSubMap(llvm::to_vector<4>(
+      llvm::seq<unsigned>(getNumResults() - numResults, getNumResults())));
+}
+
 AffineMap mlir::simplifyAffineMap(AffineMap map) {
   SmallVector<AffineExpr, 8> exprs;
   for (auto e : map.getResults()) {

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5e4ba39895ed..7ba79528aee6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -175,3 +175,123 @@ func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) {
   vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref<?x?xf32>
   return %1 : vector<4x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_transpose_2d(
+//  CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3xf32>,
+//  CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32,
+//  CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: f32,
+//  CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: f32,
+//  CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: f32
+func @insert_extract_transpose_2d(
+    %v: vector<2x3xf32>, %f0: f32, %f1: f32, %f2: f32, %f3: f32)
+-> (f32, f32, f32)
+{
+  %0 = vector.insert %f0, %v[0, 0] : f32 into vector<2x3xf32>
+  %1 = vector.insert %f1, %0[0, 1] : f32 into vector<2x3xf32>
+  %2 = vector.insert %f2, %1[1, 0] : f32 into vector<2x3xf32>
+  %3 = vector.insert %f3, %2[1, 1] : f32 into vector<2x3xf32>
+  %4 = vector.transpose %3, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
+  %5 = vector.insert %f3, %4[1, 0] : f32 into vector<3x2xf32>
+  %6 = vector.transpose %5, [1, 0] : vector<3x2xf32> to vector<2x3xf32>
+
+  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0].
+  %r1 = vector.extract %3[1, 0] : vector<2x3xf32>
+
+  // Expected %f1 from %1 = vector.insert %f1, %0[0, 1] followed by
+  // transpose [1, 0].
+  %r2 = vector.extract %4[1, 0] : vector<3x2xf32>
+
+  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0] followed by double
+  // transpose [1, 0].
+  %r3 = vector.extract %6[1, 0] : vector<2x3xf32>
+
+  // CHECK-NEXT: return %[[F2]], %[[F1]], %[[F2]] : f32, f32, f32
+  return %r1, %r2, %r3 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_transpose_3d(
+//  CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>,
+//  CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32,
+//  CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: f32,
+//  CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: f32,
+//  CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: f32
+func @insert_extract_transpose_3d(
+    %v: vector<2x3x4xf32>, %f0: f32, %f1: f32, %f2: f32, %f3: f32)
+-> (f32, f32, f32, f32)
+{
+  %0 = vector.insert %f0, %v[0, 0, 0] : f32 into vector<2x3x4xf32>
+  %1 = vector.insert %f1, %0[0, 1, 0] : f32 into vector<2x3x4xf32>
+  %2 = vector.insert %f2, %1[1, 0, 0] : f32 into vector<2x3x4xf32>
+  %3 = vector.insert %f3, %2[0, 0, 1] : f32 into vector<2x3x4xf32>
+  %4 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32>
+  %5 = vector.insert %f3, %4[1, 0, 0] : f32 into vector<3x4x2xf32>
+  %6 = vector.transpose %5, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32>
+  %7 = vector.insert %f3, %6[1, 0, 0] : f32 into vector<4x2x3xf32>
+  %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32>
+
+  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0].
+  %r1 = vector.extract %3[1, 0, 0] : vector<2x3x4xf32>
+
+  // Expected %f1 from %1 = vector.insert %f1, %0[0, 1, 0] followed by
+  // transpose[1, 2, 0].
+  %r2 = vector.extract %4[1, 0, 0] : vector<3x4x2xf32>
+
+  // Expected %f3 from %3 = vector.insert %f3, %0[0, 0, 1] followed by double
+  // transpose[1, 2, 0].
+  %r3 = vector.extract %6[1, 0, 0] : vector<4x2x3xf32>
+
+  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple
+  // transpose[1, 2, 0].
+  %r4 = vector.extract %8[1, 0, 0] : vector<2x3x4xf32>
+
+  // CHECK-NEXT: return %[[F2]], %[[F1]], %[[F3]], %[[F2]] : f32, f32, f32
+  return %r1, %r2, %r3, %r4 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_transpose_3d_2d(
+//  CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>,
+//  CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: vector<4xf32>,
+//  CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: vector<4xf32>,
+//  CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: vector<4xf32>,
+//  CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: vector<4xf32>
+func @insert_extract_transpose_3d_2d(
+    %v: vector<2x3x4xf32>,
+    %f0: vector<4xf32>, %f1: vector<4xf32>, %f2: vector<4xf32>, %f3: vector<4xf32>)
+-> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>)
+{
+  %0 = vector.insert %f0, %v[0, 0] : vector<4xf32> into vector<2x3x4xf32>
+  %1 = vector.insert %f1, %0[0, 1] : vector<4xf32> into vector<2x3x4xf32>
+  %2 = vector.insert %f2, %1[1, 0] : vector<4xf32> into vector<2x3x4xf32>
+  %3 = vector.insert %f3, %2[1, 1] : vector<4xf32> into vector<2x3x4xf32>
+  %4 = vector.transpose %3, [1, 0, 2] : vector<2x3x4xf32> to vector<3x2x4xf32>
+  %5 = vector.transpose %4, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
+
+  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0].
+  %r1 = vector.extract %3[1, 0] : vector<2x3x4xf32>
+
+  // Expected %f1 from %1 = vector.insert %f1, %0[0, 1] followed by
+  // transpose[1, 0, 2].
+  %r2 = vector.extract %4[1, 0] : vector<3x2x4xf32>
+
+  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by double
+  // transpose[1, 0, 2].
+  %r3 = vector.extract %5[1, 0] : vector<2x3x4xf32>
+
+  %6 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32>
+  %7 = vector.transpose %6, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32>
+  %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32>
+
+  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple
+  // transpose[1, 2, 0].
+  %r4 = vector.extract %8[1, 0] : vector<2x3x4xf32>
+
+  //      CHECK: return %[[F2]], %[[F1]], %[[F2]], %[[F2]]
+  // CHECK-SAME: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+  return %r1, %r2, %r3, %r4 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+}


        


More information about the Mlir-commits mailing list