[Mlir-commits] [mlir] 9081e75 - [mlir][vector] Address post-commit review comments on vector ops folding patterns
Thomas Raoux
llvmlistbot at llvm.org
Mon Nov 2 10:58:06 PST 2020
Author: Thomas Raoux
Date: 2020-11-02T10:57:32-08:00
New Revision: 9081e7594df2356f72c038423a0f8130f140b255
URL: https://github.com/llvm/llvm-project/commit/9081e7594df2356f72c038423a0f8130f140b255
DIFF: https://github.com/llvm/llvm-project/commit/9081e7594df2356f72c038423a0f8130f140b255.diff
LOG: [mlir][vector] Address post-commit review comments on vector ops folding patterns
Differential Revision: https://reviews.llvm.org/D90183
Added:
Modified:
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index faae278ad6d4d..53cdf3fc91035 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -850,10 +850,12 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
return Value();
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
- return type.getDimSize(type.getRank() - n - 1);
+ return type.getShape().take_back(n+1).front();
};
int64_t destinationRank =
- extractOp.getVectorType().getRank() - extractOp.position().size();
+ extractOp.getType().isa<VectorType>()
+ ? extractOp.getType().cast<VectorType>().getRank()
+ : 0;
if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
return Value();
if (destinationRank > 0) {
@@ -861,6 +863,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
for (int64_t i = 0; i < destinationRank; i++) {
// The lowest dimension of of the destination must match the lowest
// dimension of the shapecast op source.
+ // TODO: This case could be support in a canonicalization pattern.
if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
getDimReverse(destinationType, i))
return Value();
@@ -891,6 +894,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
}
std::reverse(newStrides.begin(), newStrides.end());
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
+ // OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setAttr(ExtractOp::getPositionAttrName(),
b.getI64ArrayAttr(newPosition));
@@ -1632,8 +1636,8 @@ static LogicalResult verify(ExtractStridedSliceOp op) {
}
// When the source of ExtractStrided comes from a chain of InsertStrided ops try
-// to use the source o the InsertStrided ops if we can detect that the extracted
-// vector is a subset of one of the vector inserted.
+// to use the source of the InsertStrided ops if we can detect that the
+// extracted vector is a subset of one of the vector inserted.
static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
// Helper to extract integer out of ArrayAttr.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b20acccb9e7b9..00905420c118f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -160,20 +160,20 @@ func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
// Case where we need to go through 2 level of insert element.
// CHECK-LABEL: extract_strided_fold_insert
-// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
-// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
+// CHECK-SAME: {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]}
// CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32>
// CHECK-NEXT: return %[[EXT]] : vector<1x1xf32>
-func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
+func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
%c : vector<1x4xf32>) -> (vector<1x1xf32>) {
- %0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
- : vector<1x4xf32> into vector<2x4xf32>
+ %0 = vector.insert_strided_slice %b, %a {offsets = [0, 1], strides = [1, 1]}
+ : vector<1x4xf32> into vector<2x8xf32>
%1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
- : vector<1x4xf32> into vector<2x4xf32>
+ : vector<1x4xf32> into vector<2x8xf32>
%2 = vector.extract_strided_slice %1
{offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
- : vector<2x4xf32> to vector<1x1xf32>
+ : vector<2x8xf32> to vector<1x1xf32>
return %2 : vector<1x1xf32>
}
More information about the Mlir-commits
mailing list