[Mlir-commits] [mlir] b0a309d - [mlir][vector] Add folding for extract + extract/insert_strided
Thomas Raoux
llvmlistbot at llvm.org
Wed Jan 12 11:48:40 PST 2022
Author: Thomas Raoux
Date: 2022-01-12T11:48:35-08:00
New Revision: b0a309dd7a59c7fd4298116be63be4eedb28176e
URL: https://github.com/llvm/llvm-project/commit/b0a309dd7a59c7fd4298116be63be4eedb28176e
DIFF: https://github.com/llvm/llvm-project/commit/b0a309dd7a59c7fd4298116be63be4eedb28176e.diff
LOG: [mlir][vector] Add folding for extract + extract/insert_strided
Differential Revision: https://reviews.llvm.org/D116785
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 5538099d15b5a..20b431a7b7b25 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -865,6 +865,11 @@ def Vector_InsertStridedSliceOp :
VectorType getDestVectorType() {
return dest().getType().cast<VectorType>();
}
+ bool hasNonUnitStrides() {
+ return llvm::any_of(strides(), [](Attribute attr) {
+ return attr.cast<IntegerAttr>().getInt() != 1;
+ });
+ }
}];
let hasFolder = 1;
@@ -1120,6 +1125,11 @@ def Vector_ExtractStridedSliceOp :
static StringRef getStridesAttrName() { return "strides"; }
VectorType getVectorType(){ return vector().getType().cast<VectorType>(); }
void getOffsets(SmallVectorImpl<int64_t> &results);
+ bool hasNonUnitStrides() {
+ return llvm::any_of(strides(), [](Attribute attr) {
+ return attr.cast<IntegerAttr>().getInt() != 1;
+ });
+ }
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 3f83578caade3..224bccfd71022 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1204,6 +1204,109 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
return extractOp.getResult();
}
+/// Fold an ExtractOp from ExtractStridedSliceOp.
+static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
+ auto extractStridedSliceOp =
+ extractOp.vector().getDefiningOp<vector::ExtractStridedSliceOp>();
+ if (!extractStridedSliceOp)
+ return Value();
+ // Return if 'extractStridedSliceOp' has non-unit strides.
+ if (extractStridedSliceOp.hasNonUnitStrides())
+ return Value();
+
+ // Trim offsets for dimensions fully extracted.
+ auto sliceOffsets = extractVector<int64_t>(extractStridedSliceOp.offsets());
+ while (!sliceOffsets.empty()) {
+ size_t lastOffset = sliceOffsets.size() - 1;
+ if (sliceOffsets.back() != 0 ||
+ extractStridedSliceOp.getType().getDimSize(lastOffset) !=
+ extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
+ break;
+ sliceOffsets.pop_back();
+ }
+ unsigned destinationRank = 0;
+ if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
+ destinationRank = vecType.getRank();
+ // The dimensions of the result need to be untouched by the
+ // extractStridedSlice op.
+ if (destinationRank >
+ extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
+ return Value();
+ auto extractedPos = extractVector<int64_t>(extractOp.position());
+ assert(extractedPos.size() >= sliceOffsets.size());
+ for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
+ extractedPos[i] = extractedPos[i] + sliceOffsets[i];
+ extractOp.vectorMutable().assign(extractStridedSliceOp.vector());
+ // OpBuilder is only used as a helper to build an I64ArrayAttr.
+ OpBuilder b(extractOp.getContext());
+ extractOp->setAttr(ExtractOp::getPositionAttrName(),
+ b.getI64ArrayAttr(extractedPos));
+ return extractOp.getResult();
+}
+
+/// Fold extract_op fed from a chain of insertStridedSlice ops.
+static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
+ int64_t destinationRank = op.getType().isa<VectorType>()
+ ? op.getType().cast<VectorType>().getRank()
+ : 0;
+ auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
+ while (insertOp) {
+ int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
+ insertOp.getSourceVectorType().getRank();
+ if (destinationRank > insertOp.getSourceVectorType().getRank())
+ return Value();
+ auto insertOffsets = extractVector<int64_t>(insertOp.offsets());
+ auto extractOffsets = extractVector<int64_t>(op.position());
+
+ if (llvm::any_of(insertOp.strides(), [](Attribute attr) {
+ return attr.cast<IntegerAttr>().getInt() != 1;
+ }))
+ return Value();
+ bool disjoint = false;
+ SmallVector<int64_t, 4> offsetDiffs;
+ for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
+ int64_t start = insertOffsets[dim];
+ int64_t size =
+ (dim < insertRankDiff)
+ ? 1
+ : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
+ int64_t end = start + size;
+ int64_t offset = extractOffsets[dim];
+ // Check if the start of the extract offset is in the interval inserted.
+ if (start <= offset && offset < end) {
+ if (dim >= insertRankDiff)
+ offsetDiffs.push_back(offset - start);
+ continue;
+ }
+ disjoint = true;
+ break;
+ }
+ // The extract element chunk overlap with the vector inserted.
+ if (!disjoint) {
+ // If any of the inner dimensions are only partially inserted we have a
+ // partial overlap.
+ int64_t srcRankDiff =
+ insertOp.getSourceVectorType().getRank() - destinationRank;
+ for (int64_t i = 0; i < destinationRank; i++) {
+ if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
+ insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
+ insertRankDiff))
+ return Value();
+ }
+ op.vectorMutable().assign(insertOp.source());
+ // OpBuilder is only used as a helper to build an I64ArrayAttr.
+ OpBuilder b(op.getContext());
+ op->setAttr(ExtractOp::getPositionAttrName(),
+ b.getI64ArrayAttr(offsetDiffs));
+ return op.getResult();
+ }
+ // If the chunk extracted is disjoint from the chunk inserted, keep
+ // looking in the insert chain.
+ insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
+ }
+ return Value();
+}
+
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
if (position().empty())
return vector();
@@ -1217,6 +1320,10 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
return val;
if (auto val = foldExtractFromShapeCast(*this))
return val;
+ if (auto val = foldExtractFromExtractStrided(*this))
+ return val;
+ if (auto val = foldExtractStridedOpFromInsertChain(*this))
+ return val;
return OpFoldResult();
}
@@ -2183,9 +2290,7 @@ class StridedSliceConstantMaskFolder final
if (!constantMaskOp)
return failure();
// Return if 'extractStridedSliceOp' has non-unit strides.
- if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) {
- return attr.cast<IntegerAttr>().getInt() != 1;
- }))
+ if (extractStridedSliceOp.hasNonUnitStrides())
return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index faf801fe534a9..ba0e0a2b175cc 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1109,3 +1109,87 @@ func @extract_constant() -> (vector<7xf32>, i32) {
%1 = vector.extract %cst_1[1, 4, 5] : vector<4x37x9xi32>
return %0, %1 : vector<7xf32>, i32
}
+
+// -----
+
+// CHECK-LABEL: extract_extract_strided
+// CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16>
+// CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>
+// CHECK: return %[[V]] : vector<4xf16>
+func @extract_extract_strided(%arg0: vector<32x16x4xf16>) -> vector<4xf16> {
+ %1 = vector.extract_strided_slice %arg0
+ {offsets = [7, 3], sizes = [10, 8], strides = [1, 1]} :
+ vector<32x16x4xf16> to vector<10x8x4xf16>
+ %2 = vector.extract %1[2, 4] : vector<10x8x4xf16>
+ return %2 : vector<4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert_strided
+// CHECK-SAME: %[[A:.*]]: vector<6x4xf32>
+// CHECK: %[[V:.*]] = vector.extract %[[A]][0, 2] : vector<6x4xf32>
+// CHECK: return %[[V]] : f32
+func @extract_insert_strided(%a: vector<6x4xf32>, %b: vector<8x16xf32>)
+ -> f32 {
+ %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
+ : vector<6x4xf32> into vector<8x16xf32>
+ %2 = vector.extract %0[2, 4] : vector<8x16xf32>
+ return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert_rank_reduce
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>
+// CHECK: %[[V:.*]] = vector.extract %[[A]][2] : vector<4xf32>
+// CHECK: return %[[V]] : f32
+func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>)
+ -> f32 {
+ %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1]}
+ : vector<4xf32> into vector<8x16xf32>
+ %2 = vector.extract %0[2, 4] : vector<8x16xf32>
+ return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert_negative
+// CHECK: vector.insert_strided_slice
+// CHECK: vector.extract
+func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
+ -> vector<16xf32> {
+ %0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]}
+ : vector<2x15xf32> into vector<12x8x16xf32>
+ %2 = vector.extract %0[4, 2] : vector<12x8x16xf32>
+ return %2 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert_chain
+// CHECK-SAME: (%[[A:.*]]: vector<2x16xf32>, %[[B:.*]]: vector<12x8x16xf32>, %[[C:.*]]: vector<2x16xf32>)
+// CHECK: %[[V:.*]] = vector.extract %[[C]][0] : vector<2x16xf32>
+// CHECK: return %[[V]] : vector<16xf32>
+func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %c: vector<2x16xf32>)
+ -> vector<16xf32> {
+ %0 = vector.insert_strided_slice %c, %b {offsets = [4, 2, 0], strides = [1, 1]}
+ : vector<2x16xf32> into vector<12x8x16xf32>
+ %1 = vector.insert_strided_slice %a, %0 {offsets = [0, 2, 0], strides = [1, 1]}
+ : vector<2x16xf32> into vector<12x8x16xf32>
+ %2 = vector.extract %1[4, 2] : vector<12x8x16xf32>
+ return %2 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_extract_strided2
+// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
+// CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
+// CHECK: return %[[V]] : vector<4xf32>
+func @extract_extract_strided2(%A: vector<2x4xf32>)
+ -> (vector<4xf32>) {
+ %0 = vector.extract_strided_slice %A {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<2x4xf32> to vector<1x4xf32>
+ %1 = vector.extract %0[0] : vector<1x4xf32>
+ return %1 : vector<4xf32>
+}
More information about the Mlir-commits
mailing list