[Mlir-commits] [mlir] ea6a60a - [mlir][vector] Add folder for ExtractStridedSliceOp
Thomas Raoux
llvmlistbot at llvm.org
Fri Oct 23 12:18:26 PDT 2020
Author: Thomas Raoux
Date: 2020-10-23T12:18:09-07:00
New Revision: ea6a60a9a6c18a2a512b066fd8a873ff0db49836
URL: https://github.com/llvm/llvm-project/commit/ea6a60a9a6c18a2a512b066fd8a873ff0db49836
DIFF: https://github.com/llvm/llvm-project/commit/ea6a60a9a6c18a2a512b066fd8a873ff0db49836.diff
LOG: [mlir][vector] Add folder for ExtractStridedSliceOp
Add folder for the case where ExtractStridedSliceOp source comes from a chain
of InsertStridedSliceOp. Also add a folder for the trivial case where the
ExtractStridedSliceOp is a no-op.
Differential Revision: https://reviews.llvm.org/D89850
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index cc6bc2c8c77b..edf44c7c7110 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1016,6 +1016,7 @@ def Vector_ExtractStridedSliceOp :
void getOffsets(SmallVectorImpl<int64_t> &results);
}];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index d1deb5abd541..4b8cbda197f3 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1629,6 +1629,81 @@ static LogicalResult verify(ExtractStridedSliceOp op) {
return success();
}
+// When the source of ExtractStrided comes from a chain of InsertStrided ops try
+// to use the source o the InsertStrided ops if we can detect that the extracted
+// vector is a subset of one of the vector inserted.
+static LogicalResult
+foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
+ // Helper to extract integer out of ArrayAttr.
+ auto getElement = [](ArrayAttr array, int idx) {
+ return array[idx].cast<IntegerAttr>().getInt();
+ };
+ ArrayAttr extractOffsets = op.offsets();
+ ArrayAttr extractStrides = op.strides();
+ ArrayAttr extractSizes = op.sizes();
+ auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
+ while (insertOp) {
+ if (op.getVectorType().getRank() !=
+ insertOp.getSourceVectorType().getRank())
+ return failure();
+ ArrayAttr insertOffsets = insertOp.offsets();
+ ArrayAttr insertStrides = insertOp.strides();
+ // If the rank of extract is greater than the rank of insert, we are likely
+ // extracting a partial chunk of the vector inserted.
+ if (extractOffsets.size() > insertOffsets.size())
+ return failure();
+ bool patialoverlap = false;
+ bool disjoint = false;
+ SmallVector<int64_t, 4> offsetDiffs;
+ for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
+ if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
+ return failure();
+ int64_t start = getElement(insertOffsets, dim);
+ int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
+ int64_t offset = getElement(extractOffsets, dim);
+ int64_t size = getElement(extractSizes, dim);
+ // Check if the start of the extract offset is in the interval inserted.
+ if (start <= offset && offset < end) {
+ // If the extract interval overlaps but is not fully included we may
+ // have a partial overlap that will prevent any folding.
+ if (offset + size > end)
+ patialoverlap = true;
+ offsetDiffs.push_back(offset - start);
+ continue;
+ }
+ disjoint = true;
+ break;
+ }
+ // The extract element chunk is a subset of the insert element.
+ if (!disjoint && !patialoverlap) {
+ op.setOperand(insertOp.source());
+ // OpBuilder is only used as a helper to build an I64ArrayAttr.
+ OpBuilder b(op.getContext());
+ op.setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
+ b.getI64ArrayAttr(offsetDiffs));
+ return success();
+ }
+ // If the chunk extracted is disjoint from the chunk inserted, keep looking
+ // in the insert chain.
+ if (disjoint)
+ insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
+ else {
+ // The extracted vector partially overlap the inserted vector, we cannot
+ // fold.
+ return failure();
+ }
+ }
+ return failure();
+}
+
+OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
+ if (getVectorType() == getResult().getType())
+ return vector();
+ if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
+ return getResult();
+ return {};
+}
+
void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
populateFromInt64AttrArray(offsets(), results);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 66bad06e6b60..b20acccb9e7b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -90,6 +90,95 @@ func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
// -----
+// CHECK-LABEL: extract_strided_fold
+// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
+// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>
+func @extract_strided_fold(%arg : vector<4x3xi1>) -> (vector<4x3xi1>) {
+ %0 = vector.extract_strided_slice %arg
+ {offsets = [0, 0], sizes = [4, 3], strides = [1, 1]}
+ : vector<4x3xi1> to vector<4x3xi1>
+ return %0 : vector<4x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_fold_insert
+// CHECK-SAME: (%[[ARG:.*]]: vector<4x4xf32>
+// CHECK-NEXT: return %[[ARG]] : vector<4x4xf32>
+func @extract_strided_fold_insert(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
+ -> (vector<4x4xf32>) {
+ %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
+ : vector<4x4xf32> into vector<8x16xf32>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
+ : vector<8x16xf32> to vector<4x4xf32>
+ return %1 : vector<4x4xf32>
+}
+
+// -----
+
+// Case where the vector inserted is a subset of the vector extracted.
+// CHECK-LABEL: extract_strided_fold_insert
+// CHECK-SAME: (%[[ARG0:.*]]: vector<6x4xf32>
+// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]}
+// CHECK-SAME: : vector<6x4xf32> to vector<4x4xf32>
+// CHECK-NEXT: return %[[EXT]] : vector<4x4xf32>
+func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>)
+ -> (vector<4x4xf32>) {
+ %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
+ : vector<6x4xf32> into vector<8x16xf32>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
+ : vector<8x16xf32> to vector<4x4xf32>
+ return %1 : vector<4x4xf32>
+}
+
+// -----
+
+// Negative test where the extract is not a subset of the element inserted.
+// CHECK-LABEL: extract_strided_fold_negative
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
+// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
+// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
+// CHECK-SAME: : vector<4x4xf32> into vector<8x16xf32>
+// CHECK: %[[EXT:.*]] = vector.extract_strided_slice %[[INS]]
+// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
+// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
+// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
+func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
+ -> (vector<6x4xf32>) {
+ %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
+ : vector<4x4xf32> into vector<8x16xf32>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
+ : vector<8x16xf32> to vector<6x4xf32>
+ return %1 : vector<6x4xf32>
+}
+
+// -----
+
+// Case where we need to go through 2 level of insert element.
+// CHECK-LABEL: extract_strided_fold_insert
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
+// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
+// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
+// CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32>
+// CHECK-NEXT: return %[[EXT]] : vector<1x1xf32>
+func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
+ %c : vector<1x4xf32>) -> (vector<1x1xf32>) {
+ %0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
+ : vector<1x4xf32> into vector<2x4xf32>
+ %1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
+ : vector<1x4xf32> into vector<2x4xf32>
+ %2 = vector.extract_strided_slice %1
+ {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
+ : vector<2x4xf32> to vector<1x1xf32>
+ return %2 : vector<1x1xf32>
+}
+
+// -----
+
// CHECK-LABEL: transpose_1D_identity
// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
More information about the Mlir-commits
mailing list