[Mlir-commits] [mlir] 234193b - [mlir][linalg] Vectorization support for convolution of i1 type (#109480)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 24 10:25:11 PDT 2024


Author: Nirvedh Meshram
Date: 2024-09-24T12:24:59-05:00
New Revision: 234193bae6cf8b19703c6e543b100517bb99a9f7

URL: https://github.com/llvm/llvm-project/commit/234193bae6cf8b19703c6e543b100517bb99a9f7
DIFF: https://github.com/llvm/llvm-project/commit/234193bae6cf8b19703c6e543b100517bb99a9f7.diff

LOG: [mlir][linalg] Vectorization support for convolution of i1 type (#109480)

Normally convolutions present with the following linalg op region
```
^bb0(%arg14: i4, %arg15: i4, %arg16: i4):
  %17 = arith.muli %arg14, %arg15 : i4
  %18 = arith.addi %arg16, %17 : i4
  linalg.yield %18 : i4
  ```
  However, for i1 due to strength reduction we get something like
  ```
  ^bb0(%arg14: i1, %arg15: i1, %arg16: i1):
  %17 = arith.andi %arg14, %arg15 : i1
  %18 = arith.ori %arg16, %17 : i1
  linalg.yield %18 : i1
  ```
  This PR updates the logic to support this region for i1 types.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c332307da4d333..fa20001f661822 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2987,10 +2987,15 @@ struct Conv1DGenerator
     if (!setOperKind(reduceOp))
       return;
     auto maybeKind = getCombinerOpKind(reduceOp);
-    if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
+    // Typically convolution will have a `Add` CombiningKind but for i1 type it
+    // can get strength reduced to `OR` which is also supported. This strength
+    // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
+    if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+                        *maybeKind != vector::CombiningKind::OR) &&
                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
       return;
     }
+    reductionKind = maybeKind.value();
 
     auto rhsRank = rhsShapedType.getRank();
     switch (oper) {
@@ -3273,10 +3278,12 @@ struct Conv1DGenerator
     bindDims(ctx, n, w, f, c);
     lhs = promote(rewriter, loc, lhs, res.getType());
     rhs = promote(rewriter, loc, rhs, res.getType());
-    return rewriter.create<vector::ContractionOp>(
+    auto contrationOp = rewriter.create<vector::ContractionOp>(
         loc, lhs, rhs, res,
         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
+    contrationOp.setKind(reductionKind);
+    return contrationOp;
   }
 
   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
@@ -3666,6 +3673,7 @@ struct Conv1DGenerator
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
+  vector::CombiningKind reductionKind;
 
   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
   // Returns true iff it is a valid conv/pooling op.
@@ -3681,7 +3689,9 @@ struct Conv1DGenerator
     switch (numBlockArguments) {
     case 1: {
       // Will be convolution if feeder is a MulOp.
-      // Otherwise, if it can be pooling.
+      // A strength reduced version of MulOp for i1 type is AndOp which is also
+      // supported. Otherwise, it can be pooling. This strength reduction logic
+      // is in `buildBinaryFn` helper in the Linalg dialect.
       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
                                          llvm::IsaPred<BlockArgument>);
       Operation *feedOp = (*feedValIt).getDefiningOp();
@@ -3689,7 +3699,9 @@ struct Conv1DGenerator
         oper = Pool;
         isPoolExt = true;
         poolExtOp = feedOp->getName().getIdentifier();
-      } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
+      } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+                    (isa<arith::AndIOp>(feedOp) &&
+                     feedOp->getResultTypes()[0].isInteger(1))) &&
                    llvm::all_of(feedOp->getOperands(), [](Value v) {
                      if (isa<BlockArgument>(v))
                        return true;

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 93e36a69567bd5..7f4b9b986c81b4 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -39,6 +39,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
 //      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<add>
 // CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
@@ -46,6 +47,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
 //      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<add>
 // CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
@@ -61,6 +63,36 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
 
 // -----
 
+// This test is same as above but for i1 type with the only 
diff erence being that
+// the combining kind for `vector.contract` is `OR`.
+func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<1x3x8xi1>, %output: memref<4x2x8xi1>) {
+  linalg.conv_1d_nwc_wcf
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x6x3xi1>, memref<1x3x8xi1>)
+    outs(%output : memref<4x2x8xi1>)
+  return
+}
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+//      CHECK: func @conv1d_nwc_4x2x8_memref_i1
+/// w == 0, kw == 0
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<or>
+// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+
+/// w == 1, kw == 0
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<or>
+// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+
+// -----
+
 // The i8i8i32 case is similar to f32 case, so checking one case is enough for
 // test coverage.
 func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: memref<1x3x8xi8>, %output: memref<4x2x8xi32>) {
@@ -299,6 +331,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 //      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<add>
 // CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
@@ -306,6 +339,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 //      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<add>
 // CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
@@ -324,6 +358,37 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 
 // -----
 
+// This test is same as above but for i1 type with the only 
diff erence being that
+// the combining kind for `vector.contract` is `OR`.
+func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<8x3x1xi1>, %output: memref<4x8x2xi1>) {
+  linalg.conv_1d_ncw_fcw
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x3x6xi1>, memref<8x3x1xi1>)
+    outs(%output : memref<4x8x2xi1>)
+  return
+}
+
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+//      CHECK: func @conv1d_ncw_4x8x2_memref_i1
+/// w == 0, kw == 0
+//      CHECK:   vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<or>
+// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+
+/// w == 1, kw == 0
+//      CHECK:   vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<or>
+// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+
+// -----
+
 func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) {
   linalg.conv_1d_ncw_fcw
     {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}


        


More information about the Mlir-commits mailing list