[Mlir-commits] [mlir] [mlir][tensor] Add new builders for insert_slice/extract_slice Ops (nfc) (PR #169533)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Nov 25 09:36:41 PST 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/169533
Adds new builders for `tensor.insert_slice` and `tensor.extract_slice`
Ops for which the _offsets_ and the _strides_ are all 0s and 1s,
respecitvely. This allows us to write:
```cpp
tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeSizes);
```
instead of:
```cpp
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeOffsets,
writeSizes, writeStrides);
```
>From dd275ad51edbbdcc2a78c7114560e2d9ac42b3ca Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 25 Nov 2025 17:30:54 +0000
Subject: [PATCH] [mlir][tensor] Add new builders for
insert_slice/extract_slice Ops (nfc)
Adds new builders for `tensor.insert_slice` and `tensor.extract_slice`
Ops for which the _offsets_ and the _strides_ are all 0s and 1s,
respecitvely. This allows us to write:
```cpp
tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeSizes);
```
instead of:
```cpp
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeOffsets,
writeSizes, writeStrides);
```
---
.../mlir/Dialect/Tensor/IR/TensorOps.td | 12 ++++++-
.../Dialect/Linalg/Transforms/Transforms.cpp | 31 +++----------------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 25 +++++++++++++++
3 files changed, 41 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index ac40d5e454281..35d2b6007c628 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -471,6 +471,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
// a Range vector.
OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
+ // result type, offsets set to 0 and strides set to 1.
+ OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
+ "ArrayRef<OpFoldResult>":$sizes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -930,7 +935,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
// a Range vector and inferred result type.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<Range>":$ranges,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
+ // to 0, strides set to 1 and inferred result type.
+ OpBuilder<(ins "Value":$source, "Value":$dest,
+ "ArrayRef<OpFoldResult>":$sizes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 027268cc20ddd..67e2b9f8d6077 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1167,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
"this is not supported ATM!");
}
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
- Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
int64_t srcRank = packOp.getSourceRank();
- int64_t destRank = packOp.getDestRank();
// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
@@ -1262,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
writeSizes.push_back(tileSizeOfr);
}
- // TODO: Add a constructor for tensor.insert_slice that doesn't require
- // strides nor offsets.
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
-
auto insert = tensor::InsertSliceOp::create(
- rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
- writeOffsets, writeSizes, writeStrides);
+ rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
// 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());
@@ -1279,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
- int64_t srcRank = unpackOp.getSourceRank();
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
@@ -1296,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
Value source = unpackOp.getSource();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
unpackOp.getDimAndTileMapping();
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
// The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
@@ -1307,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
// outer-tiled-dims being all 1), this will be
// [ outer-untiled-dims, tile-sizes ]
SmallVector<OpFoldResult> extractSliceSizes;
- // The offset and strides attributes for ExtractSliceOp.
- SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
- SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
// Shape for EmptyOp that's used as the init value for TransposeOp below.
// This should be:
@@ -1364,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
Type elemType = unpackOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
Value innerTile = tensor::ExtractSliceOp::create(
- rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
- extractSliceSizes, extractSliceStrides);
+ rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
// 2. Transpose the tile to match the outer corresponding tile order.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
@@ -1381,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
// 3. Handle in-complete tiles if needed. It truncates trailing data from the
// transposed tile.
- int numLoops = shapeForEmptyOp.size();
- SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
- SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
SmallVector<OpFoldResult> tileSizes;
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
for (auto i : llvm::seq<unsigned>(0, destRank)) {
@@ -1393,13 +1375,11 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
}
auto partialTile =
- tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
- tileOffsets, tileSizes, tileStrides);
+ tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
+ transposedOp.getResult()[0], tileSizes);
// 4. Insert the result to the destination tensor.
SmallVector<OpFoldResult> writeSizes;
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
for (int i = 0, idx = 0; i < destRank; ++i) {
if (dimAndTileMapping.count(i) || destShape[i] != 1)
writeSizes.push_back(tileSizes[idx++]);
@@ -1407,8 +1387,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
writeSizes.push_back(oneIdxAttr);
}
auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
- unpackOp.getDest(), writeOffsets,
- writeSizes, writeStrides);
+ unpackOp.getDest(), writeSizes);
rewriter.replaceOp(unpackOp, insert.getResult());
return success();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5a58d7cbed30f..204e9bb73e12c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2445,6 +2445,19 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
}
+/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
+/// result type, offsets set to 0 and strides set to 1.
+void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
+ RankedTensorType resultType, Value source,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<NamedAttribute> attrs) {
+ Attribute zeroIdxAttr = b.getIndexAttr(0);
+ Attribute oneIdxAttr = b.getIndexAttr(1);
+ SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
+ SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
+ build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
+}
+
/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
@@ -3889,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
+// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
+// to 0, strides set to 1 and inferred result type.
+void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
+ Value dest, ArrayRef<OpFoldResult> sizes,
+ ArrayRef<NamedAttribute> attrs) {
+ Attribute zeroIdxAttr = b.getIndexAttr(0);
+ Attribute oneIdxAttr = b.getIndexAttr(1);
+ SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
+ SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
+ build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
+}
+
LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected InParallelOpInterface parent, got:")
More information about the Mlir-commits
mailing list