[Mlir-commits] [mlir] [Linalg] Support i1 data type in matchConvolutionOpOfType utility (PR #176704)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 19 00:21:53 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Abhishek Varma (Abhishek-Varma)

<details>
<summary>Changes</summary>

-- Extend bodyMatcherForConvolutionOps to recognize arith.ori/arith.andi
   for i1 element types (in addition to add/mul for integer/float types)
   for accumulation and multiplication.
-- Similarly, extend bodyMatcherForSumPoolOps to recognize arith.ori for
   i1 accumulation (in addition to add for integer/float types).

Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>

---
Full diff: https://github.com/llvm/llvm-project/pull/176704.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+20-8) 
- (modified) mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir (+26) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index daf02442bb21a..a1ee6b307caf5 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -320,28 +320,37 @@ static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
 ///     %out + (%lhs * %rhs)
 ///   where: %lhs, %rhs and %out are block arguments and
 ///          %lhs and %rhs can have optional upcast operation.
+/// For i1 element types, the pattern matches:
+///     %out | (%lhs & %rhs)
+///   using arith.ori for accumulation and arith.andi for multiplication.
 /// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
 ///       %input - %input_scalar
 ///          where, %input_scalar can have optional upcast operation.
 static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
                                          bool containsZeroPointOffset = false) {
-  Operation *addOp = yieldVal.getDefiningOp();
-  if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
-    return false;
+  bool isOrOp = false;
+  Operation *accOp = yieldVal.getDefiningOp();
+  if (!isa_and_present<arith::AddIOp, arith::AddFOp>(accOp)) {
+    if (!isa_and_present<arith::OrIOp>(accOp))
+      return false;
+    isOrOp = true;
+  }
 
-  Operation *mulOp = addOp->getOperand(1).getDefiningOp();
-  if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
+  Operation *mulOp = accOp->getOperand(1).getDefiningOp();
+  if (!isOrOp && !isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
+    return false;
+  if (isOrOp && !isa_and_present<arith::AndIOp>(mulOp))
     return false;
 
   if (containsZeroPointOffset) {
-    return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
+    return bodyMatcherForZeroPointOffsets(accOp, mulOp, body);
   }
   BlockArgument lhsBlockArg =
       getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
   BlockArgument rhsBlockArg =
       getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
   BlockArgument outBlockArg =
-      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(accOp->getOperand(0));
   if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
       lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
       outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
@@ -386,8 +395,11 @@ static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
   return bodyMatcherForPoolOps<arith::MinUIOp>(yieldVal, body);
 }
 
+/// Matches sum pooling body pattern. For i1 element types, arith.ori is used
+/// instead of arith.addi/arith.addf for accumulation.
 static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+  return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp, arith::OrIOp>(
+      yieldVal, body);
 }
 
 static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex,
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 809a4208f8db0..1c4106cb75b1c 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -73,6 +73,19 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x
 
 // -----
 
+func.func @conv_2d_nhwc_hwcf_i1(%input: tensor<?x?x?x?xi1>, %filter: tensor<?x?x?x?xi1>, %output: tensor<?x?x?x?xi1>) -> tensor<?x?x?x?xi1> {
+  %0 = linalg.conv_2d_nhwc_hwcf
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xi1>, tensor<?x?x?x?xi1>)
+         outs (%output: tensor<?x?x?x?xi1>) -> tensor<?x?x?x?xi1>
+  return %0 : tensor<?x?x?x?xi1>
+}
+//      CHECK: @conv_2d_nhwc_hwcf_i1
+//      CHECK:   linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
 func.func @conv_2d_nhwc_hwcf_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %output: tensor<?x?x?x?xi32>, %zp_input: i32, %zp_filter: i32) -> tensor<?x?x?x?xi32> {
   %0 = linalg.conv_2d_nhwc_hwcf_q
          {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
@@ -429,6 +442,19 @@ func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32
 
 // -----
 
+func.func @pooling_nhwc_max_i1(%input: tensor<?x?x?x?xi1>, %filter: tensor<?x?xi1>, %output: tensor<?x?x?x?xi1>) -> tensor<?x?x?x?xi1> {
+  %0 = linalg.pooling_nhwc_max
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xi1>, tensor<?x?xi1>)
+         outs (%output: tensor<?x?x?x?xi1>) -> tensor<?x?x?x?xi1>
+  return %0 : tensor<?x?x?x?xi1>
+}
+//      CHECK: @pooling_nhwc_max_i1
+//      CHECK:   linalg.pooling_nhwc_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
 func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
   %0 = linalg.pooling_nhwc_min
          {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}

``````````

</details>


https://github.com/llvm/llvm-project/pull/176704


More information about the Mlir-commits mailing list