[Mlir-commits] [mlir] f11bda7 - [mlir][linalg] Use vector.shuffle to flatten conv filter (#75038)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 15 09:57:04 PST 2023


Author: Andrzej WarzyƄski
Date: 2023-12-15T17:56:59Z
New Revision: f11bda78c8fc551cf3e22cd5caa4005c329b904f

URL: https://github.com/llvm/llvm-project/commit/f11bda78c8fc551cf3e22cd5caa4005c329b904f
DIFF: https://github.com/llvm/llvm-project/commit/f11bda78c8fc551cf3e22cd5caa4005c329b904f.diff

LOG: [mlir][linalg] Use vector.shuffle to flatten conv filter (#75038)

Updates the vectorisation of 1D depthwise convolution when flattening
the channel dimension (introduced in #71918). In particular - how the
convolution filter is "flattened". ATM, the vectoriser will use
`vector.shape_cast`:

```mlir
  %b_filter = vector.broadcast %filter : vector<4xf32> to vector<3x2x4xf32>
  %sc_filter = vector.shape_cast %b_filter : vector<3x2x4xf32> to vector<3x8xf32>
```

This lowering is not ideal - `vector.shape_cast` can be convenient when
it's folded away, but that's not happening in this case. Instead, this
patch updates the vectoriser to use `vector.shuffle` (the overall result
is identical):

```mlir
  %sh_filter = vector.shuffle %filter, %filter
      [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
  %b_filter = vector.broadcast %sh_filter : vector<8xf32> to vector<3x8xf32>
```

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c21d007c931b9b..d956fd4fdd9bd8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2910,17 +2910,16 @@ struct Conv1DGenerator
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         Value lhsVal = lhsVals[linearIndex(kw, w)];
         Value resVal = resVals[w];
-        ShapedType filterBCastTy = cast<ShapedType>(resVal.getType());
         if (flatten) {
-          // Flatten the input and filter vectors (collapse the channel
+          // Flatten the input and output vectors (collapse the channel
           // dimension)
           lhsVal = rewriter.create<vector::ShapeCastOp>(
               loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
           resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
                                                         resVals[w]);
         }
-        resVals[w] = depthwiseConv1dSliceAsMulAcc(
-            rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
+        resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
+                                                  rhsVals[kw], resVal, flatten);
         if (flatten) {
           // Un-flatten the output vector (restore the channel dimension)
           resVals[w] = rewriter.create<vector::ShapeCastOp>(
@@ -2964,20 +2963,32 @@ struct Conv1DGenerator
   /// to MulAcc.
   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
                                      Value lhs, Value rhs, Value res,
-                                     ShapedType bcastTy, bool flatten) {
+                                     bool flatten) {
     auto rhsTy = cast<ShapedType>(rhs.getType());
     auto resTy = cast<ShapedType>(res.getType());
 
     // TODO(suderman): Change this to use a vector.ima intrinsic.
     lhs = promote(rewriter, loc, lhs, resTy);
 
-    rhs = rewriter.create<vector::BroadcastOp>(
-        loc, bcastTy.clone(rhsTy.getElementType()), rhs);
     if (flatten) {
-      // Flatten the channel dimension
-      rhs = rewriter.create<vector::ShapeCastOp>(
-          loc, resTy.clone(rhsTy.getElementType()), rhs);
+      // There are two options for handling the filter:
+      //  * shape_cast(broadcast(filter))
+      //  * broadcast(shuffle(filter))
+      // Opt for the option without shape_cast to simplify the codegen.
+      auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
+      auto resSize = res.getType().cast<VectorType>().getShape()[1];
+
+      SmallVector<int64_t, 16> indicies;
+      for (int i = 0; i < resSize / rhsSize; ++i) {
+        for (int j = 0; j < rhsSize; ++j)
+          indicies.push_back(j);
+      }
+
+      rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
     }
+    // Broadcast the filter to match the output vector
+    rhs = rewriter.create<vector::BroadcastOp>(
+        loc, resTy.clone(rhsTy.getElementType()), rhs);
 
     rhs = promote(rewriter, loc, rhs, resTy);
 

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
index a242d09671825b..afb59cb26188a6 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -36,9 +36,10 @@ module attributes {transform.with_named_sequence} {
 /// w == 0, kw = 0
 // CHECK:           %[[SC_INPUT:.*]] = vector.shape_cast %[[V_INPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
 // CHECK:           %[[SC_OUTPUT:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK:           %[[B_FILTER:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<3xi8> to vector<1x8x3xi8>
-// CHECK:           %[[SC_FILTER:.*]] = vector.shape_cast %[[B_FILTER]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK:           %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[SC_FILTER]] : vector<1x24xi8>
+// CHECK:           %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
+// CHECK-SAME:        [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK:           %[[B_FILTER:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<24xi8> to vector<1x24xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[B_FILTER]] : vector<1x24xi8>
 // CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[SC_OUTPUT]] : vector<1x24xi8>
 
 // Write the result back in one shot.
@@ -80,15 +81,17 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
 /// w == 0, kw = 0
 // CHECK:           %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32>
 // CHECK:           %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
-// CHECK:           %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK:           %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[SC_B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
+// CHECK:           %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]] 
+// CHECK-SAME:        [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<8xf32> to vector<3x8xf32>
+// CHECK:           %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
 
 /// w == 0, kw = 1
 // CHECK:           %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK:           %[[B_V_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32>
-// CHECK:           %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_V_FILTER_1]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK:           %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[SC_B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
+// CHECK:           %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]] 
+// CHECK-SAME:        [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[SH_FILTER_1]] : vector<8xf32> to vector<3x8xf32>
+// CHECK:           %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
 
 // Write the result back in one shot.
 //      CHECK:   %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32>
@@ -138,19 +141,21 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5
 //      CHECK:  %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x8xi8>
 //      CHECK:  %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xi32> to vector<3x8xi32>
 //      CHECK:  %[[EXT_INPUT_0:.*]] = arith.extsi %[[SC_V_INPUT_0]] : vector<3x8xi8> to vector<3x8xi32>
-//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
-//      CHECK:  %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x8xi8>
-//      CHECK:  %[[EXT_FILTER_0:.*]] = arith.extsi %[[SC_B_FILTER_0]] : vector<3x8xi8> to vector<3x8xi32>
-//      CHECK:  %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x8xi32>
+//      CHECK:  %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
+//      CHECK-SAME:  [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+//      CHECK:  %[[EXT_FILTER_0:.*]] = arith.extsi %[[SH_FILTER_0]] : vector<8xi8> to vector<8xi32>
+//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[EXT_FILTER_0]] : vector<8xi32> to vector<3x8xi32>
+//      CHECK:  %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[B_FILTER_0]] : vector<3x8xi32>
 //      CHECK:  %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xi32>
 
 /// w == 0, kw = 1
 //      CHECK:  %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x8xi8>
 //      CHECK:  %[[EXT_INPUT_1:.*]] = arith.extsi %[[SC_V_INPUT_1]] : vector<3x8xi8> to vector<3x8xi32>
-//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
-//      CHECK:  %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x8xi8>
-//      CHECK:  %[[EXT_FILTER_1:.*]] = arith.extsi %[[SC_B_FILTER_1]] : vector<3x8xi8> to vector<3x8xi32>
-//      CHECK:  %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x8xi32>
+//      CHECK:  %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
+//      CHECK-SAME:  [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+//      CHECK:  %[[EXT_FILTER_1:.*]] = arith.extsi %[[SH_FILTER_1]] : vector<8xi8> to vector<8xi32>
+//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[EXT_FILTER_1]] : vector<8xi32> to vector<3x8xi32>
+//      CHECK:  %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[B_FILTER_1]] : vector<3x8xi32>
 //      CHECK:  %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x8xi32>
 
 // Write the result back in one shot.
@@ -223,69 +228,60 @@ func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4
 /// w == 0, kw == 0
 // CHECK:           %[[VAL_23:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
 // CHECK:           %[[VAL_24:.*]] = vector.shape_cast %[[V_OUTPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK:           %[[VAL_26:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[VAL_26]] : vector<3x4xi8>
+// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[B_FILTER_0]] : vector<3x4xi8>
 // CHECK:           %[[VAL_28:.*]] = arith.addi %[[VAL_27]], %[[VAL_24]] : vector<3x4xi8>
 
 /// w == 1, kw == 0
 // CHECK:           %[[VAL_29:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
 // CHECK:           %[[VAL_30:.*]] = vector.shape_cast %[[V_OUTPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK:           %[[VAL_32:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[VAL_32]] : vector<3x4xi8>
+// CHECK:           %[[B_FILTER_0_1:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[B_FILTER_0_1]] : vector<3x4xi8>
 // CHECK:           %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_30]] : vector<3x4xi8>
 
 /// w == 2, kw == 0
 // CHECK:           %[[VAL_35:.*]] = vector.shape_cast %[[V_INPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
 // CHECK:           %[[VAL_36:.*]] = vector.shape_cast %[[V_OUTPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK:           %[[VAL_38:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[VAL_38]] : vector<3x4xi8>
+// CHECK:           %[[B_FILTER_0_2:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[B_FILTER_0_2]] : vector<3x4xi8>
 // CHECK:           %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_36]] : vector<3x4xi8>
 
 /// w == 3, kw == 1
 // CHECK:           %[[VAL_41:.*]] = vector.shape_cast %[[V_INPUT_3]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK:           %[[VAL_43:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[VAL_43]] : vector<3x4xi8>
+// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[B_FILTER_1]] : vector<3x4xi8>
 // CHECK:           %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_28]] : vector<3x4xi8>
 
 /// w == 4, kw == 1
 // CHECK:           %[[VAL_46:.*]] = vector.shape_cast %[[V_INPUT_4]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK:           %[[VAL_48:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[VAL_48]] : vector<3x4xi8>
+// CHECK:           %[[B_FILTER_1_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[B_FILTER_1_1]] : vector<3x4xi8>
 // CHECK:           %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_34]] : vector<3x4xi8>
 
 /// w == 5, kw == 1
 // CHECK:           %[[VAL_51:.*]] = vector.shape_cast %[[V_INPUT_5]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK:           %[[VAL_53:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[VAL_53]] : vector<3x4xi8>
+// CHECK:           %[[B_FILTER_1_2:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[B_FILTER_1_2]] : vector<3x4xi8>
 // CHECK:           %[[VAL_55:.*]] = arith.addi %[[VAL_54]], %[[VAL_40]] : vector<3x4xi8>
 
 /// w == 6, kw == 2
 // CHECK:           %[[VAL_56:.*]] = vector.shape_cast %[[V_INPUT_6]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK:           %[[VAL_58:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[VAL_58]] : vector<3x4xi8>
+// CHECK:           %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[B_FILTER_2]] : vector<3x4xi8>
 // CHECK:           %[[VAL_60:.*]] = arith.addi %[[VAL_59]], %[[VAL_45]] : vector<3x4xi8>
 
 /// w == 7, kw == 2
 // CHECK:           %[[VAL_61:.*]] = vector.shape_cast %[[VAL_60]] : vector<3x4xi8> to vector<3x1x4xi8>
 // CHECK:           %[[VAL_62:.*]] = vector.shape_cast %[[V_INPUT_7]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK:           %[[VAL_64:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[VAL_64]] : vector<3x4xi8>
+// CHECK:           %[[B_FILTER_2_1:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[B_FILTER_2_1]] : vector<3x4xi8>
 // CHECK:           %[[VAL_66:.*]] = arith.addi %[[VAL_65]], %[[VAL_50]] : vector<3x4xi8>
 
 /// w == 8, kw == 2
 // CHECK:           %[[VAL_67:.*]] = vector.shape_cast %[[VAL_66]] : vector<3x4xi8> to vector<3x1x4xi8>
 // CHECK:           %[[VAL_68:.*]] = vector.shape_cast %[[V_INPUT_8]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK:           %[[VAL_70:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK:           %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[VAL_70]] : vector<3x4xi8>
+// CHECK:           %[[B_FILTER_2_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[B_FILTER_2_2]] : vector<3x4xi8>
 // CHECK:           %[[VAL_72:.*]] = arith.addi %[[VAL_71]], %[[VAL_55]] : vector<3x4xi8>
 
 // Write the result back.


        


More information about the Mlir-commits mailing list