[Mlir-commits] [mlir] f11bda7 - [mlir][linalg] Use vector.shuffle to flatten conv filter (#75038)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 15 09:57:04 PST 2023
Author: Andrzej WarzyĆski
Date: 2023-12-15T17:56:59Z
New Revision: f11bda78c8fc551cf3e22cd5caa4005c329b904f
URL: https://github.com/llvm/llvm-project/commit/f11bda78c8fc551cf3e22cd5caa4005c329b904f
DIFF: https://github.com/llvm/llvm-project/commit/f11bda78c8fc551cf3e22cd5caa4005c329b904f.diff
LOG: [mlir][linalg] Use vector.shuffle to flatten conv filter (#75038)
Updates the vectorisation of 1D depthwise convolution when flattening
the channel dimension (introduced in #71918). In particular - how the
convolution filter is "flattened". ATM, the vectoriser will use
`vector.shape_cast`:
```mlir
%b_filter = vector.broadcast %filter : vector<4xf32> to vector<3x2x4xf32>
%sc_filter = vector.shape_cast %b_filter : vector<3x2x4xf32> to vector<3x8xf32>
```
This lowering is not ideal - `vector.shape_cast` can be convenient when
it's folded away, but that's not happening in this case. Instead, this
patch updates the vectoriser to use `vector.shuffle` (the overall result
is identical):
```mlir
%sh_filter = vector.shuffle %filter, %filter
[0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
%b_filter = vector.broadcast %sh_filter : vector<8xf32> to vector<3x8xf32>
```
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c21d007c931b9b..d956fd4fdd9bd8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2910,17 +2910,16 @@ struct Conv1DGenerator
for (int64_t w = 0; w < wSize; w += wSizeStep) {
Value lhsVal = lhsVals[linearIndex(kw, w)];
Value resVal = resVals[w];
- ShapedType filterBCastTy = cast<ShapedType>(resVal.getType());
if (flatten) {
- // Flatten the input and filter vectors (collapse the channel
+ // Flatten the input and output vectors (collapse the channel
// dimension)
lhsVal = rewriter.create<vector::ShapeCastOp>(
loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
resVals[w]);
}
- resVals[w] = depthwiseConv1dSliceAsMulAcc(
- rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
+ resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
+ rhsVals[kw], resVal, flatten);
if (flatten) {
// Un-flatten the output vector (restore the channel dimension)
resVals[w] = rewriter.create<vector::ShapeCastOp>(
@@ -2964,20 +2963,32 @@ struct Conv1DGenerator
/// to MulAcc.
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res,
- ShapedType bcastTy, bool flatten) {
+ bool flatten) {
auto rhsTy = cast<ShapedType>(rhs.getType());
auto resTy = cast<ShapedType>(res.getType());
// TODO(suderman): Change this to use a vector.ima intrinsic.
lhs = promote(rewriter, loc, lhs, resTy);
- rhs = rewriter.create<vector::BroadcastOp>(
- loc, bcastTy.clone(rhsTy.getElementType()), rhs);
if (flatten) {
- // Flatten the channel dimension
- rhs = rewriter.create<vector::ShapeCastOp>(
- loc, resTy.clone(rhsTy.getElementType()), rhs);
+ // There are two options for handling the filter:
+ // * shape_cast(broadcast(filter))
+ // * broadcast(shuffle(filter))
+ // Opt for the option without shape_cast to simplify the codegen.
+ auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
+ auto resSize = res.getType().cast<VectorType>().getShape()[1];
+
+ SmallVector<int64_t, 16> indicies;
+ for (int i = 0; i < resSize / rhsSize; ++i) {
+ for (int j = 0; j < rhsSize; ++j)
+ indicies.push_back(j);
+ }
+
+ rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
}
+ // Broadcast the filter to match the output vector
+ rhs = rewriter.create<vector::BroadcastOp>(
+ loc, resTy.clone(rhsTy.getElementType()), rhs);
rhs = promote(rewriter, loc, rhs, resTy);
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
index a242d09671825b..afb59cb26188a6 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -36,9 +36,10 @@ module attributes {transform.with_named_sequence} {
/// w == 0, kw = 0
// CHECK: %[[SC_INPUT:.*]] = vector.shape_cast %[[V_INPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
// CHECK: %[[SC_OUTPUT:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[B_FILTER:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<3xi8> to vector<1x8x3xi8>
-// CHECK: %[[SC_FILTER:.*]] = vector.shape_cast %[[B_FILTER]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[SC_FILTER]] : vector<1x24xi8>
+// CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
+// CHECK-SAME: [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK: %[[B_FILTER:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<24xi8> to vector<1x24xi8>
+// CHECK: %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[B_FILTER]] : vector<1x24xi8>
// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[SC_OUTPUT]] : vector<1x24xi8>
// Write the result back in one shot.
@@ -80,15 +81,17 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
/// w == 0, kw = 0
// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32>
// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
-// CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[SC_B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
+// CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
+// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<8xf32> to vector<3x8xf32>
+// CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
/// w == 0, kw = 1
// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[B_V_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32>
-// CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_V_FILTER_1]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[SC_B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
+// CHECK: %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
+// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[SH_FILTER_1]] : vector<8xf32> to vector<3x8xf32>
+// CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
// Write the result back in one shot.
// CHECK: %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32>
@@ -138,19 +141,21 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5
// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x8xi8>
// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xi32> to vector<3x8xi32>
// CHECK: %[[EXT_INPUT_0:.*]] = arith.extsi %[[SC_V_INPUT_0]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
-// CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x8xi8>
-// CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[SC_B_FILTER_0]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x8xi32>
+// CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
+// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+// CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[SH_FILTER_0]] : vector<8xi8> to vector<8xi32>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[EXT_FILTER_0]] : vector<8xi32> to vector<3x8xi32>
+// CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[B_FILTER_0]] : vector<3x8xi32>
// CHECK: %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xi32>
/// w == 0, kw = 1
// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x8xi8>
// CHECK: %[[EXT_INPUT_1:.*]] = arith.extsi %[[SC_V_INPUT_1]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
-// CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x8xi8>
-// CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[SC_B_FILTER_1]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x8xi32>
+// CHECK: %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
+// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+// CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[SH_FILTER_1]] : vector<8xi8> to vector<8xi32>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[EXT_FILTER_1]] : vector<8xi32> to vector<3x8xi32>
+// CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[B_FILTER_1]] : vector<3x8xi32>
// CHECK: %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x8xi32>
// Write the result back in one shot.
@@ -223,69 +228,60 @@ func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4
/// w == 0, kw == 0
// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
// CHECK: %[[VAL_24:.*]] = vector.shape_cast %[[V_OUTPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_26:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[VAL_26]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[B_FILTER_0]] : vector<3x4xi8>
// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_27]], %[[VAL_24]] : vector<3x4xi8>
/// w == 1, kw == 0
// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[V_OUTPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[VAL_32]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_0_1:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[B_FILTER_0_1]] : vector<3x4xi8>
// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_30]] : vector<3x4xi8>
/// w == 2, kw == 0
// CHECK: %[[VAL_35:.*]] = vector.shape_cast %[[V_INPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[V_OUTPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_38:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[VAL_38]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_0_2:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[B_FILTER_0_2]] : vector<3x4xi8>
// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_36]] : vector<3x4xi8>
/// w == 3, kw == 1
// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[V_INPUT_3]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[VAL_43]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[B_FILTER_1]] : vector<3x4xi8>
// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_28]] : vector<3x4xi8>
/// w == 4, kw == 1
// CHECK: %[[VAL_46:.*]] = vector.shape_cast %[[V_INPUT_4]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_48:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[VAL_48]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_1_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[B_FILTER_1_1]] : vector<3x4xi8>
// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_34]] : vector<3x4xi8>
/// w == 5, kw == 1
// CHECK: %[[VAL_51:.*]] = vector.shape_cast %[[V_INPUT_5]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_53:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[VAL_53]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_1_2:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[B_FILTER_1_2]] : vector<3x4xi8>
// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_54]], %[[VAL_40]] : vector<3x4xi8>
/// w == 6, kw == 2
// CHECK: %[[VAL_56:.*]] = vector.shape_cast %[[V_INPUT_6]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_58:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[VAL_58]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[B_FILTER_2]] : vector<3x4xi8>
// CHECK: %[[VAL_60:.*]] = arith.addi %[[VAL_59]], %[[VAL_45]] : vector<3x4xi8>
/// w == 7, kw == 2
// CHECK: %[[VAL_61:.*]] = vector.shape_cast %[[VAL_60]] : vector<3x4xi8> to vector<3x1x4xi8>
// CHECK: %[[VAL_62:.*]] = vector.shape_cast %[[V_INPUT_7]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_64:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[VAL_64]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_2_1:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[B_FILTER_2_1]] : vector<3x4xi8>
// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_65]], %[[VAL_50]] : vector<3x4xi8>
/// w == 8, kw == 2
// CHECK: %[[VAL_67:.*]] = vector.shape_cast %[[VAL_66]] : vector<3x4xi8> to vector<3x1x4xi8>
// CHECK: %[[VAL_68:.*]] = vector.shape_cast %[[V_INPUT_8]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_70:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[VAL_70]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_2_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[B_FILTER_2_2]] : vector<3x4xi8>
// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_71]], %[[VAL_55]] : vector<3x4xi8>
// Write the result back.
More information about the Mlir-commits
mailing list