[Mlir-commits] [mlir] [mlir][linalg] Vectorization support for convolution of i1 type (PR #109480)

Nirvedh Meshram llvmlistbot at llvm.org
Tue Sep 24 09:48:01 PDT 2024


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/109480

>From e5fc9412991ea5331c1d7891e131bc22fe0fccbe Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Wed, 18 Sep 2024 22:31:23 +0000
Subject: [PATCH 1/4] [mlir][linalg] Vectorization support for convolution of
 i1 type

---
 .../Linalg/Transforms/Vectorization.cpp       | 34 +++++++++++++------
 .../Dialect/Linalg/vectorize-convolution.mlir | 29 ++++++++++++++++
 2 files changed, 52 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a376afa5ddab12..1cdf937742fd2e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2947,12 +2947,14 @@ struct Conv1DGenerator
 
     if (!setOperKind(reduceOp))
       return;
-    auto maybeKind = getCombinerOpKind(reduceOp);
-    if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
+    maybeKind = getCombinerOpKind(reduceOp);
+    // Typically convolution will have a `Add` CombiningKind but for i1 type it
+    // can get strength reduced to `OR` which is also supported.
+    if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+                        *maybeKind != vector::CombiningKind::OR) &&
                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
       return;
     }
-
     auto rhsRank = rhsShapedType.getRank();
     switch (oper) {
     case Conv:
@@ -3156,9 +3158,9 @@ struct Conv1DGenerator
                                                    lhsVals[linearIndex(kw, w)],
                                                    rhsVals[kw], resVals[w]);
           } else {
-            resVals[w] = conv1dSliceAsContraction(rewriter, loc,
-                                                  lhsVals[linearIndex(kw, w)],
-                                                  rhsVals[kw], resVals[w]);
+            resVals[w] = conv1dSliceAsContraction(
+                rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+                resVals[w], maybeKind);
           }
           break;
         case Pool:
@@ -3226,18 +3228,24 @@ struct Conv1DGenerator
   }
 
   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
-  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
-                                 Value lhs, Value rhs, Value res) {
+  Value
+  conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs,
+                           Value rhs, Value res,
+                           std::optional<vector::CombiningKind> maybeKind) {
     vector::IteratorType par = vector::IteratorType::parallel;
     vector::IteratorType red = vector::IteratorType::reduction;
     AffineExpr n, w, f, c;
     bindDims(ctx, n, w, f, c);
     lhs = promote(rewriter, loc, lhs, res.getType());
     rhs = promote(rewriter, loc, rhs, res.getType());
-    return rewriter.create<vector::ContractionOp>(
+    auto ContrationOp = rewriter.create<vector::ContractionOp>(
         loc, lhs, rhs, res,
         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
+    if (maybeKind) {
+      ContrationOp.setKind(*maybeKind);
+    }
+    return ContrationOp;
   }
 
   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
@@ -3627,6 +3635,7 @@ struct Conv1DGenerator
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
+  std::optional<vector::CombiningKind> maybeKind;
 
   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
   // Returns true iff it is a valid conv/pooling op.
@@ -3642,7 +3651,8 @@ struct Conv1DGenerator
     switch (numBlockArguments) {
     case 1: {
       // Will be convolution if feeder is a MulOp.
-      // Otherwise, if it can be pooling.
+      // A strength reduced version of MulOp for i1 type is AndOp which is also
+      // supported. Otherwise, it can be pooling.
       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
                                          llvm::IsaPred<BlockArgument>);
       Operation *feedOp = (*feedValIt).getDefiningOp();
@@ -3650,7 +3660,9 @@ struct Conv1DGenerator
         oper = Pool;
         isPoolExt = true;
         poolExtOp = feedOp->getName().getIdentifier();
-      } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
+      } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+                    (isa<arith::AndIOp>(feedOp) &&
+                     feedOp->getResultTypes()[0].isInteger(1))) &&
                    llvm::all_of(feedOp->getOperands(), [](Value v) {
                      if (isa<BlockArgument>(v))
                        return true;
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 93e36a69567bd5..84e790954b4d02 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -654,6 +654,35 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
 // CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
 // CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
 
+// -----
+
+func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+  %0 = linalg.conv_1d_ncw_fcw 
+  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+   ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>) 
+   outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1>
+  return %0 : tensor<1x8x6xi1>
+}
+
+// CHECK-LABEL:  func @conv2d_i1_i1_i1
+// CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+// CHECK-DAG:    %[[I0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[FALSE:.+]] = arith.constant false
+// CHECK-DAG:    %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG:    %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG:    %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG:    %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK-DAG:    %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1>
+// CHECK-DAG:    %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK:        %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1>
+// CHECK:        %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<or>} 
+// CHECK-SAME:   %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1>
+// CHECK:        %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1>
+// CHECK:        %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]] 
+// CHECK:        return %[[RESULT]] : tensor<1x8x6xi1>
+
+
+
 // -----
 
 func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {

>From 6f839f3d5da6a787c8e67cbb9e109deffe1365bd Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 23 Sep 2024 22:14:16 +0000
Subject: [PATCH 2/4] address reviwer comments

---
 .../Linalg/Transforms/Vectorization.cpp       | 32 +++----
 .../Dialect/Linalg/vectorize-convolution.mlir | 83 ++++++++++++-------
 2 files changed, 70 insertions(+), 45 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 1cdf937742fd2e..c06ad50ff4bfe0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2947,14 +2947,17 @@ struct Conv1DGenerator
 
     if (!setOperKind(reduceOp))
       return;
-    maybeKind = getCombinerOpKind(reduceOp);
+    auto maybeKind = getCombinerOpKind(reduceOp);
     // Typically convolution will have a `Add` CombiningKind but for i1 type it
-    // can get strength reduced to `OR` which is also supported.
+    // can get strength reduced to `OR` which is also supported. This strength
+    // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
     if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
                         *maybeKind != vector::CombiningKind::OR) &&
                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
       return;
     }
