[Mlir-commits] [mlir] 39a604e - [mlir][linalg] update fusion on tensors to support linalg index operations.
Tobias Gysi
llvmlistbot at llvm.org
Mon Apr 19 23:14:36 PDT 2021
Author: Tobias Gysi
Date: 2021-04-20T06:13:04Z
New Revision: 39a604e3df8535aeed1db5a2e3a544d47b330ba1
URL: https://github.com/llvm/llvm-project/commit/39a604e3df8535aeed1db5a2e3a544d47b330ba1
DIFF: https://github.com/llvm/llvm-project/commit/39a604e3df8535aeed1db5a2e3a544d47b330ba1.diff
LOG: [mlir][linalg] update fusion on tensors to support linalg index operations.
The patch replaces the index operations in the body of fused producers and linearizes the indices after expansion.
Differential Revision: https://reviews.llvm.org/D100479
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index a404cbd560f7..4d6045a6d7b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -28,10 +28,6 @@ using namespace mlir::linalg;
/// Implementation of fusion of generic ops and indexed_generic ops.
static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx) {
- // TODO: remove once index ops are supported.
- if (producer.hasIndexSemantics() || consumer.hasIndexSemantics())
- return false;
-
// Producer and consumer must have tensor semantics.
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
return false;
@@ -138,7 +134,7 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
// 1. Map consumer indices to fusedBlock indices 1-1.
mapper.map(consumerBlock.getArguments().take_front(numConsumerIndices),
fusedBlock->getArguments().take_front(numConsumerIndices));
- // 2. Embed producer indices into fusedBlock index space 1-1.
+ // 2a. Embed producer indices into fusedBlock index space 1-1.
for (auto it :
llvm::zip(producerBlock.getArguments().take_front(numProducerIndices),
fusedBlock->getArguments().take_front(numProducerIndices))) {
@@ -148,6 +144,28 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
fusedBlock->getArguments().take_front(numFusedOpIndices));
mapper.map(std::get<0>(it), newIndex);
}
+ // 2b. Replace the producer index operations by index operations placed in the
+ // fused block using the `consumerToProducerLoopsMap` to map the index spaces.
+ unsigned numFusedOpLoops =
+ std::max(producer.getNumLoops(), consumer.getNumLoops());
+ if (producer.hasIndexSemantics()) {
+ SmallVector<Value> fusedIndices;
+ fusedIndices.reserve(numFusedOpLoops);
+ llvm::transform(llvm::seq<int64_t>(0, numFusedOpLoops),
+ std::back_inserter(fusedIndices), [&](int64_t dim) {
+ return rewriter.create<IndexOp>(producer.getLoc(), dim);
+ });
+ for (IndexOp indexOp :
+ llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
+ Value newIndex = rewriter.create<mlir::AffineApplyOp>(
+ producer.getLoc(),
+ consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
+ // Replace the producer index operation by the index value computed in the
+ // fused block. All remaining operations in the producer block are later
+ // on cloned to the fused block.
+ rewriter.replaceOp(indexOp, newIndex);
+ }
+ }
// TODO: allow fusing the producer of an output operand.
assert(consumerIdx < consumer.getNumInputs() &&
"expected producer of input operand");
@@ -329,8 +347,8 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
invProducerResultIndexMap.compose(consumerResultIndexMap);
generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
- consumerToProducerLoopsMap, consumerIdx,
- consumer.getNumLoops());
+ consumerToProducerLoopsMap, consumerIdx,
+ consumer.getNumLoops());
return SmallVector<Value, 1>(fusedOp->getResults());
}
@@ -602,17 +620,16 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
return success();
}
-/// To expand an indexed_generic operation, the body of the indexed generic op
-/// need to be modified appropriately. Specifically, uses of arguments for
-/// induction variables in the original operation need to be replaced with
-/// linearization of the corresponding arguments in the expanded op. That
-/// requires the shape of the expanded dimensions (at least all but the most
-/// significant. For now check that these are all statically sized. Note that
-/// this could be extended to handle dynamic case, but the implementation below
-/// uses `affine.apply` which seems to have issues when the shapes are not
-/// static.
-LogicalResult isIndexedGenericOpExpandable(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo) {
+/// Epanding the body of a linalg operation requires adaptations of the accessed
+/// loop indices. Specifically, access of indices in the original operation need
+/// to be replaced with linearizations of indices in the expanded op. That
+/// requires the shape of the expanded dimensions to be static (at least all but
+/// the most significant). For now check that these are all statically sized.
+/// Note that this could be extended to handle dynamic case, but the
+/// implementation below uses `affine.apply` which seems to have issues when the
+/// shapes are not static.
+LogicalResult isIndexedOpExpandable(LinalgOp linalgOp,
+ const ExpansionInfo &expansionInfo) {
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
if (expandedShape.size() == 1)
@@ -734,6 +751,49 @@ static void buildExpandedIndexedGenericOpRegion(
argReplacements);
}
+/// Update the body of an expanded linalg operation having index semantics. The
+/// indices of the original operation need to be recovered by linearizing the
+/// indices of the correspoding dimensions of the expanded operation. For now it
+/// is assumed that the shapes of the expanded operation needed for
+/// linearization are static.
+static void updateExpandedIndexOpRegion(PatternRewriter &rewriter, Location loc,
+ Region &fusedRegion,
+ const ExpansionInfo &expansionInfo) {
+ // Replace the original indices by the linearization of the expanded indices.
+ for (IndexOp indexOp :
+ llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
+ ArrayRef<int64_t> expandedDims =
+ expansionInfo.getExpandedDims(indexOp.dim());
+ assert(!expandedDims.empty() && "expected valid expansion info");
+
+ // Skip index operations that are not affected by the expansion.
+ if (expandedDims.size() == 1 &&
+ expandedDims.front() == (int64_t)indexOp.dim())
+ continue;
+
+ // Linearize the expanded indices of the original index dimension.
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointAfter(indexOp);
+ ArrayRef<int64_t> expandedDimsShape =
+ expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
+ SmallVector<Value> expandedIndices;
+ expandedIndices.reserve(expandedDims.size() - 1);
+ llvm::transform(
+ expandedDims.drop_front(), std::back_inserter(expandedIndices),
+ [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
+ Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
+ for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
+ assert(!ShapedType::isDynamic(std::get<0>(it)));
+ AffineExpr idx, acc;
+ bindDims(rewriter.getContext(), idx, acc);
+ newIndex = rewriter.create<AffineApplyOp>(
+ indexOp.getLoc(), idx + acc * std::get<0>(it),
+ ValueRange{std::get<1>(it), newIndex});
+ }
+ rewriter.replaceOp(indexOp, newIndex);
+ }
+}
+
/// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
/// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
/// conditions have been satisfied.
@@ -748,6 +808,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
RankedTensorType expandedType =
isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
+ bool hasIndexSemantics = linalgOp.hasIndexSemantics() ||
+ isa<IndexedGenericOp>(linalgOp.getOperation());
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex,
@@ -755,8 +817,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
expandedType.getShape())))
return llvm::None;
- if (isa<IndexedGenericOp>(linalgOp.getOperation()) &&
- failed(isIndexedGenericOpExpandable(linalgOp, expansionInfo)))
+ if (hasIndexSemantics &&
+ failed(isIndexedOpExpandable(linalgOp, expansionInfo)))
return llvm::None;
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -823,6 +885,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
fusedRegion, expansionInfo);
}
+ // Update the index accesses after the expansion.
+ if (linalgOp.hasIndexSemantics())
+ updateExpandedIndexOpRegion(rewriter, loc, fusedRegion, expansionInfo);
+
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value, 1> resultVals;
@@ -1261,6 +1327,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
context, options.controlElementwiseOpsFusionFn);
populateFoldReshapeOpsByExpansionPatterns(
patterns, options.allowFoldingUnitDimReshapes);
+ AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 00d0995a25f6..40c52657a853 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -359,6 +359,58 @@ func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
// -----
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>,
+ %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
+ %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
+ %3 = linalg.generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors
+ %10 = addi %arg2, %arg3 : i32
+ linalg.yield %10 : i32
+ } -> tensor<?x?xi32>
+ %4 = linalg.generic {
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%3 : tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %5 = index_cast %idx0 : index to i32
+ %6 = index_cast %idx1 : index to i32
+ %7 = addi %arg2, %5 : i32
+ %8 = subi %7, %6 : i32
+ linalg.yield %8 : i32
+ } -> tensor<?x?xi32>
+ return %4 : tensor<?x?xi32>
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @producer_indexed_consumer_fusion
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[VAL1:.+]] = addi %[[ARG0]], %[[ARG1]] : i32
+// CHECK: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[IDX0]] : index to i32
+// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[IDX1]] : index to i32
+// CHECK: %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32
+// CHECK: %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32
+// CHECK: linalg.yield %[[VAL3]] : i32
+// CHECK-NOT: linalg.generic
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
%arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
@@ -409,6 +461,58 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
// -----
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
+ %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
+ %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
+ %3 = linalg.generic {
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0 : tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg4: i32, %arg5: i32): // no predecessors
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %4 = index_cast %idx0 : index to i32
+ %5 = index_cast %idx1 : index to i32
+ %6 = addi %arg4, %4 : i32
+ %7 = subi %6, %5 : i32
+ linalg.yield %7 : i32
+ } -> tensor<?x?xi32>
+ %4 = linalg.generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%3, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors
+ %10 = addi %arg2, %arg3 : i32
+ linalg.yield %10 : i32
+ } -> tensor<?x?xi32>
+ return %4 : tensor<?x?xi32>
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @indexed_producer_consumer_fusion
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[IDX0]] : index to i32
+// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[IDX1]] : index to i32
+// CHECK: %[[VAL1:.+]] = addi %[[ARG0]], %[[ADD_OPERAND]] : i32
+// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32
+// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG1]] : i32
+// CHECK: linalg.yield %[[VAL3]] : i32
+// CHECK-NOT: linalg.generic
+
+// -----
+
// The indices of the first indexed_generic op are swapped after fusion.
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
@@ -465,6 +569,69 @@ func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
// -----
+// The indices of the first indexed_generic op are swapped after fusion.
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>)
+ -> tensor<?x?xi32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
+ %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
+ %3 = linalg.generic {
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0 : tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %4 = index_cast %idx0 : index to i32
+ %5 = index_cast %idx1 : index to i32
+ %6 = addi %arg2, %4 : i32
+ %7 = subi %5, %6 : i32
+ linalg.yield %7 : i32
+ } -> tensor<?x?xi32>
+ %4= linalg.generic {
+ indexing_maps = [#map1, #map1],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%3 : tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %5 = index_cast %idx0 : index to i32
+ %6 = index_cast %idx1 : index to i32
+ %7 = addi %arg2, %5 : i32
+ %8 = subi %7, %6 : i32
+ linalg.yield %8 : i32
+ } -> tensor<?x?xi32>
+ return %4 : tensor<?x?xi32>
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @indexed_producer_indexed_consumer_fusion
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[IDX1]] : index to i32
+// CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[IDX0]] : index to i32
+// CHECK: %[[VAL1:.+]] = addi %[[ARG0]], %[[ADD_OPERAND1]] : i32
+// CHECK: %[[VAL2:.+]] = subi %[[SUB_OPERAND1]], %[[VAL1]] : i32
+// CHECK: %[[IDX2:.+]] = linalg.index 0 : index
+// CHECK: %[[IDX3:.+]] = linalg.index 1 : index
+// CHECK: %[[ADD_OPERAND2:.+]] = index_cast %[[IDX2]] : index to i32
+// CHECK: %[[SUB_OPERAND2:.+]] = index_cast %[[IDX3]] : index to i32
+// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32
+// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
+// CHECK: linalg.yield %[[VAL4]] : i32
+// CHECK-NOT: linalg.generic
+
+// -----
+
func @scalar_indexed_generic_fusion
(%arg0: tensor<5x1x1xf32>, %arg1 : tensor<i32>) -> tensor<10xf32>
{
@@ -507,6 +674,48 @@ func @scalar_indexed_generic_fusion
// -----
+func @scalar_generic_fusion
+ (%arg0: tensor<5x1x1xf32>, %arg1 : tensor<i32>) -> tensor<10xf32>
+{
+ %c0 = constant 0 : index
+ %cst = constant dense<1.000000e+00> : tensor<10xf32>
+ %0 = linalg.init_tensor [] : tensor<f32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+ iterator_types = []}
+ ins(%arg1 : tensor<i32>) outs(%0 : tensor<f32>) {
+ ^bb0(%arg2: i32, %arg3: f32): // no predecessors
+ %3 = index_cast %arg2 : i32 to index
+ %4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
+ linalg.yield %4 : f32
+ } -> tensor<f32>
+ %2 = linalg.init_tensor [10] : tensor<10xf32>
+ %3 = linalg.generic
+ {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%1, %cst : tensor<f32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
+ %4 = mulf %arg2, %arg3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<10xf32>
+ return %3 : tensor<10xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: func @scalar_generic_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<i32>
+// CHECK: %[[T0:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel"]
+// CHECK-SAME: ins(%[[ARG1]] : tensor<i32>)
+// CHECK: tensor.extract %[[ARG0]]
+// CHECK: linalg.yield
+// CHECK return %[[T0]]
+
+// -----
+
func @constant_fusion(%arg0 : tensor<4xf32>) -> (tensor<4xf32>) {
%cst = constant dense<1.0> : tensor<4xf32>
%1 = linalg.init_tensor [4] : tensor<4xf32>
@@ -655,32 +864,6 @@ func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tenso
// -----
-// CHECK-LABEL: func @index_op(
-// CHECK-COUNT-2: linalg.generic
-func @index_op(%arg0: tensor<1x8xindex>, %arg1: tensor<1x8xindex>) -> tensor<1x8xindex> {
- %0 = linalg.generic {
- indexing_maps = [affine_map<(i, j) -> (i, j)>],
- iterator_types = ["parallel", "parallel"]}
- outs(%arg0 : tensor<1x8xindex>) {
- ^bb0(%a: index): // no predecessors
- %2 = linalg.index 1 : index
- linalg.yield %2 : index
- } -> tensor<1x8xindex>
- %1 = linalg.generic {
- indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0 : tensor<1x8xindex>)
- outs(%arg1 : tensor<1x8xindex>) {
- ^bb0(%a: index, %b: index): // no predecessors
- %2 = linalg.index 0 : index
- %3 = addi %2, %a : index
- linalg.yield %3 : index
- } -> tensor<1x8xindex>
- return %1 : tensor<1x8xindex>
-}
-
-// -----
-
// CHECK-LABEL: func @no_fuse_constant_with_reduction
func @no_fuse_constant_with_reduction() -> tensor<3xf32>
{
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 9c0fe41684ee..0e7239ea01c0 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -237,6 +237,60 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
// -----
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
+ %arg1 : tensor<?x?x?xi32>) ->
+ tensor<?x?x?xi32>
+{
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k)>,
+ affine_map<(i, j, k, l) -> (l)>] :
+ tensor<?x?x4x?xi32> into tensor<?x?x?xi32>
+ %1 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
+ outs(%0 : tensor<?x?x?xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32, %s: i32):
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %idx2 = linalg.index 2 : index
+ %1 = muli %arg3, %arg4 : i32
+ %2 = index_cast %idx0 : index to i32
+ %3 = addi %1, %2 : i32
+ %4 = index_cast %idx1 : index to i32
+ %5 = addi %3, %4 : i32
+ %6 = index_cast %idx2 : index to i32
+ %7 = addi %5, %6 : i32
+ linalg.yield %7 : i32
+ } -> tensor<?x?x?xi32>
+ return %1 : tensor<?x?x?xi32>
+}
+
+// Only check the body in the indexed version of the test.
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
+// CHECK: func @indexed_consumer_reshape_producer_fusion
+// CHECK: linalg.generic
+// CHECK: ^{{.*}}(
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32, %[[ARG4:[a-zA-Z0-9]+]]: i32,
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32)
+// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
+// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
+// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX1]], %[[IDX0]])
+// CHECK: %[[T4:.+]] = muli %[[ARG3]], %[[ARG4]]
+// CHECK: %[[T5:.+]] = index_cast %[[T3]]
+// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]]
+// CHECK: %[[T7:.+]] = index_cast %[[IDX2]]
+// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]]
+// CHECK: %[[T9:.+]] = index_cast %[[IDX3]]
+// CHECK: %[[T10:.+]] = addi %[[T8]], %[[T9]]
+// CHECK: linalg.yield %[[T10]]
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
%arg1 : tensor<?x?xi32>) ->
@@ -280,6 +334,53 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
// -----
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
+ %arg1 : tensor<?x?xi32>) ->
+ tensor<?x?x4x5xi32>
+{
+ %0 = linalg.generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%arg0 : tensor<?x?xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32, %s: i32): // no predecessors
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %1 = muli %arg3, %arg4 : i32
+ %2 = index_cast %idx0 : index to i32
+ %3 = addi %1, %2 : i32
+ %4 = index_cast %idx1 : index to i32
+ %5 = addi %3, %4 : i32
+ linalg.yield %5 : i32
+ } -> tensor<?x?xi32>
+ %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k, l)>] :
+ tensor<?x?xi32> into tensor<?x?x4x5xi32>
+ return %1 : tensor<?x?x4x5xi32>
+}
+
+// Only check the body in the indexed version of the test.
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)>
+// CHECK: func @indexed_producer_reshape_consumer_fusion
+// CHECK: linalg.generic
+// CHECK: ^{{.*}}(
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32, %[[ARG4:[a-zA-Z0-9]+]]: i32,
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32)
+// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
+// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
+// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX3]], %[[IDX2]], %[[IDX1]])
+// CHECK: %[[T4:.+]] = muli %[[ARG3]], %[[ARG4]]
+// CHECK: %[[T5:.+]] = index_cast %[[IDX0]]
+// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]]
+// CHECK: %[[T7:.+]] = index_cast %[[T3]]
+// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]]
+// CHECK: linalg.yield %[[T8]]
+
+// -----
+
func @reshape_as_consumer_permutation
(%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
-> tensor<2x3x4x5x6x7xi32> {
@@ -350,6 +451,82 @@ func @reshape_as_consumer_permutation
// -----
+func @reshape_as_consumer_permutation
+ (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
+ -> tensor<2x3x4x5x6x7xi32> {
+ %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32>
+ %c = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>)
+ outs(%shape : tensor<6x4x210xi32>) {
+ ^bb0(%arg3 : i32, %arg4: i32, %s: i32):
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %idx2 = linalg.index 2 : index
+ %1 = addi %arg3, %arg4 : i32
+ %2 = index_cast %idx0 : index to i32
+ %3 = addi %1, %2 : i32
+ %4 = index_cast %idx1 : index to i32
+ %5 = addi %3, %4 : i32
+ %6 = index_cast %idx2 : index to i32
+ %7 = addi %5, %6 : i32
+ linalg.yield %7 : i32
+ } -> tensor<6x4x210xi32>
+ %d = linalg.tensor_reshape %c
+ [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
+ : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
+ return %d : tensor<2x3x4x5x6x7xi32>
+}
+
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
+// CHECK: func @reshape_as_consumer_permutation
+// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
+// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-SAME: [#[[MAP3]], #[[MAP4]]]
+// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
+// CHECK: %[[T4:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
+// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
+// CHECK: ^{{.+}}(
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32)
+// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
+// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
+// CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index
+// CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index
+// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[IDX1]], %[[IDX0]])
+// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[IDX4]], %[[IDX3]], %[[IDX2]])
+// CHECK-DAG: %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]]
+// CHECK: %[[T8:.+]] = index_cast %[[T5]]
+// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]]
+// CHECK: %[[T10:.+]] = index_cast %[[T6]]
+// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]]
+// CHECK: %[[T12:.+]] = index_cast %[[IDX5]]
+// CHECK: %[[T13:.+]] = addi %[[T11]], %[[T12]]
+
+// -----
+
func @reshape_as_producer_projected_permutation(
%arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
{
@@ -407,6 +584,66 @@ func @reshape_as_producer_projected_permutation(
// -----
+func @reshape_as_producer_projected_permutation(
+ %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
+{
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>]
+ : tensor<33x8x?xi32> into tensor<264x?xi32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0 : tensor<264x?xi32>)
+ outs(%shape : tensor<264x?x4xi32>) {
+ ^bb0(%arg1: i32, %s: i32): // no predecessors
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %idx2 = linalg.index 2 : index
+ %2 = index_cast %idx0 : index to i32
+ %3 = addi %arg1, %2 : i32
+ %4 = index_cast %idx1 : index to i32
+ %5 = addi %3, %4 : i32
+ %6 = index_cast %idx2 : index to i32
+ %7 = addi %5, %6 : i32
+ linalg.yield %7 : i32
+ } -> tensor<264x?x4xi32>
+ return %1 : tensor<264x?x4xi32>
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: @reshape_as_producer_projected_permutation
+// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32>
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<33x8x?xi32>)
+// CHECK: ^{{.+}}(
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: i32,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: i32)
+// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
+// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
+// CHECK-DAG: %[[T0:.+]] = affine.apply #[[MAP2]](%[[IDX1]], %[[IDX0]])
+// CHECK: %[[T1:.+]] = index_cast %[[T0]] : index to i32
+// CHECK: %[[T2:.+]] = addi %[[ARG1]], %[[T1]] : i32
+// CHECK: %[[T3:.+]] = index_cast %[[IDX2]] : index to i32
+// CHECK: %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32
+// CHECK: %[[T5:.+]] = index_cast %[[IDX3]] : index to i32
+// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32
+// CHECK: linalg.yield %[[T6]] : i32
+// CHECK: %[[RES2:.+]] = linalg.tensor_reshape %[[RES]]
+// CHECK-SAME: [#[[MAP3]], #[[MAP4]], #[[MAP5]]]
+// CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
+// CHECK: return %[[RES2]] : tensor<264x?x4xi32>
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
More information about the Mlir-commits
mailing list