[Mlir-commits] [mlir] b0a309d - [mlir][vector] Add folding for extract + extract/insert_strided

Thomas Raoux llvmlistbot at llvm.org
Wed Jan 12 11:48:40 PST 2022


Author: Thomas Raoux
Date: 2022-01-12T11:48:35-08:00
New Revision: b0a309dd7a59c7fd4298116be63be4eedb28176e

URL: https://github.com/llvm/llvm-project/commit/b0a309dd7a59c7fd4298116be63be4eedb28176e
DIFF: https://github.com/llvm/llvm-project/commit/b0a309dd7a59c7fd4298116be63be4eedb28176e.diff

LOG: [mlir][vector] Add folding for extract + extract/insert_strided

Differential Revision: https://reviews.llvm.org/D116785

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 5538099d15b5a..20b431a7b7b25 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -865,6 +865,11 @@ def Vector_InsertStridedSliceOp :
     VectorType getDestVectorType() {
       return dest().getType().cast<VectorType>();
     }
+    bool hasNonUnitStrides() { 
+      return llvm::any_of(strides(), [](Attribute attr) {
+        return attr.cast<IntegerAttr>().getInt() != 1;
+      });
+    }
   }];
 
   let hasFolder = 1;
@@ -1120,6 +1125,11 @@ def Vector_ExtractStridedSliceOp :
     static StringRef getStridesAttrName() { return "strides"; }
     VectorType getVectorType(){ return vector().getType().cast<VectorType>(); }
     void getOffsets(SmallVectorImpl<int64_t> &results);
+    bool hasNonUnitStrides() { 
+      return llvm::any_of(strides(), [](Attribute attr) {
+        return attr.cast<IntegerAttr>().getInt() != 1;
+      });
+    }
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 3f83578caade3..224bccfd71022 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1204,6 +1204,109 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   return extractOp.getResult();
 }
 
+/// Fold an ExtractOp from ExtractStridedSliceOp.
+static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
+  auto extractStridedSliceOp =
+      extractOp.vector().getDefiningOp<vector::ExtractStridedSliceOp>();
+  if (!extractStridedSliceOp)
+    return Value();
+  // Return if 'extractStridedSliceOp' has non-unit strides.
+  if (extractStridedSliceOp.hasNonUnitStrides())
+    return Value();
+
+  // Trim offsets for dimensions fully extracted.
+  auto sliceOffsets = extractVector<int64_t>(extractStridedSliceOp.offsets());
+  while (!sliceOffsets.empty()) {
+    size_t lastOffset = sliceOffsets.size() - 1;
+    if (sliceOffsets.back() != 0 ||
+        extractStridedSliceOp.getType().getDimSize(lastOffset) !=
+            extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
+      break;
+    sliceOffsets.pop_back();
+  }
+  unsigned destinationRank = 0;
+  if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
+    destinationRank = vecType.getRank();
+  // The dimensions of the result need to be untouched by the
+  // extractStridedSlice op.
+  if (destinationRank >
+      extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
+    return Value();
+  auto extractedPos = extractVector<int64_t>(extractOp.position());
+  assert(extractedPos.size() >= sliceOffsets.size());
+  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
+    extractedPos[i] = extractedPos[i] + sliceOffsets[i];
+  extractOp.vectorMutable().assign(extractStridedSliceOp.vector());
+  // OpBuilder is only used as a helper to build an I64ArrayAttr.
+  OpBuilder b(extractOp.getContext());
+  extractOp->setAttr(ExtractOp::getPositionAttrName(),
+                     b.getI64ArrayAttr(extractedPos));
+  return extractOp.getResult();
+}
+
+/// Fold extract_op fed from a chain of insertStridedSlice ops.
+static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
+  int64_t destinationRank = op.getType().isa<VectorType>()
+                                ? op.getType().cast<VectorType>().getRank()
+                                : 0;
+  auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
+  while (insertOp) {
+    int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
+                             insertOp.getSourceVectorType().getRank();
+    if (destinationRank > insertOp.getSourceVectorType().getRank())
+      return Value();
+    auto insertOffsets = extractVector<int64_t>(insertOp.offsets());
+    auto extractOffsets = extractVector<int64_t>(op.position());
+
+    if (llvm::any_of(insertOp.strides(), [](Attribute attr) {
+          return attr.cast<IntegerAttr>().getInt() != 1;
+        }))
+      return Value();
+    bool disjoint = false;
+    SmallVector<int64_t, 4> offsetDiffs;
+    for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
+      int64_t start = insertOffsets[dim];
+      int64_t size =
+          (dim < insertRankDiff)
+              ? 1
+              : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
+      int64_t end = start + size;
+      int64_t offset = extractOffsets[dim];
+      // Check if the start of the extract offset is in the interval inserted.
+      if (start <= offset && offset < end) {
+        if (dim >= insertRankDiff)
+          offsetDiffs.push_back(offset - start);
+        continue;
+      }
+      disjoint = true;
+      break;
+    }
+    // The extract element chunk overlap with the vector inserted.
+    if (!disjoint) {
+      // If any of the inner dimensions are only partially inserted we have a
+      // partial overlap.
+      int64_t srcRankDiff =
+          insertOp.getSourceVectorType().getRank() - destinationRank;
+      for (int64_t i = 0; i < destinationRank; i++) {
+        if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
+            insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
+                                                    insertRankDiff))
+          return Value();
+      }
+      op.vectorMutable().assign(insertOp.source());
+      // OpBuilder is only used as a helper to build an I64ArrayAttr.
+      OpBuilder b(op.getContext());
+      op->setAttr(ExtractOp::getPositionAttrName(),
+                  b.getI64ArrayAttr(offsetDiffs));
+      return op.getResult();
+    }
+    // If the chunk extracted is disjoint from the chunk inserted, keep
+    // looking in the insert chain.
+    insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
+  }
+  return Value();
+}
+
 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
   if (position().empty())
     return vector();
