[Mlir-commits] [mlir] [mlir][linalg] Allow extra DPS inputs in convolution dim inference (PR #198462)
Ahmad Tameem
llvmlistbot at llvm.org
Mon May 25 11:50:31 PDT 2026
https://github.com/Tameem-10xE updated https://github.com/llvm/llvm-project/pull/198462
>From 45262ed502b5024f32ededee90d361f6802a012c Mon Sep 17 00:00:00 2001
From: Tameem-10xE <ahmad.tameem at 10xengineers.ai>
Date: Mon, 18 May 2026 21:35:11 +0500
Subject: [PATCH 1/2] [mlir][linalg] Allow extra DPS inputs in convolution dim
inference
Signed-off-by: Tameem-10xE <ahmad.tameem at 10xengineers.ai>
---
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 2ba77cea8f16e..b761782541ff3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -884,7 +884,7 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
FailureOr<ConvolutionDimensions>
mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) {
- if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+ if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() < 2)
return failure();
auto indexingMaps = linalgOp.getIndexingMapsArray();
>From 3ecc71cbecc05223ae8c6a797c588bb59d2524f3 Mon Sep 17 00:00:00 2001
From: Tameem-10xE <ahmad.tameem at 10xengineers.ai>
Date: Mon, 25 May 2026 23:45:01 +0500
Subject: [PATCH 2/2] [mlir][linalg] Add inferConvolutionDims test for
quantized convolution with zero-points
Signed-off-by: Tameem-10xE <ahmad.tameem at 10xengineers.ai>
---
.../Linalg/InferConvolutionDimsTest.cpp | 117 ++++++++++++++++++
1 file changed, 117 insertions(+)
diff --git a/mlir/unittests/Dialect/Linalg/InferConvolutionDimsTest.cpp b/mlir/unittests/Dialect/Linalg/InferConvolutionDimsTest.cpp
index 7f495a4859064..45ef5100931ba 100644
--- a/mlir/unittests/Dialect/Linalg/InferConvolutionDimsTest.cpp
+++ b/mlir/unittests/Dialect/Linalg/InferConvolutionDimsTest.cpp
@@ -130,6 +130,82 @@ createConv2DWithSwappedFilterLoops(OpBuilder &builder,
});
}
+/// Creates a Quantize 2D Convolution using input and filter layout
+/// but with extra scalar input/filter zero-point operands.
+///
+/// Loop order:
+/// d0 = output height (oh), parallel
+/// d1 = output width (ow), parallel
+/// d2 = kernel height (kh), reduction
+/// d3 = kernel width (kw), reduction
+///
+/// Indexing maps:
+/// input: (d0 + d2, d1 + d3)
+/// filter: (d2, d3)
+/// input zp: scalar
+/// filter zp: scalar
+/// output: (d0, d1)
+///
+/// Semantic pairing: d0 <-> d2, d1 <-> d3
+static linalg::GenericOp createQConv2DOp(OpBuilder &builder, int64_t oh,
+ int64_t ow, int64_t kh,
+ int64_t kw) {
+ Location loc = builder.getUnknownLoc();
+ MLIRContext *ctx = builder.getContext();
+
+ auto i8Type = builder.getI8Type();
+ auto i32Type = builder.getI32Type();
+
+ int64_t ih = oh + kh - 1;
+ int64_t iw = ow + kw - 1;
+
+ auto inputType = RankedTensorType::get({ih, iw}, i8Type);
+ auto filterType = RankedTensorType::get({kh, kw}, i8Type);
+ auto outputType = RankedTensorType::get({oh, ow}, i32Type);
+
+ Value input = tensor::EmptyOp::create(builder, loc, inputType.getShape(),
+ inputType.getElementType());
+ Value filter = tensor::EmptyOp::create(builder, loc, filterType.getShape(),
+ filterType.getElementType());
+ Value inputZeroPoint = arith::ConstantIntOp::create(builder, loc, 0, 32);
+ Value filterZeroPoint = arith::ConstantIntOp::create(builder, loc, 0, 32);
+ Value output = tensor::EmptyOp::create(builder, loc, outputType.getShape(),
+ outputType.getElementType());
+
+ AffineExpr d0, d1, d2, d3;
+ bindDims(ctx, d0, d1, d2, d3);
+
+ auto inputMap = AffineMap::get(4, 0, {d0 + d2, d1 + d3}, ctx);
+ auto filterMap = AffineMap::get(4, 0, {d2, d3}, ctx);
+ auto scalarMap = AffineMap::get(4, 0, ArrayRef<AffineExpr>{}, ctx);
+ auto outputMap = AffineMap::get(4, 0, {d0, d1}, ctx);
+
+ SmallVector<AffineMap> indexingMaps = {inputMap, filterMap, scalarMap,
+ scalarMap, outputMap};
+
+ SmallVector<utils::IteratorType> iterTypes = {
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::reduction, utils::IteratorType::reduction};
+
+ return linalg::GenericOp::create(
+ builder, loc, outputType,
+ ValueRange{input, filter, inputZeroPoint, filterZeroPoint},
+ ValueRange{output}, indexingMaps, iterTypes,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value inputI32 =
+ arith::ExtSIOp::create(b, loc, b.getI32Type(), args[0]);
+ Value filterI32 =
+ arith::ExtSIOp::create(b, loc, b.getI32Type(), args[1]);
+
+ inputI32 = arith::SubIOp::create(b, loc, inputI32, args[2]);
+ filterI32 = arith::SubIOp::create(b, loc, filterI32, args[3]);
+
+ Value mul = arith::MulIOp::create(b, loc, inputI32, filterI32);
+ Value add = arith::AddIOp::create(b, loc, args[4], mul);
+ linalg::YieldOp::create(b, loc, add);
+ });
+}
+
TEST_F(InferConvolutionDimsTest, Conv2DPairing) {
// Use non-square kernel to ensure dimension swapping is tested properly.
const int64_t oh = 6, ow = 12, kh = 3, kw = 5;
@@ -176,4 +252,45 @@ TEST_F(InferConvolutionDimsTest, Conv2DPairing) {
<< "outputImage[1]=1 should pair with filterLoop[1]=2 (ow <-> kw)";
}
+TEST_F(InferConvolutionDimsTest, QConv2DWithZeroPoints) {
+ // Use non-square kernel to ensure dimension swapping is tested properly.
+ const int64_t oh = 6, ow = 12, kh = 3, kw = 5;
+
+ // Create a module to own all test operations and ensure proper cleanup.
+ OpBuilder builder(ctx.get());
+ OwningOpRef<ModuleOp> module = ModuleOp::create(builder.getUnknownLoc());
+ builder.setInsertionPointToStart(module->getBody());
+
+ // Create Quantize ConvOp with two extra scalar zero-point inputs.
+ linalg::GenericOp qConvOp = createQConv2DOp(builder, oh, ow, kh, kw);
+
+ // The qconv op should have:
+ // input, filter, input zero point, filter zero point
+ ASSERT_EQ(qConvOp.getNumDpsInputs(), 4u);
+ ASSERT_EQ(qConvOp.getNumDpsInits(), 1u);
+
+ auto indexingMaps = qConvOp.getIndexingMapsArray();
+ ASSERT_EQ(indexingMaps.size(),
+ static_cast<size_t>(qConvOp.getNumDpsInputs() +
+ qConvOp.getNumDpsInits()));
+
+ // The two extra quantized conv operands must be scalar inputs.
+ EXPECT_EQ(indexingMaps[2].getNumResults(), 0u);
+ EXPECT_EQ(indexingMaps[3].getNumResults(), 0u);
+ EXPECT_EQ(indexingMaps[2].getNumDims(), 4u);
+ EXPECT_EQ(indexingMaps[3].getNumDims(), 4u);
+
+ FailureOr<ConvolutionDimensions> qConvDims = inferConvolutionDims(qConvOp);
+ ASSERT_TRUE(succeeded(qConvDims));
+ ASSERT_EQ(qConvDims->outputImage.size(), 2u);
+ ASSERT_EQ(qConvDims->filterLoop.size(), 2u);
+
+ // Standard pairing: outputImage=[0,1], filterLoop=[2,3]
+ // d0 <-> d2 (oh <-> kh), d1 <-> d3 (ow <-> kw)
+ EXPECT_EQ(qConvDims->outputImage[0], 0u);
+ EXPECT_EQ(qConvDims->outputImage[1], 1u);
+ EXPECT_EQ(qConvDims->filterLoop[0], 2u);
+ EXPECT_EQ(qConvDims->filterLoop[1], 3u);
+}
+
} // namespace
More information about the Mlir-commits
mailing list