[Mlir-commits] [mlir] 5d4603a - [mlir][tensor] Support more cases in MergeConsecutiveExtractSlice
Lei Zhang
llvmlistbot at llvm.org
Tue Sep 20 17:16:27 PDT 2022
Author: Lei Zhang
Date: 2022-09-20T20:16:03-04:00
New Revision: 5d4603a02d0c3e0106b10d245322b1d2072c0c3d
URL: https://github.com/llvm/llvm-project/commit/5d4603a02d0c3e0106b10d245322b1d2072c0c3d
DIFF: https://github.com/llvm/llvm-project/commit/5d4603a02d0c3e0106b10d245322b1d2072c0c3d.diff
LOG: [mlir][tensor] Support more cases in MergeConsecutiveExtractSlice
This commit adds utility functions to perform general merging of
OffsetSizeAndStrideOpInterface by supporting producer rank
reducing and non-unit strides.
With it we can extend MergeConsecutiveExtractSlice to support
more cases.
Co-authored-by: Mahesh Ravishankar <ravishankarm at google.com>
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D134294
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
index 2ca556275af12..e1e6a033d778f 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
@@ -14,6 +14,37 @@
namespace mlir {
namespace tensor {
+/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
+/// when combining a producer slice **into** a consumer slice.
+///
+/// This function performs the following computation:
+/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
+/// - Combined sizes = consumer_sizes
+/// - Combined strides = producer_strides * consumer_strides
+LogicalResult
+mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> producerOffsets,
+ ArrayRef<OpFoldResult> producerSizes,
+ ArrayRef<OpFoldResult> producerStrides,
+ const llvm::SmallBitVector &droppedProducerDims,
+ ArrayRef<OpFoldResult> consumerOffsets,
+ ArrayRef<OpFoldResult> consumerSizes,
+ ArrayRef<OpFoldResult> consumerStrides,
+ SmallVector<OpFoldResult> &combinedOffsets,
+ SmallVector<OpFoldResult> &combinedSizes,
+ SmallVector<OpFoldResult> &combinedStrides);
+
+/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
+/// when combining a `producer` slice op **into** a `consumer` slice op.
+LogicalResult
+mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
+ OffsetSizeAndStrideOpInterface producer,
+ OffsetSizeAndStrideOpInterface consumer,
+ const llvm::SmallBitVector &droppedProducerDims,
+ SmallVector<OpFoldResult> &combinedOffsets,
+ SmallVector<OpFoldResult> &combinedSizes,
+ SmallVector<OpFoldResult> &combinedStrides);
+
//===----------------------------------------------------------------------===//
// Extract slice from `tensor.collapse_shape`
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
index 48977a90ffb3b..e4489448ec5ff 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -17,29 +17,101 @@
using namespace mlir;
using namespace mlir::tensor;
-/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
-/// returns the results.
-static SmallVector<OpFoldResult> mergeOffsets(Location loc,
- ArrayRef<OpFoldResult> offsets1,
- ArrayRef<OpFoldResult> offsets2,
- OpBuilder &builder) {
- SmallVector<OpFoldResult> foldedOffsets;
- assert(offsets1.size() == offsets2.size());
- foldedOffsets.reserve(offsets1.size());
-
- AffineExpr dim1, dim2;
- bindDims(builder.getContext(), dim1, dim2);
-
- for (const auto &pair : llvm::zip(offsets1, offsets2)) {
- auto offset0 =
- getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair));
- auto offset1 =
- getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair));
- auto foldedOffset =
- makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1});
- foldedOffsets.push_back(foldedOffset.getResult());
+/// Creates AffineExpr from `ofr`: if the OpFoldResult is a Value, creates a
+/// AffineSymbolExpr and appends it to `symbols`; otherwise creates a
+/// AffineConstantExpr.
+static AffineExpr getAffineExpr(OpFoldResult ofr,
+ SmallVector<OpFoldResult> &symbols) {
+ if (auto attr = ofr.dyn_cast<Attribute>()) {
+ return getAffineConstantExpr(attr.cast<IntegerAttr>().getInt(),
+ attr.getContext());
}
- return foldedOffsets;
+ Value v = ofr.get<Value>();
+ AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext());
+ symbols.push_back(v);
+ return expr;
+}
+
+/// Builds the AffineExpr incrementally for arithmetic operations.
+static AffineExpr add(AffineExpr expr, OpFoldResult ofr,
+ SmallVector<OpFoldResult> &symbols) {
+ return expr + getAffineExpr(ofr, symbols);
+}
+static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs,
+ SmallVector<OpFoldResult> &symbols) {
+ return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols);
+}
+
+/// Converts an AffineExpr to OpFoldResult by generating an `affine.apply`
+/// op and fold it.
+static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc,
+ AffineExpr expr,
+ SmallVector<OpFoldResult> &symbols) {
+ AffineMap m = AffineMap::get(0, symbols.size(), expr);
+ return makeComposedFoldedAffineApply(builder, loc, m, symbols);
+}
+
+LogicalResult tensor::mergeOffsetsSizesAndStrides(
+ OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets,
+ ArrayRef<OpFoldResult> producerSizes,
+ ArrayRef<OpFoldResult> producerStrides,
+ const llvm::SmallBitVector &droppedProducerDims,
+ ArrayRef<OpFoldResult> consumerOffsets,
+ ArrayRef<OpFoldResult> consumerSizes,
+ ArrayRef<OpFoldResult> consumerStrides,
+ SmallVector<OpFoldResult> &combinedOffsets,
+ SmallVector<OpFoldResult> &combinedSizes,
+ SmallVector<OpFoldResult> &combinedStrides) {
+ combinedOffsets.resize(producerOffsets.size());
+ combinedSizes.resize(producerOffsets.size());
+ combinedStrides.resize(producerOffsets.size());
+ unsigned consumerPos = 0;
+ for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) {
+ if (droppedProducerDims.test(i)) {
+ // For dropped dims, get the values from the producer.
+ combinedOffsets[i] = producerOffsets[i];
+ combinedSizes[i] = producerSizes[i];
+ combinedStrides[i] = producerStrides[i];
+ continue;
+ }
+ SmallVector<OpFoldResult> offsetSymbols, strideSymbols;
+ // The combined offset is computed as
+ // producer_offset + consumer_offset * producer_strides.
+ combinedOffsets[i] =
+ getOpFoldResult(builder, loc,
+ add(mul(consumerOffsets[consumerPos],
+ producerStrides[i], offsetSymbols),
+ producerOffsets[i], offsetSymbols),
+ offsetSymbols);
+ combinedSizes[i] = consumerSizes[consumerPos];
+ // The combined stride is computed as
+ // consumer_stride * producer_stride.
+ combinedStrides[i] = getOpFoldResult(
+ builder, loc,
+ mul(consumerStrides[consumerPos], producerStrides[i], strideSymbols),
+ strideSymbols);
+ consumerPos++;
+ }
+ return success();
+}
+
+LogicalResult tensor::mergeOffsetsSizesAndStrides(
+ OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer,
+ OffsetSizeAndStrideOpInterface consumer,
+ const llvm::SmallBitVector &droppedProducerDims,
+ SmallVector<OpFoldResult> &combinedOffsets,
+ SmallVector<OpFoldResult> &combinedSizes,
+ SmallVector<OpFoldResult> &combinedStrides) {
+ SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
+ SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes();
+ SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
+ SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
+ SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes();
+ SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
+ return tensor::mergeOffsetsSizesAndStrides(
+ builder, loc, producerOffsets, producerSizes, producerStrides,
+ droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
+ combinedOffsets, combinedSizes, combinedStrides);
}
namespace {
@@ -53,24 +125,15 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
if (!prevOp)
return failure();
- if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
+ SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
+ if (failed(mergeOffsetsSizesAndStrides(rewriter, nextOp.getLoc(), prevOp,
+ nextOp, prevOp.getDroppedDims(),
+ newOffsets, newSizes, newStrides)))
return failure();
- auto prevResultType = prevOp.getType().cast<ShapedType>();
- if (prevOp.getSourceType().getRank() != prevResultType.getRank())
- return rewriter.notifyMatchFailure(
- prevOp, "rank-reducing producder case unimplemented");
-
- Location loc = nextOp.getLoc();
-
- SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets();
- SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets();
- SmallVector<OpFoldResult> foldedOffsets =
- mergeOffsets(loc, prevOffsets, nextOffsets, rewriter);
-
- rewriter.replaceOpWithNewOp<ExtractSliceOp>(
- nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets,
- nextOp.getMixedSizes(), nextOp.getMixedStrides());
+ rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
+ prevOp.getSource(), newOffsets,
+ newSizes, newStrides);
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
index 45a3f37ea0679..f5d77f63561cc 100644
--- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
+++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
@@ -9,10 +9,12 @@ func.func @extract_slice_same_rank(
// CHECK-LABEL: func.func @extract_slice_same_rank
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
-// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]]
+// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
// CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32>
+// -----
+
func.func @extract_slice_rank_reducing_consumer(
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
@@ -23,6 +25,8 @@ func.func @extract_slice_rank_reducing_consumer(
// CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
// CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
+// -----
+
func.func @extract_slice_rank_reducing_producer(
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
@@ -30,8 +34,27 @@ func.func @extract_slice_rank_reducing_producer(
return %1: tensor<8x?xf32>
}
-// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
-// CHECK-COUNT-2: tensor.extract_slice
+// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
+// CHECK-SAME: (%[[SRC:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
+// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][0, 8, 2, %[[OFFSET]]] [1, 8, 1, %[[SIZE1]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<8x?xf32>
+// CHECK: return %[[EXTRACT]] : tensor<8x?xf32>
+
+// -----
+
+func.func @extract_slice_non_one_stride(
+ %src: tensor<?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index, %stride0: index, %stride1: index) -> tensor<?xf32> {
+ %0 = tensor.extract_slice %src[%offset0] [%size0] [%stride0] : tensor<?xf32> to tensor<?xf32>
+ %1 = tensor.extract_slice %0[%offset1] [%size1] [%stride1] : tensor<?xf32> to tensor<?xf32>
+ return %1: tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @extract_slice_non_one_stride
+// CHECK-SAME: (%[[SRC:.+]]: tensor<?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index, %[[STRIDE0:.+]]: index, %[[STRIDE1:.+]]: index)
+// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>()[%[[OFFSET1]], %[[STRIDE0]], %[[OFFSET0]]]
+// CHECK: %[[STRIDE:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%[[STRIDE1]], %[[STRIDE0]]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][%[[OFFSET]]] [%[[SIZE1]]] [%[[STRIDE]]] : tensor<?xf32> to tensor<?xf32>
+// CHECK: return %[[EXTRACT]] : tensor<?xf32>
// -----
@@ -47,6 +70,8 @@ func.func @insert_slice_rank_reducing(
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+// -----
+
func.func @insert_slice_rank_reducing_dynamic_shape(
%dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> {
%0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32>
More information about the Mlir-commits
mailing list