[Mlir-commits] [mlir] 9d5239d - [mlir][Linalg] Add fusion of IndexedGenericOp with TensorReshapeOp by expansion.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 27 16:15:50 PDT 2020
Author: MaheshRavishankar
Date: 2020-10-27T16:15:34-07:00
New Revision: 9d5239d39e48b8b171a0fbc47dbbb22381f4d9be
URL: https://github.com/llvm/llvm-project/commit/9d5239d39e48b8b171a0fbc47dbbb22381f4d9be
DIFF: https://github.com/llvm/llvm-project/commit/9d5239d39e48b8b171a0fbc47dbbb22381f4d9be.diff
LOG: [mlir][Linalg] Add fusion of IndexedGenericOp with TensorReshapeOp by expansion.
This patch adds support for fusing linalg.indexed_generic op with
linalg.tensor_reshape op by expansion, i.e.
- linalg.indexed_generic op -> linalg.tensor_reshape op when the
latter is expanding.
- linalg.tensor_reshape op -> linalg.indexed_generic op when the
former is folding.
Differential Revision: https://reviews.llvm.org/D90082
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index f2a3fb7d7766..59e2d4cc5673 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -413,13 +413,13 @@ static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
unsigned fusedTensorIndex) {
// Is fusable only if:
- // - The linalgOp is a generic op.
+ // - The linalgOp is a generic op, or an indexed_generic.
// - All the indexing maps for operands in linalgOp are projected
// permutations.
// - The indexing map at the position representing the fused tensor is a
// permutation.
// - All the loops in linalgOp are parallel loops.
- return isa<GenericOp>(linalgOp.getOperation()) &&
+ return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) &&
linalgOp.hasTensorSemantics() &&
llvm::all_of(linalgOp.indexing_maps().getValue().take_front(
linalgOp.getNumInputs()),
@@ -460,7 +460,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
ArrayRef<int64_t> expandedShape = expandedType.getShape();
SmallVector<unsigned, 4> numFoldedDims(foldedType.getRank(), 0);
SmallVector<SmallVector<int64_t, 4>, 4> expandedDimsShape(
- expandedType.getRank());
+ foldedType.getRank());
auto reassociationMaps = reshapeOp.getReassociationMaps();
for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
@@ -472,6 +472,26 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
expandedDimsShape[pos].assign(shape.begin(), shape.end());
}
+ if (isa<IndexedGenericOp>(linalgOp.getOperation())) {
+ // For indexed generic op, the region contains arguments that represent the
+ // induction variable value of the loops. In the fused op these values are
+ // obtained by linearizing the expanded dimensions. For now just check that
+ // the extents used in the linearization (all the expanded dims except the
+ // front) are statically know. For dynamic case, we would need shape
+ // information on these dimensions to get these.
+ for (auto &expandedShape : expandedDimsShape) {
+ for (int64_t expandedDimShape : llvm::make_range(
+ std::next(expandedShape.begin()), expandedShape.end())) {
+ if (ShapedType::isDynamic(expandedDimShape)) {
+ linalgOp.emitError(
+ "unable to fuse indexed generic op where the expanded dim is "
+ "dynamic");
+ return llvm::None;
+ }
+ }
+ }
+ }
+
// The remapping of the indices is then the prefix sum (inclusive) of the
// numFoldedDims.
SmallVector<unsigned, 4> remapping(numFoldedDims.size() + 1, 0);
@@ -563,10 +583,56 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
/*outputBuffers=*/ValueRange{},
/*initTensors=*/ValueRange{}, expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
- // TODO: Add support for indexed generic op, which would need mapping the
- // expanded dimensions to the original dimension arguments.
- rewriter.cloneRegionBefore(linalgOp.getOperation()->getRegion(0), fusedRegion,
- fusedRegion.begin());
+ Region &originalRegion = linalgOp.getOperation()->getRegion(0);
+
+ if (isa<GenericOp>(linalgOp.getOperation())) {
+ rewriter.cloneRegionBefore(originalRegion, fusedRegion,
+ fusedRegion.begin());
+ } else {
+ assert(isa<IndexedGenericOp>(linalgOp.getOperation()));
+ // Create an entry block in the fused Region with same number of arguments
+ // as the fused op
+ Block *fusedEntryBlock = new Block;
+ fusedRegion.push_back(fusedEntryBlock);
+ rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.end());
+
+ // Merge the entry block of the fused op with the cloned blocks. For this
+ // compute the value for arguments of the region in the original operation
+ // in terms of the arguments of the fused op. Since the original operation
+ // is expanded, the expanded dimensions need to be folded back to get the
+ // replacement value for the arguments corresponding to interation index.
+ // For now this expects that all the loop ranges are constants, which is
+ // true if the shapes are all static. This has already been checked in the
+ // precondition.
+ using namespace edsc::op;
+ using namespace edsc::intrinsics;
+ OpBuilder::InsertionGuard guard(rewriter);
+ SmallVector<Value, 4> argReplacements(originalRegion.getNumArguments());
+ rewriter.setInsertionPointToStart(fusedEntryBlock);
+ edsc::ScopedContext scopedContext(rewriter, fusedOp.getLoc());
+ IndexType indexType = rewriter.getIndexType();
+ for (unsigned i : llvm::seq<unsigned>(0, numFoldedDims.size())) {
+ Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
+ for (unsigned foldedDim = remapping[i] + 1; foldedDim != remapping[i + 1];
+ foldedDim++) {
+ int64_t expandedDimExtent =
+ expandedDimsShape[i][foldedDim - remapping[i]];
+ assert(!ShapedType::isDynamic(expandedDimExtent));
+ linearizedIndex =
+ linearizedIndex * std_constant_index(expandedDimExtent);
+ linearizedIndex =
+ linearizedIndex + fusedEntryBlock->addArgument(indexType);
+ }
+ argReplacements[i] = linearizedIndex;
+ }
+ for (unsigned i :
+ llvm::seq<unsigned>(numFoldedDims.size(), argReplacements.size())) {
+ argReplacements[i] =
+ fusedEntryBlock->addArgument(originalRegion.getArgument(i).getType());
+ }
+ rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
+ argReplacements);
+ }
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
@@ -670,14 +736,15 @@ struct FoldProducerReshapeOpByLinearization
}
};
-/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the
-/// reshape op is collapsing dimensions. The dimensionality of the loop in the
-/// consumer generic op is expanded.
+/// Pattern to fuse a tensor_reshape op with its consumer
+/// generic/indexed_generic op, when the reshape op is collapsing
+/// dimensions. The dimensionality of the loop in the consumer is expanded.
+template <typename GenericOpTy>
struct FoldWithProducerReshapeOpByExpansion
- : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
+ : public OpRewritePattern<GenericOpTy> {
+ using OpRewritePattern<GenericOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(GenericOpTy genericOp,
PatternRewriter &rewriter) const override {
LinalgOp linalgOp = cast<LinalgOp>(genericOp.getOperation());
for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
@@ -942,7 +1009,9 @@ void mlir::populateFoldReshapeOpsByLinearizationPatterns(
void mlir::populateFoldReshapeOpsByExpansionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<FoldReshapeWithGenericOpByExpansion,
- FoldWithProducerReshapeOpByExpansion>(context);
+ FoldWithProducerReshapeOpByExpansion<GenericOp>,
+ FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
+ context);
}
void mlir::populateLinalgTensorOpsFusionPatterns(
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 865b10b51696..1f201f78fe74 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -190,3 +190,157 @@ func @scalar_reshape(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>)
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]] : tensor<f32>)
// CHECK: return %[[T1]] : tensor<1x10xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
+ %arg1 : tensor<?x?x?xi32>) ->
+ tensor<?x?x?xi32>
+{
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k)>,
+ affine_map<(i, j, k, l) -> (l)>] :
+ tensor<?x?x4x?xi32> into tensor<?x?x?xi32>
+ %1 = linalg.indexed_generic {
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xi32>) {
+ ^bb0(%arg3 : index, %arg4 : index, %arg5 : index, %arg6: i32, %arg7: i32):
+ %1 = muli %arg6, %arg7 : i32
+ %2 = index_cast %arg3 : index to i32
+ %3 = addi %1, %2 : i32
+ %4 = index_cast %arg4 : index to i32
+ %5 = addi %3, %4 : i32
+ %6 = index_cast %arg5 : index to i32
+ %7 = addi %5, %6 : i32
+ linalg.yield %7 : i32
+ } -> tensor<?x?x?xi32>
+ return %1 : tensor<?x?x?xi32>
+}
+
+// The generic op version of the test check for the op structure. Only
+// checking the op body here.
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
+// CHECK: func @indexed_generic_op_reshape_producer_fusion
+// CHECK: linalg.indexed_generic
+// CHECK: ^{{.*}}(
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32)
+// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG2]], %[[ARG3]])
+// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]]
+// CHECK: %[[T5:.+]] = index_cast %[[T3]]
+// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]]
+// CHECK: %[[T7:.+]] = index_cast %[[ARG4]]
+// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]]
+// CHECK: %[[T9:.+]] = index_cast %[[ARG5]]
+// CHECK: %[[T10:.+]] = addi %[[T8]], %[[T9]]
+// CHECK: linalg.yield %[[T10]]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
+ %arg1 : tensor<?x?xi32>) ->
+ tensor<?x?x4x5xi32>
+{
+ %0 = linalg.indexed_generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) {
+ ^bb0(%arg3 : index, %arg4 : index, %arg5: i32, %arg6: i32): // no predecessors
+ %1 = muli %arg5, %arg6 : i32
+ %2 = index_cast %arg3 : index to i32
+ %3 = addi %1, %2 : i32
+ %4 = index_cast %arg4 : index to i32
+ %5 = addi %3, %4 : i32
+ linalg.yield %5 : i32
+ } -> tensor<?x?xi32>
+ %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k, l)>] :
+ tensor<?x?xi32> into tensor<?x?x4x5xi32>
+ return %1 : tensor<?x?x4x5xi32>
+}
+// The generic op version of the test check for the op structure. Only
+// checking the op body here.
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 * 20 + d1 * 5 + d2)>
+// CHECK: func @indexed_generic_op_reshape_consumer_fusion
+// CHECK: linalg.indexed_generic
+// CHECK: ^{{.*}}(
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32)
+// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG3]], %[[ARG4]], %[[ARG5]])
+// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]]
+// CHECK: %[[T5:.+]] = index_cast %[[ARG2]]
+// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]]
+// CHECK: %[[T7:.+]] = index_cast %[[T3]]
+// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]]
+// CHECK: linalg.yield %[[T8]]
+
+// -----
+
+func @reshape_as_consumer_permutation
+ (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
+ -> tensor<2x3x4x5x6x7xi32> {
+ %c = linalg.indexed_generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>) {
+ ^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32):
+ %1 = addi %arg3, %arg4 : i32
+ %2 = index_cast %arg0 : index to i32
+ %3 = addi %1, %2 : i32
+ %4 = index_cast %arg1 : index to i32
+ %5 = addi %3, %4 : i32
+ %6 = index_cast %arg2 : index to i32
+ %7 = addi %5, %6 : i32
+ linalg.yield %7 : i32
+ } -> tensor<6x4x210xi32>
+ %d = linalg.tensor_reshape %c
+ [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
+ : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
+ return %d : tensor<2x3x4x5x6x7xi32>
+}
+
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 * 3 + d1)>
+// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2) -> (d0 * 42 + d1 * 7 + d2)>
+// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+// CHECK: func @reshape_as_consumer_permutation
+// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
+// CHECK-DAG: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-SAME: [#[[MAP3]], #[[MAP4]]]
+// CHECK: %[[T2:.+]] = linalg.indexed_generic
+// CHECK-SAME: indexing_maps = [#[[MAP7]], #[[MAP8]], #[[MAP9]]]
+// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<{{.+}}>, tensor<{{.+}}>)
+// CHECK: ^{{.+}}(
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index, %[[ARG7:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32)
+// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP5]](%[[ARG2]], %[[ARG3]])
+// CHECK-DAG: %[[T4:.+]] = affine.apply #[[MAP6]](%[[ARG4]], %[[ARG5]], %[[ARG6]])
+// CHECK-DAG: %[[T5:.+]] = addi %[[ARG8]], %[[ARG9]]
+// CHECK: %[[T6:.+]] = index_cast %[[T3]]
+// CHECK: %[[T7:.+]] = addi %[[T5]], %[[T6]]
+// CHECK: %[[T8:.+]] = index_cast %[[T4]]
+// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]]
+// CHECK: %[[T10:.+]] = index_cast %[[ARG7]]
+// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]]
More information about the Mlir-commits
mailing list