[Mlir-commits] [mlir] [mlir][Linalg] Add dropI64ArrayAttrElem helper to reduce duplication (NFC) (PR #174279)
Nick Kreeger
llvmlistbot at llvm.org
Sat Jan 17 07:41:56 PST 2026
https://github.com/nkreeger updated https://github.com/llvm/llvm-project/pull/174279
>From 5c79cbf4e008124b24f2b98fa9b6c0f6b1414538 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Sat, 3 Jan 2026 08:58:30 -0600
Subject: [PATCH 1/2] [mlir][Linalg] Add dropI64ArrayAttrElem helper to
consolidate duplicate code (NFC)
Adds a utility function to drop an element from a DenseIntElementsAttr at a
specified index. This consolidates duplicate code in Conv2D rank-reduction
transformations.
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 7 +++++
.../Dialect/Linalg/Transforms/Transforms.cpp | 27 ++++++-------------
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 8 ++++++
3 files changed, 23 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 9da01f30b52d2..ae3d2b2e5ddfa 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -392,6 +392,13 @@ SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
return getPrunedAttributeList(op, elidedAttrs);
}
+/// Creates a new I64 array attribute by dropping the element at the specified
+/// index from the input array attribute. This is useful for rank reduction
+/// operations where dimensions need to be removed from strides, dilations, etc.
+DenseIntElementsAttr dropI64ArrayAttrElem(OpBuilder &builder,
+ DenseIntElementsAttr attr,
+ unsigned index);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 96cc378f6c21a..37d6f6d079799 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1483,16 +1483,10 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
rewriter, loc, output, newOutputType);
// Rank-reduce strides and dilations too.
- // TODO: dropDim 1-liner helper.
- auto strides =
- llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
- strides.erase(strides.begin() + (removeH ? 0 : 1));
- auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
- auto dilations =
- llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
- dilations.erase(dilations.begin() + (removeH ? 0 : 1));
- auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
+ auto stridesAttr =
+ dropI64ArrayAttrElem(rewriter, convOp.getStrides(), removeH ? 0 : 1);
+ auto dilationsAttr =
+ dropI64ArrayAttrElem(rewriter, convOp.getDilations(), removeH ? 0 : 1);
auto conv1DOp = Conv1DOp::create(
rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
@@ -1571,15 +1565,10 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
rewriter, loc, output, newOutputType);
// Rank-reduce strides and dilations too.
- // TODO: dropDim 1-liner helper.
- auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
- strides.erase(strides.begin() + (removeH ? 0 : 1));
- auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
- auto dilations =
- llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
- dilations.erase(dilations.begin() + (removeH ? 0 : 1));
- auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
+ auto stridesAttr =
+ dropI64ArrayAttrElem(rewriter, convOp.getStrides(), removeH ? 0 : 1);
+ auto dilationsAttr =
+ dropI64ArrayAttrElem(rewriter, convOp.getDilations(), removeH ? 0 : 1);
auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2718124251c18..a64b82c183c1d 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -2573,5 +2573,13 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
return reassociation;
}
+DenseIntElementsAttr dropI64ArrayAttrElem(OpBuilder &builder,
+ DenseIntElementsAttr attr,
+ unsigned index) {
+ auto values = llvm::to_vector<4>(attr.getValues<int64_t>());
+ values.erase(values.begin() + index);
+ return builder.getI64VectorAttr(values);
+}
+
} // namespace linalg
} // namespace mlir
>From 7b47377e7db568009948ad7c9063c54539ea6b32 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at microsoft.com>
Date: Sat, 17 Jan 2026 15:41:41 +0000
Subject: [PATCH 2/2] Fix the build.
---
.../lib/Dialect/Linalg/Transforms/Transforms.cpp | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index b3bf6d16111f9..b8d58bf791d30 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1487,10 +1487,10 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
rewriter, loc, output, newOutputType);
// Rank-reduce strides and dilations too.
- auto stridesAttr =
- dropI64ArrayAttrElem(rewriter, convOp.getStrides(), removeH ? 0 : 1);
- auto dilationsAttr =
- dropI64ArrayAttrElem(rewriter, convOp.getDilations(), removeH ? 0 : 1);
+ strides.erase(strides.begin() + (removeH ? 0 : 1));
+ auto stridesAttr = rewriter.getI64VectorAttr(strides);
+ dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+ auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
auto conv1DOp = Conv1DOp::create(
rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
@@ -1577,10 +1577,10 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
rewriter, loc, output, newOutputType);
// Rank-reduce strides and dilations too.
- auto stridesAttr =
- dropI64ArrayAttrElem(rewriter, convOp.getStrides(), removeH ? 0 : 1);
- auto dilationsAttr =
- dropI64ArrayAttrElem(rewriter, convOp.getDilations(), removeH ? 0 : 1);
+ strides.erase(strides.begin() + (removeH ? 0 : 1));
+ auto stridesAttr = rewriter.getI64VectorAttr(strides);
+ dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+ auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
More information about the Mlir-commits
mailing list