@@ -1217,6 +1320,10 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
     return val;
   if (auto val = foldExtractFromShapeCast(*this))
     return val;
+  if (auto val = foldExtractFromExtractStrided(*this))
+    return val;
+  if (auto val = foldExtractStridedOpFromInsertChain(*this))
+    return val;
   return OpFoldResult();
 }
 
@@ -2183,9 +2290,7 @@ class StridedSliceConstantMaskFolder final
     if (!constantMaskOp)
       return failure();
     // Return if 'extractStridedSliceOp' has non-unit strides.
-    if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) {
-          return attr.cast<IntegerAttr>().getInt() != 1;
-        }))
+    if (extractStridedSliceOp.hasNonUnitStrides())
       return failure();
     // Gather constant mask dimension sizes.
     SmallVector<int64_t, 4> maskDimSizes;

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index faf801fe534a9..ba0e0a2b175cc 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1109,3 +1109,87 @@ func @extract_constant() -> (vector<7xf32>, i32) {
   %1 = vector.extract %cst_1[1, 4, 5] : vector<4x37x9xi32>
   return %0, %1 : vector<7xf32>, i32
 }
+
+// -----
+
+// CHECK-LABEL: extract_extract_strided
+//  CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16>
+//       CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>
+//       CHECK: return %[[V]] : vector<4xf16>
+func @extract_extract_strided(%arg0: vector<32x16x4xf16>) -> vector<4xf16> {
+ %1 = vector.extract_strided_slice %arg0
+  {offsets = [7, 3], sizes = [10, 8], strides = [1, 1]} :
+  vector<32x16x4xf16> to vector<10x8x4xf16>
+  %2 = vector.extract %1[2, 4] : vector<10x8x4xf16>
+  return %2 : vector<4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert_strided
+//  CHECK-SAME: %[[A:.*]]: vector<6x4xf32>
+//       CHECK: %[[V:.*]] = vector.extract %[[A]][0, 2] : vector<6x4xf32>
+//       CHECK: return %[[V]] : f32
+func @extract_insert_strided(%a: vector<6x4xf32>, %b: vector<8x16xf32>)
+  -> f32 {
+  %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
+    : vector<6x4xf32> into vector<8x16xf32>
+  %2 = vector.extract %0[2, 4] : vector<8x16xf32>
+  return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert_rank_reduce
+//  CHECK-SAME: %[[A:.*]]: vector<4xf32>
+//       CHECK: %[[V:.*]] = vector.extract %[[A]][2] : vector<4xf32>
+//       CHECK: return %[[V]] : f32
+func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>)
+  -> f32 {
+  %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1]}
+    : vector<4xf32> into vector<8x16xf32>
+  %2 = vector.extract %0[2, 4] : vector<8x16xf32>
+  return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert_negative
+//       CHECK: vector.insert_strided_slice
+//       CHECK: vector.extract
+func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
+  -> vector<16xf32> {
+  %0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]}
+    : vector<2x15xf32> into vector<12x8x16xf32>
+  %2 = vector.extract %0[4, 2] : vector<12x8x16xf32>
+  return %2 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert_chain
+//  CHECK-SAME: (%[[A:.*]]: vector<2x16xf32>, %[[B:.*]]: vector<12x8x16xf32>, %[[C:.*]]: vector<2x16xf32>)
+//       CHECK: %[[V:.*]] = vector.extract %[[C]][0] : vector<2x16xf32>
+//       CHECK: return %[[V]] : vector<16xf32>
+func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %c: vector<2x16xf32>)
+  -> vector<16xf32> {
+  %0 = vector.insert_strided_slice %c, %b {offsets = [4, 2, 0], strides = [1, 1]}
+    : vector<2x16xf32> into vector<12x8x16xf32>
+  %1 = vector.insert_strided_slice %a, %0 {offsets = [0, 2, 0], strides = [1, 1]}
+    : vector<2x16xf32> into vector<12x8x16xf32>
+  %2 = vector.extract %1[4, 2] : vector<12x8x16xf32>
+  return %2 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_extract_strided2
+//  CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
+//       CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
+//       CHECK: return %[[V]] : vector<4xf32>
+func @extract_extract_strided2(%A: vector<2x4xf32>)
+  -> (vector<4xf32>) {
+ %0 = vector.extract_strided_slice %A {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<2x4xf32> to vector<1x4xf32>
+ %1 = vector.extract %0[0] : vector<1x4xf32>
+ return %1 : vector<4xf32>
+}


        


More information about the Mlir-commits mailing list