[Mlir-commits] [mlir] [mlir][tensor] Generalize/restrict `GeneralizeOuterUnitDimsPackOpPattern` (PR #114315)
Andrzej Warzyński
llvmlistbot at llvm.org
Wed Nov 6 07:18:46 PST 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/114315
>From 5c84d6541cfc61e59354cde31fa68a4760abe671 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 17 Oct 2024 19:19:04 +0100
Subject: [PATCH 1/3] [mlir][tensor] Generalize/restrict
`GeneralizeOuterUnitDimsPackOpPattern`
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This PR *restricts* `GeneralizeOuterUnitDimsPackOpPattern` to follow its
intended purpose (as per the documentation), which is to:
> require all outer dimensions of tensor.pack to be 1.
There was one in-tree test that violated this assumption (and happened
to work) – see `@simple_KCRS_to_KRSCsr` in
"generalize-tensor-pack.mlir". That test has been updated to satisfy the
new requirements of the pattern.
By enforcing the pattern to follow its intended design (i.e., making it
stricter), the calculation of shapes and sizes for various Ops that the
pattern generates (PadOp, ExtractSliceOp, EmptyOp, TensorOp, and
InsertSliceOp) becomes much simpler and easier to document. This also
helped *generalize* the pattern to support cases like the one below:
```mlir
func.func @simple_pad_and_pack_dynamic_tile_cst(
%src: tensor<5x1xf32>,
%dest: tensor<1x1x?x2xf32>,
%pad: f32) -> tensor<1x1x?x2xf32> {
%tile_dim_0 = arith.constant 8 : index
%0 = tensor.pack %src
padding_value(%pad : f32)
inner_dims_pos = [0, 1]
inner_tiles = [%tile_dim_0, 2]
into %dest : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
return %0 : tensor<1x1x?x2xf32>
}
```
Note that the inner tile slice is dynamic but compile-time constant.
`getPackOpSourceOrPaddedSource`, which is used to generate PadOp,
detects this and generates a PadOp with static shapes. This is a good
optimization, but it means that all shapes/sizes for Ops generated by
`GeneralizeOuterUnitDimsPackOpPattern` also need to be updated to be
constant/static. By restricting the pattern and simplifying the
size/shape calculation, supporting the case above becomes much easier.
Notable implementation changes:
* PadOp processes the original source (no change in dimensions/rank).
ExtractSliceOp extracts the tile to pack and may reduce the rank. All
following ops work on the tile extracted by ExtractSliceOp (possibly
rank-reduced).
* All shape/size calculations assume that trailing dimensions match
inner_tiles from tensor.pack. All leading dimensions (i.e., outer
dimensions) are assumed to be 1.
* Dynamic sizes for ops like ExtractSliceOp are taken from inner_tiles
rather than computed as, for example, tensor.dim %dest, 2. It’s the
responsibility of the "producers" of tensor.pack to ensure that
dimensions in %dest match the specified tile sizes.
---
.../Dialect/Linalg/Transforms/Transforms.h | 40 ++++-
.../Dialect/Linalg/Transforms/Transforms.cpp | 125 ++++++++++----
.../Linalg/generalize-tensor-pack.mlir | 158 ++++++++++++------
3 files changed, 239 insertions(+), 84 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b5710bd78f0089..a8662a3d6f63be 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1515,9 +1515,43 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
const SmallVector<Value> &dynSizes) const;
};
-/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose +
-/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all
-/// 1s.
+/// Rewrites a tensor::PackOp into a sequence of:
+/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
+/// tensor::EmptyOp + tensor::InsertSliceOp ops.
+///
+/// Required that all the outer dims of the input tensor::PackOp are 1.
+///
+/// Before:
+/// ```
+/// %packed = tensor.pack %input
+/// padding_value(%pad : f32)
+/// inner_dims_pos = [1, 0]
+/// inner_tiles = [2, %high]
+/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
+/// ```
+///
+/// After:
+/// ```
+/// // PadOp
+/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
+/// ^bb0(...):
+/// tensor.yield %arg2 : f32
+/// } : tensor<5x1xf32> to tensor<?x2xf32>
+/// // ExtractSliceOp
+/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
+/// 1]
+/// : tensor<?x2xf32> to tensor<?x2xf32>
+/// // EmptyOp + TransposeOp
+/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
+/// %transposed = linalg.transpose
+/// ins(%extracted_slice : tensor<?x2xf32>)
+/// outs(%empty : tensor<2x?xf32>)
+/// permutation = [1, 0]
+/// // InsertSliceOp
+/// %inserted_slice = tensor.insert_slice %transposed
+/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
+/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
+/// ```
struct GeneralizeOuterUnitDimsPackOpPattern
: public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index da5233049aaf69..ed5f1bd602d7f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -27,6 +27,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -1138,6 +1139,29 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
return perm;
}
+// A helper function to generate a dim-and-size pair for Ops like
+// ExtractSliceOp that require both:
+// * dims to specify the output shape, and
+// * sizes for the sizes attribute (or similar).
+// For dynamic sizes, if the corresponding size is a compile time constant:
+// * the return size becomes the attribute encapsulating the known size, and
+// * dim is updated from kDynamic to its actual known value.
+static std::pair<int64_t, OpFoldResult>
+getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) {
+ int64_t tileSizeForShape =
+ getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
+
+ OpFoldResult tileSizeOfrSimplified;
+ if (tileSizeForShape != ShapedType::kDynamic) {
+ tileSizeOfrSimplified = rewriter.getIndexAttr(tileSizeForShape);
+ } else {
+ tileSizeOfrSimplified = tileSizeOfr;
+ }
+
+ return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
+ tileSizeOfrSimplified);
+}
+
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
// TODO: support the case that outer dimensions are not all 1s. A
@@ -1148,69 +1172,104 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp, "require the tiled outer dimensions of the result are all 1s");
}
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
- // outer dims.
+ Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
+ Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
+
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
- Attribute oneIdxAttr = rewriter.getIndexAttr(1);
int64_t srcRank = packOp.getSourceRank();
+
+ int64_t destRank = packOp.getDestRank();
+ size_t numTiles = destRank - srcRank;
+
+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+ // %extracted_tile = tensor.extract_slice(%pack_op_input)
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
- SmallVector<OpFoldResult> readSizes;
- SmallVector<OpFoldResult> transShapeForEmpty;
- SmallVector<int64_t> readShapeForExtractSlice;
+
+ // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
+ // all outer dims are 1.
+ SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
+ // The shape of the output for ExtractSliceOp. All leading unit dims are
+ // effectively rank-reduced, hence skipped.
+ SmallVector<int64_t> outputShapeForExtractSlice;
+
+ // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
+ // be equal to the inner tile sizes.
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
- readShapeForExtractSlice.push_back(
- getConstantIntValue(dimAndTileMapping[i])
- .value_or(ShapedType::kDynamic));
- readSizes.push_back(dimAndTileMapping[i]);
- transShapeForEmpty.push_back(dimAndTileMapping[i]);
- continue;
- }
- if (ShapedType::isDynamic(inputShape[i])) {
- readSizes.push_back(
- rewriter.create<tensor::DimOp>(loc, input, i).getResult());
- } else {
- readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
- }
- if (inputShape[i] != 1) {
- readShapeForExtractSlice.push_back(inputShape[i]);
- transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
+ auto [tileSize, tileSizeOfr] =
+ getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
+ extractSliceSizes.push_back(tileSizeOfr);
+ outputShapeForExtractSlice.push_back(tileSize);
}
}
Type elemType = packOp.getSourceType().getElementType();
- auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
+ auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
Value tile = rewriter.create<tensor::ExtractSliceOp>(
- loc, readType, input, readOffsets, readSizes, readStrides);
+ loc, readType, input, readOffsets, extractSliceSizes, readStrides);
- // 2. Transpose the tile to match the inner tile order.
+ // 2. Transpose the tile to match the inner tile order:
+ // %init = tensor.empty()
+ // %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
+ // NOTE: Outer dims are 1 and hence effectively ignored.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
- applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
+ // 2.1 Create tensor.empty (init value for TransposeOp)
+ SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
+ SmallVector<int64_t> transShapeForEmptyOpStatic;
+
+ // Acquire tensor shape required to create EmptyOp. This will match the inner
+ // tile sizes, but the actual data format will depend on whether the tile
+ // sizes are static or dynamic (each case leads to a different builder for
+ // EmptyOp). Conservatively, prepare for both scenarios.
+ size_t idx = numTiles;
+ while (idx != 0) {
+ transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
+ transShapeForEmptyOpStatic.push_back(
+ outputShapeForExtractSlice[numTiles - idx]);
+ idx--;
+ }
- Value empty =
- rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
+ applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
+ applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
+
+ Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
+ ? rewriter.create<tensor::EmptyOp>(
+ loc, transShapeForEmptyOpDynamic, elemType)
+ : rewriter.create<tensor::EmptyOp>(
+ loc, transShapeForEmptyOpStatic, elemType);
+
+ // 2.2 Create linalg.transpose
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
- // 3. Insert the inner tile to the destination.
- int64_t destRank = packOp.getDestRank();
+ // 3. Insert the inner tile to the destination:
+ // %inserted_tile = tensor.insert_slice(%transposed_tile)
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
- SmallVector<OpFoldResult> writeSizes =
- tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+ // Outer dims are all 1s!
+ SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
+ oneIdxAttr);
+ SmallVector<int64_t> writeShape;
+
+ for (auto tileSize : packOp.getMixedTiles()) {
+ auto [tileSizeStatic, tileSizeOfr] =
+ getSimplifiedDimSizePair(tileSize, rewriter);
+ writeSizes.push_back(tileSizeOfr);
+ writeShape.push_back(tileSizeStatic);
+ }
+ // 4. Replace tensor.packOp with tensor.insert_slice created above
auto insert = rewriter.create<tensor::InsertSliceOp>(
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
writeSizes, writeStrides);
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 7f6b5e279f6857..8abf7a11bed5c9 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -1,21 +1,32 @@
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s
-func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> {
- %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32>
- return %0 : tensor<1x1x1x1x8x32xf32>
+
+func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
+ %c8 = arith.constant 8 : index
+ %c5 = arith.constant 5 : i32
+ %pack = tensor.pack %arg0 padding_value(%c5 : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %arg1 : tensor<?x?xi32> -> tensor<1x1x?x1xi32>
+ return %pack : tensor<1x1x?x1xi32>
}
-// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x32xf32>)
-// CHECK-SAME: permutation = [1, 0]
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
-// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
-// CHECK: return %[[INSERT]]
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (-s0 + 8)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<()[s0] -> (-s0 + 1)>
+
+// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<?x?xi32>,
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32>
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 5 : i32
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_5:.*]] = tensor.dim %[[SRC]], %[[VAL_4]] : tensor<?x?xi32>
+// CHECK: %[[VAL_6:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_5]]]
+// CHECK: %[[VAL_7:.*]] = tensor.dim %[[SRC]], %[[VAL_2]] : tensor<?x?xi32>
+// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[VAL_7]]]
+// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[VAL_6]], %[[VAL_8]]] {
+// CHECK: ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index):
+// CHECK: tensor.yield %[[VAL_3]] : i32
+// CHECK: } : tensor<?x?xi32> to tensor<8x1xi32>
+// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<8x1xi32> into tensor<1x1x?x1xi32>
+// CHECK: return %[[INSERT]] : tensor<1x1x?x1xi32>
// -----
@@ -39,26 +50,59 @@ func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: te
/// Same as example above, but with 1 dynamic tile size.
-func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
- %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %tile_dim_0: index) -> tensor<1x1x?x2xf32> {
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
return %0 : tensor<1x1x?x2xf32>
}
-
// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[HIGH_VAL:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x2xf32> {
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
+// CHECK-SAME: %[[TILE_DIM_0:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x2xf32> {
+// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_0]]]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
-// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
+ %tile_dim_0 = arith.constant 8 : index
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+ return %0 : tensor<1x1x?x2xf32>
+}
+// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile_cst(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
+// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high[3, 1] {
+// CHECK: tensor.yield %[[PAD_VAL]] : f32
+// CHECK: } : tensor<5x1xf32> to tensor<8x2xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<8x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+
+func.func @simple_pad_and_pack_dynamic_tile_transpose(%input: tensor<5x1xf32>, %output: tensor<1x1x2x?xf32>, %pad: f32, %tile_dim_1: index) -> tensor<1x1x2x?xf32> {
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
+ return %0 : tensor<1x1x2x?xf32>
+}
+// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile_transpose(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]: index) -> tensor<1x1x2x?xf32> {
+// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_1]]]
+// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK: tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NEXT: } : tensor<5x1xf32> to tensor<?x2xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_1]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM_1]]) : tensor<2x?xf32>
+// CHECK: %[[TR:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[SLICE]] : tensor<?x2xf32>) outs(%[[EMPTY]] : tensor<2x?xf32>)
+// CHECK-SAME: permutation = [1, 0]
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32>
+// CHECK: return %[[RES]] : tensor<1x1x2x?xf32>
+
/// Same as example above, but with 1 scalable tile size.
/// NOTE: For this example to make sense in practice, the "?" in the output shape
@@ -77,7 +121,6 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[VS:.+]] = vector.vscale
// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
@@ -86,37 +129,56 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
-// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
/// Same as example above, but with both tile sizes dynamic.
-func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %high_1: index, %high_2: index) -> tensor<1x1x?x?xf32> {
- %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high_1, %high_2] into %output : tensor<5x1xf32> -> tensor<1x1x?x?xf32>
+func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %tile_dim_0: index, %tile_dim_1: index) -> tensor<1x1x?x?xf32> {
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x?x?xf32>
return %0 : tensor<1x1x?x?xf32>
}
// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tiles(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x?xf32>,
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32,
-// CHECK-SAME: %[[HIGH_VAL_1:[a-zA-Z0-9]+]]: index,
-// CHECK-SAME: %[[HIGH_VAL_2:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x?xf32> {
-// CHECK: %[[C3:.*]] = arith.constant 3 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[PAD_HIGH_1:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL_1]]]
-// CHECK: %[[PAD_HIGH_2:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[HIGH_VAL_2]]]
+// CHECK-SAME: %[[TILE_DIM_0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x?xf32> {
+// CHECK: %[[PAD_HIGH_1:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_0]]]
+// CHECK: %[[PAD_HIGH_2:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[TILE_DIM_1]]]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH_1]], %[[PAD_HIGH_2]]] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[HIGH_VAL_1]], %[[HIGH_VAL_2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[DIM_1:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x?xf32>
-// CHECK: %[[DIM_2:.*]] = tensor.dim %[[DEST]], %[[C3]] : tensor<1x1x?x?xf32>
-// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM_1]], %[[DIM_2]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x?xf32>
// -----
+func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x1x5x1xf32>, %output: tensor<1x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x1x1x1x2x?xf32> {
+ %0 = tensor.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x1x5x1xf32> -> tensor<1x1x1x1x2x?xf32>
+ return %0 : tensor<1x1x1x1x2x?xf32>
+}
+// CHECK: #[[$ATTR_2:.+]] = affine_map<()[s0] -> (s0 - 5)>
+// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x5x1xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x1x1x2x?xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: f32,
+// CHECK-SAME: %[[VAL_3:.*]]: index) -> tensor<1x1x1x1x2x?xf32> {
+// CHECK: %[[VAL_4:.*]] = affine.apply #[[$ATTR_2]](){{\[}}%[[VAL_3]]]
+// CHECK: %[[VAL_5:.*]] = tensor.pad %[[VAL_0]] low[0, 0, 0, 0] high[0, 0, %[[VAL_4]], 1] {
+// CHECK: ^bb0(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index):
+// CHECK: tensor.yield %[[VAL_2]] : f32
+// CHECK: } : tensor<1x1x5x1xf32> to tensor<1x1x?x2xf32>
+// CHECK: %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_11:.*]][0, 0, 0, 0] [1, 1, %[[VAL_3]], 2] [1, 1, 1, 1] : tensor<1x1x?x2xf32> to tensor<?x2xf32>
+// CHECK: %[[VAL_12:.*]] = tensor.empty(%[[VAL_3]]) : tensor<2x?xf32>
+// CHECK: %[[VAL_13:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<?x2xf32>) outs(%[[VAL_12]] : tensor<2x?xf32>) permutation = [1, 0]
+// CHECK: %[[VAL_14:.*]] = tensor.insert_slice %[[VAL_13]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x1x1x2x?xf32>
+// CHECK: return %[[VAL_14]] : tensor<1x1x1x1x2x?xf32>
+// CHECK: }
+
+// -----
+
func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{
%0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32>
return %0 : tensor<1x1x32x8xf32>
@@ -149,19 +211,19 @@ func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x
// -----
-func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1: tensor<3x1x1x1x8x32xf32>) -> tensor<3x1x1x1x8x32xf32> {
- %0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32>
- return %0 : tensor<3x1x1x1x8x32xf32>
+func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> {
+ %0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32>
+ return %0 : tensor<1x1x1x1x8x32xf32>
}
// CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [3, 1, 32, 8] [1, 1, 1, 1]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x8x32xf32>
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]] : tensor<3x32x8xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<3x8x32xf32>)
-// CHECK-SAME: permutation = [0, 2, 1]
+// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x32xf32>)
+// CHECK-SAME: permutation = [1, 0]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
-// CHECK-SAME: [0, 0, 0, 0, 0, 0] [3, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
// CHECK: return %[[INSERT]]
>From 6c445eca8e00847d02d6ec94069f5bfb6e2ee1bc Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Nov 2024 08:56:59 +0000
Subject: [PATCH 2/3] fixup! [mlir][tensor] Generalize/restrict
`GeneralizeOuterUnitDimsPackOpPattern`
SKip calculating static shapes for EmptyOp
---
.../Dialect/Linalg/Transforms/Transforms.cpp | 20 +++++--------------
1 file changed, 5 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index ed5f1bd602d7f4..ac275116e1172b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1147,13 +1147,13 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
// * the return size becomes the attribute encapsulating the known size, and
// * dim is updated from kDynamic to its actual known value.
static std::pair<int64_t, OpFoldResult>
-getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) {
+getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, Builder &b) {
int64_t tileSizeForShape =
getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
OpFoldResult tileSizeOfrSimplified;
if (tileSizeForShape != ShapedType::kDynamic) {
- tileSizeOfrSimplified = rewriter.getIndexAttr(tileSizeForShape);
+ tileSizeOfrSimplified = b.getIndexAttr(tileSizeForShape);
} else {
tileSizeOfrSimplified = tileSizeOfr;
}
@@ -1226,28 +1226,18 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
// 2.1 Create tensor.empty (init value for TransposeOp)
SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
- SmallVector<int64_t> transShapeForEmptyOpStatic;
// Acquire tensor shape required to create EmptyOp. This will match the inner
- // tile sizes, but the actual data format will depend on whether the tile
- // sizes are static or dynamic (each case leads to a different builder for
- // EmptyOp). Conservatively, prepare for both scenarios.
+ // tile sizes.
size_t idx = numTiles;
while (idx != 0) {
transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
- transShapeForEmptyOpStatic.push_back(
- outputShapeForExtractSlice[numTiles - idx]);
idx--;
}
- applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
-
- Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
- ? rewriter.create<tensor::EmptyOp>(
- loc, transShapeForEmptyOpDynamic, elemType)
- : rewriter.create<tensor::EmptyOp>(
- loc, transShapeForEmptyOpStatic, elemType);
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, transShapeForEmptyOpDynamic, elemType);
// 2.2 Create linalg.transpose
auto transposedOp =
>From 50ab228266f66d613c0ced0b40882472894707cb Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Nov 2024 10:49:45 +0000
Subject: [PATCH 3/3] fixup! fixup! [mlir][tensor] Generalize/restrict
`GeneralizeOuterUnitDimsPackOpPattern`
Raname and move getSimplifiedDimSizePair
---
.../mlir/Dialect/Utils/StaticValueUtils.h | 19 ++++++++++
.../Dialect/Linalg/Transforms/Transforms.cpp | 35 ++++---------------
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 14 ++++++++
3 files changed, 39 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 4d7aa1ae17fdb1..d1f7ab1156248f 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -60,6 +60,25 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec);
+/// Given OpFoldResult representing dim size value (*), generates a pair of
+/// sizes:
+/// * 1st result, static value, contains an int64_t dim size that can be used
+/// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims),
+/// * 2nd result, dynamic value, contains OpFoldResult encapsulating the
+/// actual dim size (either original or updated input value).
+/// For input sizes for which it is possible to extract a constant Attribute,
+/// replaces the original size value with an integer attribute (unless it's
+/// already a constant Attribute). The 1st return value also becomes the actual
+/// integer size (as opposed ShapedType::kDynamic).
+///
+/// (*) This hook is usually used when, given input sizes as OpFoldResult,
+/// it's required to generate two vectors:
+/// * sizes as int64_t to generate a shape,
+/// * sizes as OpFoldResult for sizes-like attribute.
+/// Please update this comment if you identify other use cases.
+std::pair<int64_t, OpFoldResult>
+getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b);
+
/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
template <typename IntTy>
SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index ac275116e1172b..2097883530528a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1139,37 +1139,14 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
return perm;
}
-// A helper function to generate a dim-and-size pair for Ops like
-// ExtractSliceOp that require both:
-// * dims to specify the output shape, and
-// * sizes for the sizes attribute (or similar).
-// For dynamic sizes, if the corresponding size is a compile time constant:
-// * the return size becomes the attribute encapsulating the known size, and
-// * dim is updated from kDynamic to its actual known value.
-static std::pair<int64_t, OpFoldResult>
-getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, Builder &b) {
- int64_t tileSizeForShape =
- getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
-
- OpFoldResult tileSizeOfrSimplified;
- if (tileSizeForShape != ShapedType::kDynamic) {
- tileSizeOfrSimplified = b.getIndexAttr(tileSizeForShape);
- } else {
- tileSizeOfrSimplified = tileSizeOfr;
- }
-
- return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
- tileSizeOfrSimplified);
-}
-
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
- if (llvm::any_of(packOp.getTiledOuterDims(),
+ if (llvm::any_of(packOp.getAllOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
- packOp, "require the tiled outer dimensions of the result are all 1s");
+ packOp, "not all outer dimensions of the result are 1s");
}
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
@@ -1202,7 +1179,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
auto [tileSize, tileSizeOfr] =
- getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
+ getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
extractSliceSizes.push_back(tileSizeOfr);
outputShapeForExtractSlice.push_back(tileSize);
}
@@ -1236,8 +1213,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
- Value empty = rewriter.create<tensor::EmptyOp>(
- loc, transShapeForEmptyOpDynamic, elemType);
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, transShapeForEmptyOpDynamic, elemType);
// 2.2 Create linalg.transpose
auto transposedOp =
@@ -1254,7 +1231,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
for (auto tileSize : packOp.getMixedTiles()) {
auto [tileSizeStatic, tileSizeOfr] =
- getSimplifiedDimSizePair(tileSize, rewriter);
+ getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
writeShape.push_back(tileSizeStatic);
}
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3eb6215a7a0b9b..0b399fba3f2635 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -58,6 +58,20 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(ShapedType::kDynamic);
}
+std::pair<int64_t, OpFoldResult>
+getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) {
+ int64_t tileSizeForShape =
+ getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
+
+ OpFoldResult tileSizeOfrSimplified =
+ (tileSizeForShape != ShapedType::kDynamic)
+ ? b.getIndexAttr(tileSizeForShape)
+ : tileSizeOfr;
+
+ return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
+ tileSizeOfrSimplified);
+}
+
void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec) {
More information about the Mlir-commits
mailing list