[Mlir-commits] [mlir] [Linalg] Support i1 data type in matchConvolutionOpOfType utility (PR #176704)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 19 00:21:53 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Abhishek Varma (Abhishek-Varma)
<details>
<summary>Changes</summary>
-- Extend bodyMatcherForConvolutionOps to recognize arith.ori/arith.andi
for i1 element types (in addition to add/mul for integer/float types)
for accumulation and multiplication.
-- Similarly, extend bodyMatcherForSumPoolOps to recognize arith.ori for
i1 accumulation (in addition to add for integer/float types).
Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>
---
Full diff: https://github.com/llvm/llvm-project/pull/176704.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+20-8)
- (modified) mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir (+26)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index daf02442bb21a..a1ee6b307caf5 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -320,28 +320,37 @@ static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
/// %out + (%lhs * %rhs)
/// where: %lhs, %rhs and %out are block arguments and
/// %lhs and %rhs can have optional upcast operation.
+/// For i1 element types, the pattern matches:
+/// %out | (%lhs & %rhs)
+/// using arith.ori for accumulation and arith.andi for multiplication.
/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
/// %input - %input_scalar
/// where, %input_scalar can have optional upcast operation.
static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
bool containsZeroPointOffset = false) {
- Operation *addOp = yieldVal.getDefiningOp();
- if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
- return false;
+ bool isOrOp = false;
+ Operation *accOp = yieldVal.getDefiningOp();
+ if (!isa_and_present<arith::AddIOp, arith::AddFOp>(accOp)) {
+ if (!isa_and_present<arith::OrIOp>(accOp))
+ return false;
+ isOrOp = true;
+ }
- Operation *mulOp = addOp->getOperand(1).getDefiningOp();
- if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
+ Operation *mulOp = accOp->getOperand(1).getDefiningOp();
+ if (!isOrOp && !isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
+ return false;
+ if (isOrOp && !isa_and_present<arith::AndIOp>(mulOp))
return false;
if (containsZeroPointOffset) {
- return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
+ return bodyMatcherForZeroPointOffsets(accOp, mulOp, body);
}
BlockArgument lhsBlockArg =
getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
BlockArgument rhsBlockArg =
getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
BlockArgument outBlockArg =
- getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(accOp->getOperand(0));
if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
@@ -386,8 +395,11 @@ static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
return bodyMatcherForPoolOps<arith::MinUIOp>(yieldVal, body);
}
+/// Matches sum pooling body pattern. For i1 element types, arith.ori is used
+/// instead of arith.addi/arith.addf for accumulation.
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp, arith::OrIOp>(
+ yieldVal, body);
}
static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex,
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 809a4208f8db0..1c4106cb75b1c 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -73,6 +73,19 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x
// -----
+func.func @conv_2d_nhwc_hwcf_i1(%input: tensor<?x?x?x?xi1>, %filter: tensor<?x?x?x?xi1>, %output: tensor<?x?x?x?xi1>) -> tensor<?x?x?x?xi1> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xi1>, tensor<?x?x?x?xi1>)
+ outs (%output: tensor<?x?x?x?xi1>) -> tensor<?x?x?x?xi1>
+ return %0 : tensor<?x?x?x?xi1>
+}
+// CHECK: @conv_2d_nhwc_hwcf_i1
+// CHECK: linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
func.func @conv_2d_nhwc_hwcf_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %output: tensor<?x?x?x?xi32>, %zp_input: i32, %zp_filter: i32) -> tensor<?x?x?x?xi32> {
%0 = linalg.conv_2d_nhwc_hwcf_q
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
@@ -429,6 +442,19 @@ func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32
// -----
+func.func @pooling_nhwc_max_i1(%input: tensor<?x?x?x?xi1>, %filter: tensor<?x?xi1>, %output: tensor<?x?x?x?xi1>) -> tensor<?x?x?x?xi1> {
+ %0 = linalg.pooling_nhwc_max
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xi1>, tensor<?x?xi1>)
+ outs (%output: tensor<?x?x?x?xi1>) -> tensor<?x?x?x?xi1>
+ return %0 : tensor<?x?x?x?xi1>
+}
+// CHECK: @pooling_nhwc_max_i1
+// CHECK: linalg.pooling_nhwc_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%0 = linalg.pooling_nhwc_min
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
``````````
</details>
https://github.com/llvm/llvm-project/pull/176704
More information about the Mlir-commits
mailing list