+    reductionKind = maybeKind.value();
+
     auto rhsRank = rhsShapedType.getRank();
     switch (oper) {
     case Conv:
@@ -3158,9 +3161,9 @@ struct Conv1DGenerator
                                                    lhsVals[linearIndex(kw, w)],
                                                    rhsVals[kw], resVals[w]);
           } else {
-            resVals[w] = conv1dSliceAsContraction(
-                rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
-                resVals[w], maybeKind);
+            resVals[w] = conv1dSliceAsContraction(rewriter, loc,
+                                                  lhsVals[linearIndex(kw, w)],
+                                                  rhsVals[kw], resVals[w]);
           }
           break;
         case Pool:
@@ -3228,24 +3231,20 @@ struct Conv1DGenerator
   }
 
   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
-  Value
-  conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs,
-                           Value rhs, Value res,
-                           std::optional<vector::CombiningKind> maybeKind) {
+  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
+                                 Value lhs, Value rhs, Value res) {
     vector::IteratorType par = vector::IteratorType::parallel;
     vector::IteratorType red = vector::IteratorType::reduction;
     AffineExpr n, w, f, c;
     bindDims(ctx, n, w, f, c);
     lhs = promote(rewriter, loc, lhs, res.getType());
     rhs = promote(rewriter, loc, rhs, res.getType());
-    auto ContrationOp = rewriter.create<vector::ContractionOp>(
+    auto contrationOp = rewriter.create<vector::ContractionOp>(
         loc, lhs, rhs, res,
         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
-    if (maybeKind) {
-      ContrationOp.setKind(*maybeKind);
-    }
-    return ContrationOp;
+    contrationOp.setKind(reductionKind);
+    return contrationOp;
   }
 
   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
@@ -3635,7 +3634,7 @@ struct Conv1DGenerator
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
-  std::optional<vector::CombiningKind> maybeKind;
+  vector::CombiningKind reductionKind;
 
   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
   // Returns true iff it is a valid conv/pooling op.
@@ -3652,7 +3651,8 @@ struct Conv1DGenerator
     case 1: {
       // Will be convolution if feeder is a MulOp.
       // A strength reduced version of MulOp for i1 type is AndOp which is also
-      // supported. Otherwise, it can be pooling.
+      // supported. Otherwise, it can be pooling. This strength reduction logic
+      // is in `buildBinaryFn` helper in the Linalg dialect.
       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
                                          llvm::IsaPred<BlockArgument>);
       Operation *feedOp = (*feedValIt).getDefiningOp();
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 84e790954b4d02..d6cf57fc3c1448 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -61,6 +61,33 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
 
 // -----
 
+// This test is same as above but for i1 type.
+func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<1x3x8xi1>, %output: memref<4x2x8xi1>) {
+  linalg.conv_1d_nwc_wcf
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x6x3xi1>, memref<1x3x8xi1>)
+    outs(%output : memref<4x2x8xi1>)
+  return
+}
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+//      CHECK: func @conv1d_nwc_4x2x8_memref_i1
+/// w == 0, kw == 0
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+
+/// w == 1, kw == 0
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+
+// -----
+
 // The i8i8i32 case is similar to f32 case, so checking one case is enough for
 // test coverage.
 func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: memref<1x3x8xi8>, %output: memref<4x2x8xi32>) {
@@ -324,6 +351,33 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 
 // -----
 
+// This test is same as above but for i1 type.
+func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<8x3x1xi1>, %output: memref<4x8x2xi1>) {
+  linalg.conv_1d_ncw_fcw
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x3x6xi1>, memref<8x3x1xi1>)
+    outs(%output : memref<4x8x2xi1>)
+  return
+}
+
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+//      CHECK: func @conv1d_ncw_4x8x2_memref_i1
+/// w == 0, kw == 0
+//      CHECK:   vector.contract {
+// CHECK-SAME:   kind = #vector.kind<or>
+// CHECK-SAME:   : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+
+/// w == 1, kw == 0
+//      CHECK:   vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+
+// -----
+
 func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) {
   linalg.conv_1d_ncw_fcw
     {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
@@ -654,35 +708,6 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
 // CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
 // CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
 
-// -----
-
-func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
-  %0 = linalg.conv_1d_ncw_fcw 
-  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
-   ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>) 
-   outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1>
-  return %0 : tensor<1x8x6xi1>
-}
-
-// CHECK-LABEL:  func @conv2d_i1_i1_i1
-// CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
-// CHECK-DAG:    %[[I0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[FALSE:.+]] = arith.constant false
-// CHECK-DAG:    %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
-// CHECK-DAG:    %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
-// CHECK-DAG:    %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
-// CHECK-DAG:    %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
-// CHECK-DAG:    %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1>
-// CHECK-DAG:    %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
-// CHECK:        %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1>
-// CHECK:        %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<or>} 
-// CHECK-SAME:   %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1>
-// CHECK:        %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1>
-// CHECK:        %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]] 
-// CHECK:        return %[[RESULT]] : tensor<1x8x6xi1>
-
-
-
 // -----
 
 func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {

>From 37f1f068f1520ceebaf29d6811b47d0167b004fd Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Tue, 24 Sep 2024 14:31:18 +0000
Subject: [PATCH 3/4] add some missing CHECKs

---
 .../Dialect/Linalg/vectorize-convolution.mlir     | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index d6cf57fc3c1448..2d3294749e0a0d 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -78,13 +78,15 @@ func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<
 //      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME:     : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+// CHECK-SAME:       kind = #vector.kind<or>
+// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
 
 /// w == 1, kw == 0
 //      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME:     : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+// CHECK-SAME:       kind = #vector.kind<or>
+// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
 
 // -----
 
@@ -367,14 +369,17 @@ func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<
 //      CHECK: func @conv1d_ncw_4x8x2_memref_i1
 /// w == 0, kw == 0
 //      CHECK:   vector.contract {
-// CHECK-SAME:   kind = #vector.kind<or>
-// CHECK-SAME:   : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<or>
+// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
 
 /// w == 1, kw == 0
 //      CHECK:   vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME:     : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
+// CHECK-SAME:       kind = #vector.kind<or>
+// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
 
 // -----
 

>From 9ef7c15faf56379da242b6a1bc6d6503982c4f55 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Tue, 24 Sep 2024 16:47:48 +0000
Subject: [PATCH 4/4] reviwer comments

---
 mlir/test/Dialect/Linalg/vectorize-convolution.mlir | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 2d3294749e0a0d..7f4b9b986c81b4 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -39,6 +39,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
 //      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<add>
 // CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
@@ -46,6 +47,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
 //      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<add>
 // CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
@@ -61,7 +63,8 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
 
 // -----
 
-// This test is same as above but for i1 type.
+// This test is same as above but for i1 type with the only difference being that
+// the combining kind for `vector.contract` is `OR`.
 func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<1x3x8xi1>, %output: memref<4x2x8xi1>) {
   linalg.conv_1d_nwc_wcf
     {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
@@ -328,6 +331,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 //      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<add>
 // CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
@@ -335,6 +339,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 //      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:       kind = #vector.kind<add>
 // CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
 
@@ -353,7 +358,8 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 
 // -----
 
-// This test is same as above but for i1 type.
+// This test is same as above but for i1 type with the only difference being that
+// the combining kind for `vector.contract` is `OR`.
 func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<8x3x1xi1>, %output: memref<4x8x2xi1>) {
   linalg.conv_1d_ncw_fcw
     {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}



More information about the Mlir-commits mailing list