[Mlir-commits] [mlir] 8c72eea - [mlir][vector] Add folding for ExtractOp with ShapeCastOp source

Thomas Raoux llvmlistbot at llvm.org
Fri Oct 23 12:06:41 PDT 2020


Author: Thomas Raoux
Date: 2020-10-23T12:06:18-07:00
New Revision: 8c72eea9a04ca8349224f26d1982d838824786a3

URL: https://github.com/llvm/llvm-project/commit/8c72eea9a04ca8349224f26d1982d838824786a3
DIFF: https://github.com/llvm/llvm-project/commit/8c72eea9a04ca8349224f26d1982d838824786a3.diff

LOG: [mlir][vector] Add folding for ExtractOp with ShapeCastOp source

Differential Revision: https://reviews.llvm.org/D89853

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index b71102cde1cf..d1deb5abd541 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -843,6 +843,61 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   return Value();
 }
 
+// Fold extractOp with source coming from ShapeCast op.
+static Value foldExtractFromShapeCast(ExtractOp extractOp) {
+  auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
+  if (!shapeCastOp)
+    return Value();
+  // Get the nth dimension size starting from lowest dimension.
+  auto getDimReverse = [](VectorType type, int64_t n) {
+    return type.getDimSize(type.getRank() - n - 1);
+  };
+  int64_t destinationRank =
+      extractOp.getVectorType().getRank() - extractOp.position().size();
+  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
+    return Value();
+  if (destinationRank > 0) {
+    auto destinationType = extractOp.getResult().getType().cast<VectorType>();
+    for (int64_t i = 0; i < destinationRank; i++) {
+      // The lowest dimension of of the destination must match the lowest
+      // dimension of the shapecast op source.
+      if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
+          getDimReverse(destinationType, i))
+        return Value();
+    }
+  }
+  // Extract the strides associated with the extract op vector source. Then use
+  // this to calculate a linearized position for the extract.
+  auto extractedPos = extractVector<int64_t>(extractOp.position());
+  std::reverse(extractedPos.begin(), extractedPos.end());
+  SmallVector<int64_t, 4> strides;
+  int64_t stride = 1;
+  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
+    strides.push_back(stride);
+    stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
+  }
+
+  int64_t position = linearize(extractedPos, strides);
+  // Then extract the strides assoociated to the shapeCast op vector source and
+  // delinearize the position using those strides.
+  SmallVector<int64_t, 4> newStrides;
+  int64_t numDimension =
+      shapeCastOp.getSourceVectorType().getRank() - destinationRank;
+  stride = 1;
+  for (int64_t i = 0; i < numDimension; i++) {
+    newStrides.push_back(stride);
+    stride *=
+        getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
+  }
+  std::reverse(newStrides.begin(), newStrides.end());
+  SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
+  OpBuilder b(extractOp.getContext());
+  extractOp.setAttr(ExtractOp::getPositionAttrName(),
+                    b.getI64ArrayAttr(newPosition));
+  extractOp.setOperand(shapeCastOp.source());
+  return extractOp.getResult();
+}
+
 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
@@ -852,6 +907,8 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
     return val;
   if (auto val = foldExtractFromBroadcast(*this))
     return val;
+  if (auto val = foldExtractFromShapeCast(*this))
+    return val;
   return OpFoldResult();
 }
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2f927a1bbc81..66bad06e6b60 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -394,6 +394,39 @@ func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
   return %r : vector<4xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @fold_extract_shapecast
+//  CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
+//       CHECK:   %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
+//       CHECK:   %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32>
+//       CHECK:   %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32>
+//       CHECK:   return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32>
+func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
+                             %arg1 : vector<8x4x2xf32>)
+  -> (f32, vector<2xf32>, vector<4x2xf32>) {
+  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
+  %1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32>
+  %r1 = vector.extract %0[4, 1] : vector<15x2xf32>
+  %r2 = vector.extract %0[5] : vector<15x2xf32>
+  %r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32>
+  return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_shapecast_negative
+//       CHECK:   %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32>
+//       CHECK:   return %[[R]] : vector<4x2xf32>
+func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
+                             %arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
+  %r = vector.extract %0[1] : vector<2x4x2xf32>
+  return %r : vector<4x2xf32>
+}
+
+
 // -----
 
 // CHECK-LABEL: fold_vector_transfers


        


More information about the Mlir-commits mailing list