[Mlir-commits] [mlir] [mlir][tensor] Add new builders for insert_slice/extract_slice Ops (nfc) (PR #169533)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Nov 25 09:36:41 PST 2025


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/169533

Adds new builders for `tensor.insert_slice` and `tensor.extract_slice`
Ops for which the _offsets_ and the _strides_ are all 0s and 1s,
respecitvely. This allows us to write:
```cpp
tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeSizes);
```

instead of:
```cpp
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);

Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);

tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeOffsets,
                              writeSizes, writeStrides);
```


>From dd275ad51edbbdcc2a78c7114560e2d9ac42b3ca Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 25 Nov 2025 17:30:54 +0000
Subject: [PATCH] [mlir][tensor] Add new builders for
 insert_slice/extract_slice Ops (nfc)

Adds new builders for `tensor.insert_slice` and `tensor.extract_slice`
Ops for which the _offsets_ and the _strides_ are all 0s and 1s,
respecitvely. This allows us to write:
```cpp
tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeSizes);
```

instead of:
```cpp
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);

Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);

tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeOffsets,
                              writeSizes, writeStrides);
```
---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       | 12 ++++++-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 31 +++----------------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 25 +++++++++++++++
 3 files changed, 41 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index ac40d5e454281..35d2b6007c628 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -471,6 +471,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
     // a Range vector.
     OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
+    // result type, offsets set to 0 and strides set to 1.
+    OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
+      "ArrayRef<OpFoldResult>":$sizes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
   ];
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -930,7 +935,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     // a Range vector and inferred result type.
     OpBuilder<(ins "Value":$source, "Value":$dest,
       "ArrayRef<Range>":$ranges,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
+    // to 0, strides set to 1 and inferred result type.
+    OpBuilder<(ins "Value":$source, "Value":$dest,
+      "ArrayRef<OpFoldResult>":$sizes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
   ];
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 027268cc20ddd..67e2b9f8d6077 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1167,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
                 "this is not supported ATM!");
   }
 
-  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
-  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   Location loc = packOp.getLoc();
 
   int64_t srcRank = packOp.getSourceRank();
-  int64_t destRank = packOp.getDestRank();
 
   // 1. Get the input that is going to be packed. If the input requires padding,
   // add a padding operation and return that as the input.
@@ -1262,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
     writeSizes.push_back(tileSizeOfr);
   }
 
-  // TODO: Add a constructor for tensor.insert_slice that doesn't require
-  // strides nor offsets.
-  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
-  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
-
   auto insert = tensor::InsertSliceOp::create(
-      rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
-      writeOffsets, writeSizes, writeStrides);
+      rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
 
   // 4. Replace tensor.packOp with tensor.insert_slice created above
   rewriter.replaceOp(packOp, insert.getResult());
@@ -1279,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
 LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
     linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
-  int64_t srcRank = unpackOp.getSourceRank();
   int64_t destRank = unpackOp.getDestRank();
   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
   ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
@@ -1296,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   Value source = unpackOp.getSource();
   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
       unpackOp.getDimAndTileMapping();
-  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
 
   // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
@@ -1307,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   // outer-tiled-dims being all 1), this will be
   //    [ outer-untiled-dims, tile-sizes ]
   SmallVector<OpFoldResult> extractSliceSizes;
-  // The offset and strides attributes for ExtractSliceOp.
-  SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
-  SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
 
   // Shape for EmptyOp that's used as the init value for TransposeOp below.
   // This should be:
@@ -1364,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   Type elemType = unpackOp.getSourceType().getElementType();
   auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
   Value innerTile = tensor::ExtractSliceOp::create(
-      rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
-      extractSliceSizes, extractSliceStrides);
+      rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
 
   // 2. Transpose the tile to match the outer corresponding tile order.
   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
@@ -1381,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
 
   // 3. Handle in-complete tiles if needed. It truncates trailing data from the
   // transposed tile.
-  int numLoops = shapeForEmptyOp.size();
-  SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
-  SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
   SmallVector<OpFoldResult> tileSizes;
   ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
   for (auto i : llvm::seq<unsigned>(0, destRank)) {
@@ -1393,13 +1375,11 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   }
 
   auto partialTile =
-      tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
-                                     tileOffsets, tileSizes, tileStrides);
+      tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
+                                     transposedOp.getResult()[0], tileSizes);
 
   // 4. Insert the result to the destination tensor.
   SmallVector<OpFoldResult> writeSizes;
-  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
-  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
   for (int i = 0, idx = 0; i < destRank; ++i) {
     if (dimAndTileMapping.count(i) || destShape[i] != 1)
       writeSizes.push_back(tileSizes[idx++]);
@@ -1407,8 +1387,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
       writeSizes.push_back(oneIdxAttr);
   }
   auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
-                                              unpackOp.getDest(), writeOffsets,
-                                              writeSizes, writeStrides);
+                                              unpackOp.getDest(), writeSizes);
   rewriter.replaceOp(unpackOp, insert.getResult());
 
   return success();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5a58d7cbed30f..204e9bb73e12c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2445,6 +2445,19 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
   }
 }
 
+/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
+/// result type, offsets set to 0 and strides set to 1.
+void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
+                           RankedTensorType resultType, Value source,
+                           ArrayRef<OpFoldResult> sizes,
+                           ArrayRef<NamedAttribute> attrs) {
+  Attribute zeroIdxAttr = b.getIndexAttr(0);
+  Attribute oneIdxAttr = b.getIndexAttr(1);
+  SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
+  SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
+  build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
+}
+
 /// Verifier for ExtractSliceOp.
 LogicalResult ExtractSliceOp::verify() {
   RankedTensorType sourceType = getSourceType();
@@ -3889,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
 }
 
+// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
+// to 0, strides set to 1 and inferred result type.
+void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
+                          Value dest, ArrayRef<OpFoldResult> sizes,
+                          ArrayRef<NamedAttribute> attrs) {
+  Attribute zeroIdxAttr = b.getIndexAttr(0);
+  Attribute oneIdxAttr = b.getIndexAttr(1);
+  SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
+  SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
+  build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
+}
+
 LogicalResult ParallelInsertSliceOp::verify() {
   if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
     return this->emitError("expected InParallelOpInterface parent, got:")



More information about the Mlir-commits mailing list