[Mlir-commits] [mlir] ea6a60a - [mlir][vector] Add folder for ExtractStridedSliceOp

Thomas Raoux llvmlistbot at llvm.org
Fri Oct 23 12:18:26 PDT 2020


Author: Thomas Raoux
Date: 2020-10-23T12:18:09-07:00
New Revision: ea6a60a9a6c18a2a512b066fd8a873ff0db49836

URL: https://github.com/llvm/llvm-project/commit/ea6a60a9a6c18a2a512b066fd8a873ff0db49836
DIFF: https://github.com/llvm/llvm-project/commit/ea6a60a9a6c18a2a512b066fd8a873ff0db49836.diff

LOG: [mlir][vector] Add folder for ExtractStridedSliceOp

Add folder for the case where ExtractStridedSliceOp source comes from a chain
of InsertStridedSliceOp. Also add a folder for the trivial case where the
ExtractStridedSliceOp is a no-op.

Differential Revision: https://reviews.llvm.org/D89850

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index cc6bc2c8c77b..edf44c7c7110 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1016,6 +1016,7 @@ def Vector_ExtractStridedSliceOp :
     void getOffsets(SmallVectorImpl<int64_t> &results);
   }];
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
   let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
 }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index d1deb5abd541..4b8cbda197f3 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1629,6 +1629,81 @@ static LogicalResult verify(ExtractStridedSliceOp op) {
   return success();
 }
 
+// When the source of ExtractStrided comes from a chain of InsertStrided ops try
+// to use the source o the InsertStrided ops if we can detect that the extracted
+// vector is a subset of one of the vector inserted.
+static LogicalResult
+foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
+  // Helper to extract integer out of ArrayAttr.
+  auto getElement = [](ArrayAttr array, int idx) {
+    return array[idx].cast<IntegerAttr>().getInt();
+  };
+  ArrayAttr extractOffsets = op.offsets();
+  ArrayAttr extractStrides = op.strides();
+  ArrayAttr extractSizes = op.sizes();
+  auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
+  while (insertOp) {
+    if (op.getVectorType().getRank() !=
+        insertOp.getSourceVectorType().getRank())
+      return failure();
+    ArrayAttr insertOffsets = insertOp.offsets();
+    ArrayAttr insertStrides = insertOp.strides();
+    // If the rank of extract is greater than the rank of insert, we are likely
+    // extracting a partial chunk of the vector inserted.
+    if (extractOffsets.size() > insertOffsets.size())
+      return failure();
+    bool patialoverlap = false;
+    bool disjoint = false;
+    SmallVector<int64_t, 4> offsetDiffs;
+    for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
+      if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
+        return failure();
+      int64_t start = getElement(insertOffsets, dim);
+      int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
+      int64_t offset = getElement(extractOffsets, dim);
+      int64_t size = getElement(extractSizes, dim);
+      // Check if the start of the extract offset is in the interval inserted.
+      if (start <= offset && offset < end) {
+        // If the extract interval overlaps but is not fully included we may
+        // have a partial overlap that will prevent any folding.
+        if (offset + size > end)
+          patialoverlap = true;
+        offsetDiffs.push_back(offset - start);
+        continue;
+      }
+      disjoint = true;
+      break;
+    }
+    // The extract element chunk is a subset of the insert element.
+    if (!disjoint && !patialoverlap) {
+      op.setOperand(insertOp.source());
+      // OpBuilder is only used as a helper to build an I64ArrayAttr.
+      OpBuilder b(op.getContext());
+      op.setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
+                 b.getI64ArrayAttr(offsetDiffs));
+      return success();
+    }
+    // If the chunk extracted is disjoint from the chunk inserted, keep looking
+    // in the insert chain.
+    if (disjoint)
+      insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
+    else {
+      // The extracted vector partially overlap the inserted vector, we cannot
+      // fold.
+      return failure();
+    }
+  }
+  return failure();
+}
+
+OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
+  if (getVectorType() == getResult().getType())
+    return vector();
+  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
+    return getResult();
+  return {};
+}
+
 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
   populateFromInt64AttrArray(offsets(), results);
 }

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 66bad06e6b60..b20acccb9e7b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -90,6 +90,95 @@ func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
 
 // -----
 
