[Mlir-commits] [mlir] [mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (PR #74817)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Dec 13 00:59:00 PST 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/74817
>From 0d002147e2292d688267299fe28953818a113027 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Dec 2023 21:15:33 +0000
Subject: [PATCH 1/4] [mlir][vector] Add pattern to reorder
shape_cast(arithmetic(a, b))
Adds a new pattern to reorder:
shape_cast(arithmetic(a + b))
as
arithmetic(shape_cast(a), shape_cast(b)).
Example:
```mlir
%mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
%cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
```
gets converted to:
```mlir
%B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
%A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
%mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
```
This helps to, eventually, fold away shape_casts entirely. One specific
case where this is beneficial is when vectorising 1D depthwise
convolutions, e.g. `linalg.depthwise_conv_1d_nwc_wc`, with the channel
dimension is flattened, i.e. `flatten_1d_depthwise_conv` is set:
```
transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv}
```
Output after vectorisation:
```mlir
func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
%c0 = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
%1 = vector.transfer_read %arg1[%c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
%2 = vector.transfer_read %arg2[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
%3 = vector.extract %1[0] : vector<3xi8> from vector<1x3xi8>
%4 = vector.shape_cast %0 : vector<1x8x3xi8> to vector<1x24xi8>
%5 = vector.shape_cast %2 : vector<1x8x3xi8> to vector<1x24xi8>
%6 = vector.broadcast %3 : vector<3xi8> to vector<1x8x3xi8>
%7 = vector.shape_cast %6 : vector<1x8x3xi8> to vector<1x24xi8>
%8 = arith.muli %4, %7 : vector<1x24xi8>
%9 = arith.addi %8, %5 : vector<1x24xi8>
%10 = vector.shape_cast %9 : vector<1x24xi8> to vector<1x8x3xi8>
vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, memref<1x8x3xi8>
return
}
```
Output after applying this patterns and other related patterns (e.g.
-test-vector-transfer-flatten-patterns -canonicalize)
```mlir
func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
%c0_i8 = arith.constant 0 : i8
%c0 = arith.constant 0 : index
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
%0 = vector.transfer_read %collapse_shape[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
%collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1]] : memref<1x3xi8> into memref<3xi8>
%1 = vector.transfer_read %collapse_shape_0[%c0], %c0_i8 {in_bounds = [true]} : memref<3xi8>, vector<3xi8>
%collapse_shape_1 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
%2 = vector.transfer_read %collapse_shape_1[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
%3 = vector.broadcast %1 : vector<3xi8> to vector<1x8x3xi8>
%collapse_shape_2 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
%4 = vector.shape_cast %3 : vector<1x8x3xi8> to vector<24xi8>
%5 = arith.muli %0, %4 : vector<24xi8>
%6 = arith.addi %5, %2 : vector<24xi8>
vector.transfer_write %6, %collapse_shape_2[%c0] {in_bounds = [true]} : vector<24xi8>, memref<24xi8>
return
}
```
Both shape_casts and the leading unit dims have been folded away.
---
.../Vector/Transforms/VectorRewritePatterns.h | 4 +
.../Transforms/VectorTransferOpTransforms.cpp | 1 +
.../Vector/Transforms/VectorTransforms.cpp | 111 ++++++++++++++++++
.../Vector/vector-transfer-flatten.mlir | 59 ++++++++++
4 files changed, 175 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 08c08172d0531..7102ed81ec57d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -294,6 +294,10 @@ void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Collect a set of vector.shape_cast folding patterns.
+void populateReorderShapeCastPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Collect a set of patterns to flatten n-D vector transfers on contiguous
/// memref.
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index ed42e6508b431..ac475566ccdb1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -922,4 +922,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
FlattenContiguousRowMajorTransferWritePattern>(
patterns.getContext(), benefit);
populateShapeCastFoldingPatterns(patterns, benefit);
+ populateReorderShapeCastPatterns(patterns, benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6e7fab293d3a1..b38442f733814 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1446,6 +1446,108 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
+/// Reorders:
+/// shape_cast(arithmetic(a + b))
+/// as
+/// arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+///
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+///
+/// Here is another example where this pattern is helpful:
+/// ```
+/// %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+/// %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+/// %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+/// %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+///
+/// gets folded as:
+///
+///```
+/// %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+/// %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+/// vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `AritOp`.
+template <typename ArithOp>
+struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::isa_and_present<ArithOp>(
+ shapeCastOp.getSource().getDefiningOp()))
+ return failure();
+
+ auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+
+ auto sourceVectorType =
+ dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+ auto resultVectorType =
+ dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+ if (!sourceVectorType || !resultVectorType)
+ return failure();
+
+ // Either the leading or the trailing dims of the input should be
+ // non-scalable 1.
+ if (((sourceVectorType.getShape().back() != 1) ||
+ (sourceVectorType.getScalableDims().back())) &&
+ ((sourceVectorType.getShape().front() != 1) ||
+ (sourceVectorType.getScalableDims().front())))
+ return failure();
+
+ // Does this shape_cast fold the input vector?
+ if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+ return failure();
+
+ // Does this shape_cast fold the _unit_ dim?
+ if (llvm::any_of(resultVectorType.getShape(),
+ [](int64_t dim) { return (dim == 1); }))
+ return failure();
+
+ auto loc = shapeCastOp->getLoc();
+
+ // shape_cast(a)
+ auto *lhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+ arithOp->getOperands()[0], resultVectorType,
+ shapeCastOp->getAttrs());
+ // shape_cast(b)
+ auto *rhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+ arithOp->getOperands()[1], resultVectorType,
+ shapeCastOp->getAttrs());
+
+ // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
+ rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs->getResult(0),
+ rhs->getResult(0));
+
+ return success();
+ }
+};
+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1514,6 +1616,15 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
}
+void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ReorderArithAndShapeCast<arith::AddIOp>,
+ ReorderArithAndShapeCast<arith::AddFOp>,
+ ReorderArithAndShapeCast<arith::MulIOp>,
+ ReorderArithAndShapeCast<arith::MulFOp>>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<BubbleDownVectorBitCastForExtract,
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ebec2274655e4..88a86755f9f7a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -254,3 +254,62 @@ func.func @transfer_read_flattenable_negative2(
// CHECK-LABEL: func @transfer_read_flattenable_negative2
// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
+
+// -----
+
+func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
+ %arg1 : vector<1x8xi32>) -> vector<8xi32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32>
+ %add = arith.addi %sc_arg0, %arg1 : vector<1x8xi32>
+ %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+ return %res : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dim_add(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK: return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
+ %arg1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
+ %add = arith.mulf %sc_arg0, %arg1 : vector<1x8x[2]xf32>
+ %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+ return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dim_mulf(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
+// CHECK: return %[[VAL_4]] : vector<8x[2]xf32>
+
+// -----
+
+// All shape casts are folded away
+
+func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
+ %arg1 : vector<8xi32>,
+ %arg2 : vector<8xi32>) -> vector<8xi32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+ %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+ %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+ %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+ %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+ %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+ return %res : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dims_entirely(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi32>, %[[VAL_1:.*]]: vector<8xi32>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
+// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
+// CHECK: return %[[VAL_4]] : vector<8xi32>
>From f54addbc2b07cd6a99e684304b89bae753bce1ae Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 11 Dec 2023 08:18:30 +0000
Subject: [PATCH 2/4] fixup! [mlir][vector] Add pattern to reorder
shape_cast(arithmetic(a, b))
Address PR comments
---
.../Vector/Transforms/VectorTransforms.cpp | 40 +++++++++----------
1 file changed, 20 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b38442f733814..03aaf85226fbc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1492,39 +1492,40 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
/// vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
/// ```
/// In addition, the input vector should be the result of an arithmetic
-/// operation, `AritOp`.
+/// operation, `ArithOp`.
template <typename ArithOp>
struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
- if (!llvm::isa_and_present<ArithOp>(
- shapeCastOp.getSource().getDefiningOp()))
- return failure();
-
auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+ if (!llvm::isa_and_present<ArithOp>(arithOp))
+ return failure();
auto sourceVectorType =
- dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+ dyn_cast<VectorType>(shapeCastOp.getSource().getType());
auto resultVectorType =
- dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+ dyn_cast<VectorType>(shapeCastOp.getResult().getType());
if (!sourceVectorType || !resultVectorType)
return failure();
// Either the leading or the trailing dims of the input should be
// non-scalable 1.
- if (((sourceVectorType.getShape().back() != 1) ||
- (sourceVectorType.getScalableDims().back())) &&
- ((sourceVectorType.getShape().front() != 1) ||
- (sourceVectorType.getScalableDims().front())))
+ bool leadDimUnitFixed = ((sourceVectorType.getShape().back() != 1) ||
+ (sourceVectorType.getScalableDims().back()));
+ bool trailinDimUnitFixed = ((sourceVectorType.getShape().front() != 1) ||
+ (sourceVectorType.getScalableDims().front()));
+ if (!leadDimUnitFixed && !trailinDimUnitFixed)
return failure();
// Does this shape_cast fold the input vector?
if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
return failure();
- // Does this shape_cast fold the _unit_ dim?
+ // Does this shape_cast fold the traling/leading _unit_ dim?
+ // TODO: Even when the trailing/leading unit dims are folded, there might
+ // still be some "inner" unit dims left.
if (llvm::any_of(resultVectorType.getShape(),
[](int64_t dim) { return (dim == 1); }))
return failure();
@@ -1532,17 +1533,16 @@ struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
auto loc = shapeCastOp->getLoc();
// shape_cast(a)
- auto *lhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
- arithOp->getOperands()[0], resultVectorType,
- shapeCastOp->getAttrs());
+ auto lhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+ arithOp->getOperands()[0],
+ shapeCastOp->getAttrs());
// shape_cast(b)
- auto *rhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
- arithOp->getOperands()[1], resultVectorType,
- shapeCastOp->getAttrs());
+ auto rhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+ arithOp->getOperands()[1],
+ shapeCastOp->getAttrs());
// Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
- rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs->getResult(0),
- rhs->getResult(0));
+ rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs, rhs);
return success();
}
>From 0ebb9c31b8cac0be465fae8a3cd1d4385d3b45b8 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 12 Dec 2023 19:53:30 +0000
Subject: [PATCH 3/4] fixup! [mlir][vector] Add pattern to reorder
shape_cast(arithmetic(a, b))
Switch to matching an elementwise Op
---
.../Vector/Transforms/VectorTransforms.cpp | 127 +++++++-----------
1 file changed, 50 insertions(+), 77 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 03aaf85226fbc..dcac4f925c5c3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1446,10 +1446,15 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
-/// Reorders:
-/// shape_cast(arithmetic(a + b))
-/// as
-/// arithmetic(shape_cast(a), shape_cast(b)).
+/// Replace:
+/// elementwise(a, b)
+/// with:
+/// sc_a = shape_cast(a)
+/// sc_b = shape_cast(b)
+/// res = elementwise(sc_a, sc_b)
+/// return shape_cast(res)
+/// for which `a` and `b` are vectors of rank > 2 and have unit leading and/or
+/// trailing dimension.
///
/// Ex:
/// ```
@@ -1464,85 +1469,57 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
/// ```
-///
-/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
-/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
-/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
-/// the addition shape_cast's are eventually folded away.
-///
-/// Here is another example where this pattern is helpful:
-/// ```
-/// %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
-/// %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
-/// %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
-/// %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
-/// %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
-/// %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
-/// ```
-///
-/// gets folded as:
-///
-///```
-/// %0 = arith.muli %arg0, %arg1 : vector<8xi32>
-/// %res = arith.addi %0, %arg2 : vector<8xi32>
-/// ```
-/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
-/// dim, e.g.:
-/// ```
-/// vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
-/// ```
-/// In addition, the input vector should be the result of an arithmetic
-/// operation, `ArithOp`.
-template <typename ArithOp>
-struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+struct DropUnitDimFromElementwiseOps final
+ : public OpTraitRewritePattern<OpTrait::Elementwise> {
+ using OpTraitRewritePattern::OpTraitRewritePattern;
+ LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- auto *arithOp = shapeCastOp.getSource().getDefiningOp();
- if (!llvm::isa_and_present<ArithOp>(arithOp))
- return failure();
-
- auto sourceVectorType =
- dyn_cast<VectorType>(shapeCastOp.getSource().getType());
- auto resultVectorType =
- dyn_cast<VectorType>(shapeCastOp.getResult().getType());
- if (!sourceVectorType || !resultVectorType)
+ if (op->getNumResults() != 1)
return failure();
- // Either the leading or the trailing dims of the input should be
- // non-scalable 1.
- bool leadDimUnitFixed = ((sourceVectorType.getShape().back() != 1) ||
- (sourceVectorType.getScalableDims().back()));
- bool trailinDimUnitFixed = ((sourceVectorType.getShape().front() != 1) ||
- (sourceVectorType.getScalableDims().front()));
- if (!leadDimUnitFixed && !trailinDimUnitFixed)
+ // Check the pre-condiitions. For `Elementwise` Ops all operands
+ // are guaranteed to have identical shapes and it suffices to only check the
+ // first one.
+ auto op1 = op->getOperands()[0];
+ auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
+ if (!sourceVectorType)
return failure();
- // Does this shape_cast fold the input vector?
- if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+ if (sourceVectorType.getRank() < 2)
return failure();
- // Does this shape_cast fold the traling/leading _unit_ dim?
- // TODO: Even when the trailing/leading unit dims are folded, there might
- // still be some "inner" unit dims left.
- if (llvm::any_of(resultVectorType.getShape(),
- [](int64_t dim) { return (dim == 1); }))
+ bool trailingDimUnitFixed = ((sourceVectorType.getShape().back() == 1) &&
+ (!sourceVectorType.getScalableDims().back()));
+ bool leadDimUnitFixed = ((sourceVectorType.getShape().front() == 1) &&
+ (!sourceVectorType.getScalableDims().front()));
+ if (!leadDimUnitFixed && !trailingDimUnitFixed)
return failure();
- auto loc = shapeCastOp->getLoc();
+ // Drop leading/trailing unit dim by applying vector.shape_cast to all
+ // operands
+ auto elTy = sourceVectorType.getElementType();
+ VectorType newVType =
+ leadDimUnitFixed
+ ? VectorType::get(sourceVectorType.getShape().drop_front(1), elTy,
+ sourceVectorType.getScalableDims().drop_front(1))
+ : VectorType::get(sourceVectorType.getShape().drop_back(1), elTy,
+ sourceVectorType.getScalableDims().drop_back(1));
+ SmallVector<Value> newOperands;
+ auto loc = op->getLoc();
+ for (auto operand : op->getOperands()) {
+ auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
+ newOperands.push_back(opSC);
+ }
- // shape_cast(a)
- auto lhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
- arithOp->getOperands()[0],
- shapeCastOp->getAttrs());
- // shape_cast(b)
- auto rhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
- arithOp->getOperands()[1],
- shapeCastOp->getAttrs());
+ // Create an updated elementwise Op without leading/trailing unit dim
+ Operation *elementwiseOp =
+ rewriter.create(op->getLoc(), op->getName().getIdentifier(),
+ newOperands, newVType, op->getAttrs());
- // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
- rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs, rhs);
+ // Restore the leading/trailing unit dim by applying vector.shape_cast to
+ // the result
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
+ elementwiseOp->getResults()[0]);
return success();
}
@@ -1618,11 +1595,7 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
- patterns.add<ReorderArithAndShapeCast<arith::AddIOp>,
- ReorderArithAndShapeCast<arith::AddFOp>,
- ReorderArithAndShapeCast<arith::MulIOp>,
- ReorderArithAndShapeCast<arith::MulFOp>>(patterns.getContext(),
- benefit);
+ patterns.add<DropUnitDimFromElementwiseOps>(patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
>From ac5c0163cf354998d0977a925b8fa80338b4d13e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 13 Dec 2023 08:58:26 +0000
Subject: [PATCH 4/4] fixup! fixup! [mlir][vector] Add pattern to reorder
shape_cast(arithmetic(a, b))
Update comment, add ShapeCastOpFolder to the list of patterns
---
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index dcac4f925c5c3..2ec1183000bea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1468,7 +1468,12 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
+/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
/// ```
+///
+/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
+/// `%cast`.
struct DropUnitDimFromElementwiseOps final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -1595,7 +1600,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
- patterns.add<DropUnitDimFromElementwiseOps>(patterns.getContext(), benefit);
+ patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
More information about the Mlir-commits
mailing list