[Mlir-commits] [mlir] 76d71f3 - Revert "[mlir][Vector] Extend xfer drop unit dim patterns"
Diego Caballero
llvmlistbot at llvm.org
Wed May 31 11:20:11 PDT 2023
Author: Diego Caballero
Date: 2023-05-31T18:20:05Z
New Revision: 76d71f3792b2b1864992446f7b1028b026dccd11
URL: https://github.com/llvm/llvm-project/commit/76d71f3792b2b1864992446f7b1028b026dccd11
DIFF: https://github.com/llvm/llvm-project/commit/76d71f3792b2b1864992446f7b1028b026dccd11.diff
LOG: Revert "[mlir][Vector] Extend xfer drop unit dim patterns"
This reverts commit a53cd03deac5e6272e9dae88a90cd51410d312d5.
This commit is exposing some implementation gaps in other patterns.
Reverting for now.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 0e9dcf27c5585..af0fcd097028d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -63,7 +63,6 @@ class TransferOptimization {
std::vector<Operation *> opToErase;
};
-} // namespace
/// Return true if there is a path from start operation to dest operation,
/// otherwise return false. The operations have to be in the same region.
bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
@@ -289,25 +288,14 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
}
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
- SmallVector<int64_t> reducedShape;
- llvm::copy_if(shape, std::back_inserter(reducedShape),
- [](int64_t dimSize) { return dimSize != 1; });
- return reducedShape;
-}
-
/// Returns true if all values are `arith.constant 0 : index`
static bool isZero(Value v) {
auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
return cst && cst.value() == 0;
}
-namespace {
-
-/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
-/// inserting a memref.subview dropping those unit dims. The vector shapes are
-/// also reduced accordingly.
+/// Rewrites vector.transfer_read ops where the source has unit dims, by
+/// inserting a memref.subview dropping those unit dims.
class TransferReadDropUnitDimsPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
@@ -329,15 +317,12 @@ class TransferReadDropUnitDimsPattern
return failure();
if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
- // Check if the source shape can be further reduced.
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
- return failure();
- // Check if the reduced vector shape matches the reduced source shape.
- // Otherwise, this case is not supported yet.
- int vectorReducedRank = getReducedRank(vectorType.getShape());
- if (reducedRank != vectorReducedRank)
- return failure();
+ return failure(); // The source shape can't be further reduced.
+ if (reducedRank != vectorType.getRank())
+ return failure(); // This pattern requires the vector shape to match the
+ // reduced source shape.
if (llvm::any_of(transferReadOp.getIndices(),
[](Value v) { return !isZero(v); }))
return failure();
@@ -346,22 +331,14 @@ class TransferReadDropUnitDimsPattern
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
- auto reducedVectorType = VectorType::get(
- getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
- auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
- auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, vectorType, newTransferReadOp);
- rewriter.replaceOp(transferReadOp, shapeCast);
-
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
return success();
}
};
-/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
-/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
-/// vector shapes are also reduced accordingly.
+/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
+/// unit dims, by inserting a memref.subview dropping those unit dims.
class TransferWriteDropUnitDimsPattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
@@ -383,15 +360,12 @@ class TransferWriteDropUnitDimsPattern
return failure();
if (!transferWriteOp.getPermutationMap().isMinorIdentity())
return failure();
- // Check if the destination shape can be further reduced.
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
- return failure();
- // Check if the reduced vector shape matches the reduced destination shape.
- // Otherwise, this case is not supported yet.
- int vectorReducedRank = getReducedRank(vectorType.getShape());
- if (reducedRank != vectorReducedRank)
- return failure();
+ return failure(); // The source shape can't be further reduced.
+ if (reducedRank != vectorType.getRank())
+ return failure(); // This pattern requires the vector shape to match the
+ // reduced source shape.
if (llvm::any_of(transferWriteOp.getIndices(),
[](Value v) { return !isZero(v); }))
return failure();
@@ -400,20 +374,12 @@ class TransferWriteDropUnitDimsPattern
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
- VectorType reducedVectorType = VectorType::get(
- getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
- auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, reducedVectorType, vector);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
-
+ transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
return success();
}
};
-} // namespace
-
/// Return true if the memref type has its inner dimension matching the given
/// shape. Otherwise return false.
static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
@@ -473,8 +439,6 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
return success();
}
-namespace {
-
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -768,7 +732,6 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
return success();
}
};
-
} // namespace
void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 3efa06948f546..e4e2e3b69c67b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -15,14 +15,6 @@ func.func @transfer_read_rank_reducing(
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]
-transform.sequence failures(propagate) {
-^bb1(%module_op: !pdl.operation):
- transform.vector.apply_rank_reducing_subview_patterns %module_op
- : (!pdl.operation) -> !pdl.operation
-}
-
-// -----
-
func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -36,97 +28,6 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
-transform.sequence failures(propagate) {
-^bb1(%module_op: !pdl.operation):
- transform.vector.apply_rank_reducing_subview_patterns %module_op
- : (!pdl.operation) -> !pdl.operation
-}
-
-// -----
-
-func.func @transfer_read_and_vector_rank_reducing(
- %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.0 : f32
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
- memref<1x1x3x2x1xf32>, vector<3x2x1xf32>
- return %v : vector<3x2x1xf32>
-}
-
-// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing
-// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
-// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
-// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32>
-
-transform.sequence failures(propagate) {
-^bb1(%module_op: !pdl.operation):
- transform.vector.apply_rank_reducing_subview_patterns %module_op
- : (!pdl.operation) -> !pdl.operation
-}
-
-// -----
-
-func.func @transfer_write_and_vector_rank_reducing(
- %arg : memref<1x1x3x2x1xf32>,
- %vec : vector<3x2x1xf32>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
- vector<3x2x1xf32>, memref<1x1x3x2x1xf32>
- return
-}
-
-// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing
-// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
-// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
-// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32>
-
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
- transform.vector.apply_rank_reducing_subview_patterns %module_op
- : (!transform.any_op) -> !transform.any_op
-}
-
-// -----
-
-func.func @transfer_read_and_vector_rank_reducing_to_0d(
- %arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.0 : f32
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
- memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
- return %v : vector<1x1x1xf32>
-}
-
-// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d
-// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
-// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
-
-transform.sequence failures(propagate) {
-^bb1(%module_op: !pdl.operation):
- transform.vector.apply_rank_reducing_subview_patterns %module_op
- : (!pdl.operation) -> !pdl.operation
-}
-
-// -----
-
-func.func @transfer_write_and_vector_rank_reducing_to_0d(
- %arg : memref<1x1x1x1x1xf32>,
- %vec : vector<1x1x1xf32>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
- vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
- return
-}
-
-// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d
-// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32>
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
-// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
-// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
More information about the Mlir-commits
mailing list