+// CHECK-LABEL: extract_strided_fold
+//  CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
+//  CHECK-NEXT:   return %[[ARG]] : vector<4x3xi1>
+func @extract_strided_fold(%arg : vector<4x3xi1>) -> (vector<4x3xi1>) {
+  %0 = vector.extract_strided_slice %arg
+    {offsets = [0, 0], sizes = [4, 3], strides = [1, 1]}
+      : vector<4x3xi1> to vector<4x3xi1>
+  return %0 : vector<4x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_fold_insert
+//  CHECK-SAME: (%[[ARG:.*]]: vector<4x4xf32>
+//  CHECK-NEXT:   return %[[ARG]] : vector<4x4xf32>
+func @extract_strided_fold_insert(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
+  -> (vector<4x4xf32>) {
+  %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
+    : vector<4x4xf32> into vector<8x16xf32>
+  %1 = vector.extract_strided_slice %0
+    {offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
+      : vector<8x16xf32> to vector<4x4xf32>
+  return %1 : vector<4x4xf32>
+}
+
+// -----
+
+// Case where the vector inserted is a subset of the vector extracted.
+// CHECK-LABEL: extract_strided_fold_insert
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<6x4xf32>
+//  CHECK-NEXT:   %[[EXT:.*]] = vector.extract_strided_slice %[[ARG0]]
+//  CHECK-SAME:     {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]}
+//  CHECK-SAME:       : vector<6x4xf32> to vector<4x4xf32>
+//  CHECK-NEXT:   return %[[EXT]] : vector<4x4xf32>
+func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>)
+  -> (vector<4x4xf32>) {
+  %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
+    : vector<6x4xf32> into vector<8x16xf32>
+  %1 = vector.extract_strided_slice %0
+    {offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
+      : vector<8x16xf32> to vector<4x4xf32>
+  return %1 : vector<4x4xf32>
+}
+
+// -----
+
+// Negative test where the extract is not a subset of the element inserted.
+// CHECK-LABEL: extract_strided_fold_negative
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
+//       CHECK:   %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
+//  CHECK-SAME:     {offsets = [2, 2], strides = [1, 1]}
+//  CHECK-SAME:       : vector<4x4xf32> into vector<8x16xf32>
+//       CHECK:   %[[EXT:.*]] = vector.extract_strided_slice %[[INS]]
+//  CHECK-SAME:     {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
+//  CHECK-SAME:       : vector<8x16xf32> to vector<6x4xf32>
+//  CHECK-NEXT:   return %[[EXT]] : vector<6x4xf32>
+func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
+  -> (vector<6x4xf32>) {
+  %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
+    : vector<4x4xf32> into vector<8x16xf32>
+  %1 = vector.extract_strided_slice %0
+    {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
+      : vector<8x16xf32> to vector<6x4xf32>
+  return %1 : vector<6x4xf32>
+}
+
+// -----
+
+// Case where we need to go through 2 level of insert element.
+// CHECK-LABEL: extract_strided_fold_insert
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
+//  CHECK-NEXT:   %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
+//  CHECK-SAME:     {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
+//  CHECK-SAME:       : vector<1x4xf32> to vector<1x1xf32>
+//  CHECK-NEXT:   return %[[EXT]] : vector<1x1xf32>
+func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
+                                  %c : vector<1x4xf32>) -> (vector<1x1xf32>) {
+  %0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
+    : vector<1x4xf32> into vector<2x4xf32>
+  %1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
+    : vector<1x4xf32> into vector<2x4xf32>
+  %2 = vector.extract_strided_slice %1
+      {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
+        : vector<2x4xf32> to vector<1x1xf32>
+  return %2 : vector<1x1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: transpose_1D_identity
 // CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
 func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {


        


More information about the Mlir-commits mailing list