[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Dec 8 00:39:04 PST 2023
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/74817
Adds a new pattern to reorder:
shape_cast(arithmetic(a + b))
arithmetic(shape_cast(a), shape_cast(b)).
%mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
%cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
gets converted to:
%B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
%A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
%mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
This helps to, eventually, fold away shape_casts entirely. One specific
case where this is beneficial is when vectorising 1D depthwise
convolutions, e.g. `linalg.depthwise_conv_1d_nwc_wc`, with the channel
dimension is flattened, i.e. `flatten_1d_depthwise_conv` is set:
transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv}
Output after vectorisation:
func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
%c0 = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
%1 = vector.transfer_read %arg1[%c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
%2 = vector.transfer_read %arg2[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
%3 = vector.extract %1[0] : vector<3xi8> from vector<1x3xi8>
%4 = vector.shape_cast %0 : vector<1x8x3xi8> to vector<1x24xi8>
%5 = vector.shape_cast %2 : vector<1x8x3xi8> to vector<1x24xi8>
%6 = vector.broadcast %3 : vector<3xi8> to vector<1x8x3xi8>
%7 = vector.shape_cast %6 : vector<1x8x3xi8> to vector<1x24xi8>
%8 = arith.muli %4, %7 : vector<1x24xi8>
%9 = arith.addi %8, %5 : vector<1x24xi8>
%10 = vector.shape_cast %9 : vector<1x24xi8> to vector<1x8x3xi8>
vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, memref<1x8x3xi8>
Output after applying this patterns and other related patterns (e.g.
-test-vector-transfer-flatten-patterns -canonicalize)
func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
%c0_i8 = arith.constant 0 : i8
%c0 = arith.constant 0 : index
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
%0 = vector.transfer_read %collapse_shape[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
%collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1]] : memref<1x3xi8> into memref<3xi8>
%1 = vector.transfer_read %collapse_shape_0[%c0], %c0_i8 {in_bounds = [true]} : memref<3xi8>, vector<3xi8>
%collapse_shape_1 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
%2 = vector.transfer_read %collapse_shape_1[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
%3 = vector.broadcast %1 : vector<3xi8> to vector<1x8x3xi8>
%collapse_shape_2 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
%4 = vector.shape_cast %3 : vector<1x8x3xi8> to vector<24xi8>
%5 = arith.muli %0, %4 : vector<24xi8>
%6 = arith.addi %5, %2 : vector<24xi8>
vector.transfer_write %6, %collapse_shape_2[%c0] {in_bounds = [true]} : vector<24xi8>, memref<24xi8>
Both shape_casts and the leading unit dims have been folded away.
>From 46cac088edf06e037593d46da082b13695c46a18 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Dec 2023 21:15:33 +0000
Subject: [PATCH] [mlir][vector] Add pattern to reorder
shape_cast(arithmetic(a, b))
Adds a new pattern to reorder:
shape_cast(arithmetic(a + b))
arithmetic(shape_cast(a), shape_cast(b)).
%mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
%cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
gets converted to:
%B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
%A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
%mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
This helps to, eventually, fold away shape_casts entirely. One specific
case where this is beneficial is when vectorising 1D depthwise
convolutions, e.g. `linalg.depthwise_conv_1d_nwc_wc`, with the channel
dimension is flattened, i.e. `flatten_1d_depthwise_conv` is set:
transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv}
Output after vectorisation:
func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
%c0 = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
%1 = vector.transfer_read %arg1[%c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
%2 = vector.transfer_read %arg2[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
%3 = vector.extract %1[0] : vector<3xi8> from vector<1x3xi8>
%4 = vector.shape_cast %0 : vector<1x8x3xi8> to vector<1x24xi8>
%5 = vector.shape_cast %2 : vector<1x8x3xi8> to vector<1x24xi8>
%6 = vector.broadcast %3 : vector<3xi8> to vector<1x8x3xi8>
%7 = vector.shape_cast %6 : vector<1x8x3xi8> to vector<1x24xi8>
%8 = arith.muli %4, %7 : vector<1x24xi8>
%9 = arith.addi %8, %5 : vector<1x24xi8>
%10 = vector.shape_cast %9 : vector<1x24xi8> to vector<1x8x3xi8>
vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, memref<1x8x3xi8>
Output after applying this patterns and other related patterns (e.g.
-test-vector-transfer-flatten-patterns -canonicalize)
func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
%c0_i8 = arith.constant 0 : i8
%c0 = arith.constant 0 : index
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
%0 = vector.transfer_read %collapse_shape[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
%collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1]] : memref<1x3xi8> into memref<3xi8>
%1 = vector.transfer_read %collapse_shape_0[%c0], %c0_i8 {in_bounds = [true]} : memref<3xi8>, vector<3xi8>
%collapse_shape_1 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
%2 = vector.transfer_read %collapse_shape_1[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
%3 = vector.broadcast %1 : vector<3xi8> to vector<1x8x3xi8>
%collapse_shape_2 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
%4 = vector.shape_cast %3 : vector<1x8x3xi8> to vector<24xi8>
%5 = arith.muli %0, %4 : vector<24xi8>
%6 = arith.addi %5, %2 : vector<24xi8>
vector.transfer_write %6, %collapse_shape_2[%c0] {in_bounds = [true]} : vector<24xi8>, memref<24xi8>
Both shape_casts and the leading unit dims have been folded away.
.../Vector/Transforms/VectorRewritePatterns.h | 4 +
.../Transforms/VectorTransferOpTransforms.cpp | 1 +
.../Vector/Transforms/VectorTransforms.cpp | 111 ++++++++++++++++++
.../Vector/vector-transfer-flatten.mlir | 79 +++++++++++++
4 files changed, 195 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 08c08172d0531e..b0e626cf39b0f8 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -294,6 +294,10 @@ void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Collect a set of vector.shape_cast folding patterns.
+void populateArithOpDropUnitDimPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collect a set of patterns to flatten n-D vector transfers on contiguous
/// memref.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index ed42e6508b4310..ac475566ccdb1e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -922,4 +922,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
patterns.getContext(), benefit);
populateShapeCastFoldingPatterns(patterns, benefit);
+ populateReorderShapeCastPatterns(patterns, benefit);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6e7fab293d3a1c..b38442f7338144 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1446,6 +1446,108 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
+/// Reorders:
+/// shape_cast(arithmetic(a + b))
+/// as
+/// arithmetic(shape_cast(a), shape_cast(b)).
+/// Ex:
+/// ```
+/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+/// gets converted to:
+/// ```
+/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+/// Here is another example where this pattern is helpful:
+/// ```
+/// %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+/// %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+/// %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+/// %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+/// gets folded as:
+/// %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+/// %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+/// vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `AritOp`.
+template <typename ArithOp>
+struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::isa_and_present<ArithOp>(
+ shapeCastOp.getSource().getDefiningOp()))
+ return failure();
+ auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+ auto sourceVectorType =
+ dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+ auto resultVectorType =
+ dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+ if (!sourceVectorType || !resultVectorType)
+ return failure();
+ // Either the leading or the trailing dims of the input should be
+ // non-scalable 1.
+ if (((sourceVectorType.getShape().back() != 1) ||
+ (sourceVectorType.getScalableDims().back())) &&
+ ((sourceVectorType.getShape().front() != 1) ||
+ (sourceVectorType.getScalableDims().front())))
+ return failure();
+ // Does this shape_cast fold the input vector?
+ if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+ return failure();
+ // Does this shape_cast fold the _unit_ dim?
+ if (llvm::any_of(resultVectorType.getShape(),
+ [](int64_t dim) { return (dim == 1); }))
+ return failure();
+ auto loc = shapeCastOp->getLoc();
+ // shape_cast(a)
+ auto *lhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+ arithOp->getOperands()[0], resultVectorType,
+ shapeCastOp->getAttrs());
+ // shape_cast(b)
+ auto *rhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+ arithOp->getOperands()[1], resultVectorType,
+ shapeCastOp->getAttrs());
+ // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
+ rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs->getResult(0),
+ rhs->getResult(0));
+ return success();
+ }
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1514,6 +1616,15 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
+void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ReorderArithAndShapeCast<arith::AddIOp>,
+ ReorderArithAndShapeCast<arith::AddFOp>,
+ ReorderArithAndShapeCast<arith::MulIOp>,
+ ReorderArithAndShapeCast<arith::MulFOp>>(patterns.getContext(),
+ benefit);
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ebec2274655e46..232037241fdb7d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -254,3 +254,82 @@ func.func @transfer_read_flattenable_negative2(
// CHECK-LABEL: func @transfer_read_flattenable_negative2
// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
+// -----
+func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
+ %arg1 : vector<1x8xi32>) -> vector<8xi32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32>
+ %add = arith.addi %sc_arg0, %arg1 : vector<1x8xi32>
+ %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+ return %res : vector<8xi32>
+// CHECK-LABEL: func.func @fold_unit_dim_add(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK: return %[[VAL_4]] : vector<8xi32>
+// -----
+func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
+ %arg1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
+ %add = arith.mulf %sc_arg0, %arg1 : vector<1x8x[2]xf32>
+ %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+ return %res : vector<8x[2]xf32>
+// CHECK-LABEL: func.func @fold_unit_dim_mulf(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
+// CHECK: return %[[VAL_4]] : vector<8x[2]xf32>
+// -----
+// All shape casts are folded away
+func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
+ %arg1 : vector<8xi32>,
+ %arg2 : vector<8xi32>) -> vector<8xi32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+ %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+ %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+ %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+ %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+ %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+ return %res : vector<8xi32>
+// CHECK-LABEL: func.func @fold_unit_dims_entirely(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<8xi32>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
+// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
+// CHECK: return %[[VAL_4]] : vector<8xi32>
+// -----
+func.func @dont_fold_unit_scalable(%arg0 : vector<8x[1]xi32>,
+ %arg1 : vector<[1]x8xi32>) -> vector<8xi32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x[1]xi32> to vector<[1]x8xi32>
+ %add = arith.addi %sc_arg0, %arg1 : vector<[1]x8xi32>
+ %res = vector.shape_cast %add : vector<[1]x8xi32> to vector<8xi32>
+ return %res : vector<8xi32>
+// CHECK-LABEL: func.func @fold_unit_dim_add(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK: return %[[VAL_4]] : vector<8xi32>
More information about the Mlir-commits
mailing list