[Mlir-commits] [mlir] cc11ced - [mlir][Linalg] Add support for fusion between indexed_generic ops and generic ops on tensors.
Hanhan Wang
llvmlistbot at llvm.org
Wed Jun 3 14:59:25 PDT 2020
Author: Hanhan Wang
Date: 2020-06-03T14:58:43-07:00
New Revision: cc11ceda165b5ba0a87e812fbd6ed1bce4fefd2f
URL: https://github.com/llvm/llvm-project/commit/cc11ceda165b5ba0a87e812fbd6ed1bce4fefd2f
DIFF: https://github.com/llvm/llvm-project/commit/cc11ceda165b5ba0a87e812fbd6ed1bce4fefd2f.diff
LOG: [mlir][Linalg] Add support for fusion between indexed_generic ops and generic ops on tensors.
Summary:
Different from the fusion between generic ops, indices are involved. In this
context, we need to re-map the indices for producer since the fused op is built
on consumer's perspective. This patch supports all combination of the fusion
between indexed_generic ops and generic ops, which includes tests case:
1) generic op as producer and indexed_generic op as consumer.
2) indexed_generic op as producer and generic op as consumer.
3) indexed_generic op as producer and indexed_generic op as consumer.
Differential Revision: https://reviews.llvm.org/D80347
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 3f3c1c53fc3a..9964e1355097 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
@@ -435,9 +436,9 @@ static void fuseLinalgOpsGreedily(FuncOp f) {
namespace {
-/// Implementation of fusion of generic ops.
+/// Implementation of fusion of generic ops and indexed_generic ops.
struct FuseGenericOpsOnTensors {
- static bool isFusible(GenericOp producer, GenericOp consumer,
+ static bool isFusible(LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx) {
// Verify that
// - the producer has all "parallel" iterator type.
@@ -456,7 +457,7 @@ struct FuseGenericOpsOnTensors {
return producerResultIndexMap.isPermutation();
}
- static Operation *fuse(GenericOp producer, GenericOp consumer,
+ static Operation *fuse(LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx, PatternRewriter &rewriter,
OperationFolder *folder = nullptr) {
if (!isFusible(producer, consumer, consumerIdx))
@@ -482,7 +483,8 @@ struct FuseGenericOpsOnTensors {
// indexing_map of the operand at consumerIdx in the consumer.
SmallVector<Attribute, 4> fusedIndexMaps;
auto consumerIndexMaps = consumer.indexing_maps();
- fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults());
+ fusedIndexMaps.reserve(fusedOperands.size() +
+ consumer.getOperation()->getNumResults());
fusedIndexMaps.assign(consumerIndexMaps.begin(),
std::next(consumerIndexMaps.begin(), consumerIdx));
// Compute indexing maps for the producer args in the fused operation.
@@ -494,15 +496,56 @@ struct FuseGenericOpsOnTensors {
consumerIndexMaps.end());
// Generate the fused op.
- auto fusedOp = rewriter.create<GenericOp>(
- rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
- rewriter.getI64IntegerAttr(fusedOperands.size()),
- rewriter.getI64IntegerAttr(consumer.getNumResults()),
- rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
- generateFusedRegion(rewriter, fusedOp.region(), producer.region(),
- consumer.region(), consumerIdx);
+ LinalgOp fusedOp;
+ if (isa<GenericOp>(producer.getOperation()) &&
+ isa<GenericOp>(consumer.getOperation())) {
+ fusedOp =
+ rewriter
+ .create<GenericOp>(
+ rewriter.getUnknownLoc(),
+ consumer.getOperation()->getResultTypes(), fusedOperands,
+ rewriter.getI64IntegerAttr(fusedOperands.size()),
+ rewriter.getI64IntegerAttr(
+ consumer.getOperation()->getNumResults()),
+ rewriter.getArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr)
+ .getOperation();
+ } else {
+ fusedOp =
+ rewriter
+ .create<IndexedGenericOp>(
+ rewriter.getUnknownLoc(),
+ consumer.getOperation()->getResultTypes(), fusedOperands,
+ rewriter.getI64IntegerAttr(fusedOperands.size()),
+ rewriter.getI64IntegerAttr(
+ consumer.getOperation()->getNumResults()),
+ rewriter.getArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr)
+ .getOperation();
+ }
+
+ // Construct an AffineMap from consumer loops to producer loops.
+ // consumer loop -> tensor index
+ AffineMap consumerResultIndexMap =
+ consumer.getInputIndexingMap(consumerIdx);
+ // producer loop -> tensor index
+ AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+ // tensor index -> producer loop
+ AffineMap invProducerResultIndexMap =
+ inversePermutation(producerResultIndexMap);
+ assert(invProducerResultIndexMap &&
+ "expected producer result indexig map to be invertible");
+ // consumer loop -> producer loop
+ AffineMap consumerToProducerLoopsMap =
+ invProducerResultIndexMap.compose(consumerResultIndexMap);
+
+ generateFusedRegion(rewriter, fusedOp, producer, consumer,
+ consumerToProducerLoopsMap, consumerIdx,
+ consumer.getNumLoops());
return fusedOp;
}
@@ -511,7 +554,7 @@ struct FuseGenericOpsOnTensors {
/// the `producer` to use in the fused operation given the indexing map of the
/// result of the producer in the consumer.
static void computeProducerOperandIndex(
- GenericOp producer, AffineMap fusedConsumerArgIndexMap,
+ LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
// from consumer loop -> consumer arg tensor index/producer result tensor
@@ -544,29 +587,68 @@ struct FuseGenericOpsOnTensors {
/// Generate the region of the fused operation. The region of the fused op
/// must be empty.
- static void generateFusedRegion(PatternRewriter &rewriter,
- Region &fusedRegion, Region &producerRegion,
- Region &consumerRegion,
- unsigned consumerIdx) {
+ static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp,
+ LinalgOp producer, LinalgOp consumer,
+ AffineMap consumerToProducerLoopsMap,
+ unsigned consumerIdx, unsigned nloops) {
// Build the region of the fused op.
- Block &producerBlock = producerRegion.front();
- Block &consumerBlock = consumerRegion.front();
+ Block &producerBlock = producer.getOperation()->getRegion(0).front();
+ Block &consumerBlock = consumer.getOperation()->getRegion(0).front();
Block *fusedBlock = new Block();
- fusedRegion.push_back(fusedBlock);
+ fusedOp->getRegion(0).push_back(fusedBlock);
BlockAndValueMapping mapper;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(fusedBlock);
+
+ // The block arguments are
+ // [index_0, index_1, ... ,
+ // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
+ // producer_operand_0, ... , producer_operand_(n-1)],
+ // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
+ // , where n is the number of producer's operand and m is the number
+ // consumer's operand.
+ // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
+ // generic op. In this case, there are no indices in block arguments.
+ unsigned numProducerIndices =
+ isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
+ unsigned numConsumerIndices =
+ isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
+ // Firstly, add all the indices to the block arguments.
+ for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
+ i < e; ++i)
+ fusedBlock->addArgument(rewriter.getIndexType());
// Map the arguments for the unmodified args from the consumer.
for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
- if (consumerArg.index() == consumerIdx) {
+ if (consumerArg.index() == consumerIdx + numConsumerIndices) {
// Map the arguments for the args from the producer.
- for (auto producerArg : producerBlock.getArguments())
- mapper.map(producerArg,
- fusedBlock->addArgument(producerArg.getType()));
+ for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
+ // If producer is an indexed_generic op, map the indices from consumer
+ // loop to producer loop (because the fusedOp is built based on
+ // consumer's perspective).
+ if (producerArg.index() < numProducerIndices) {
+ auto newIndex = rewriter.create<mlir::AffineApplyOp>(
+ producer.getLoc(),
+ consumerToProducerLoopsMap.getSubMap(producerArg.index()),
+ fusedBlock->getArguments().take_front(nloops));
+ mapper.map(producerArg.value(), newIndex);
+ } else {
+ mapper.map(producerArg.value(),
+ fusedBlock->addArgument(producerArg.value().getType()));
+ }
+ }
continue;
}
- mapper.map(consumerArg.value(),
- fusedBlock->addArgument(consumerArg.value().getType()));
+
+ // If consumer is an indexed_generic op, map the indices to the block
+ // arguments directly. Otherwise, add the same type of arugment and map to
+ // it.
+ if (consumerArg.index() < numConsumerIndices) {
+ mapper.map(consumerArg.value(),
+ fusedBlock->getArgument(consumerArg.index()));
+ } else {
+ mapper.map(consumerArg.value(),
+ fusedBlock->addArgument(consumerArg.value().getType()));
+ }
}
// Add operations from producer (except the yield operation) to the fused
@@ -576,7 +658,9 @@ struct FuseGenericOpsOnTensors {
// Lookup the value the yield operation is mapped to.
Value yieldVal = yieldOp.getOperand(0);
if (Value clonedVal = mapper.lookupOrNull(yieldVal))
- mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
+ mapper.map(
+ consumerBlock.getArgument(consumerIdx + numConsumerIndices),
+ clonedVal);
continue;
}
rewriter.clone(op, mapper);
@@ -838,20 +922,28 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
if (!producer || producer->getNumResults() != 1)
return nullptr;
- // Fuse when consumer is GenericOp.
- if (GenericOp genericOp = dyn_cast<GenericOp>(consumer)) {
- if (!genericOp.hasTensorSemantics())
+ // Fuse when consumer is GenericOp or IndexedGenericOp.
+ if (isa<GenericOp>(consumer) || isa<IndexedGenericOp>(consumer)) {
+ auto linalgOpConsumer = cast<LinalgOp>(consumer);
+ if (!linalgOpConsumer.hasTensorSemantics())
return nullptr;
- if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
- if (genericOpProducer.hasTensorSemantics())
- return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp,
+ if (isa<GenericOp>(producer) || isa<IndexedGenericOp>(producer)) {
+ auto linalgOpProducer = cast<LinalgOp>(producer);
+ if (linalgOpProducer.hasTensorSemantics())
+ return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer,
consumerIdx, rewriter, folder);
} else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
- return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
- reshapeOpProducer, genericOp, consumerIdx, rewriter, folder);
+ if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
+ return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
+ reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter,
+ folder);
+ }
} else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
- return FuseConstantOpAsProducer<GenericOp>::fuse(
- constantOpProducer, genericOp, consumerIdx, rewriter, folder);
+ if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
+ return FuseConstantOpAsProducer<GenericOp>::fuse(
+ constantOpProducer, genericOpConsumer, consumerIdx, rewriter,
+ folder);
+ }
}
return nullptr;
}
@@ -865,6 +957,7 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
}
return nullptr;
}
+
return nullptr;
}
@@ -911,8 +1004,8 @@ struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
void mlir::populateLinalgTensorOpsFusionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
- patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<TensorReshapeOp>>(
- context);
+ patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
+ FuseTensorOps<TensorReshapeOp>>(context);
}
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 83bd1753eb28..9b73c02a4ed2 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -274,3 +274,150 @@ func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
// CHECK-SAME: args_out = 1 : i64
// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32)
// CHECK: mulf %[[CST]], %[[ARG1]]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
+ %arg1: tensor<?x?xi32>) {
+ %0 = linalg.generic {
+ args_in = 2 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %arg0, %arg1 {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %10 = addi %arg2, %arg3 : i32
+ linalg.yield %10 : i32
+ } : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
+ %1 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = index_cast %arg3 : index to i32
+ %4 = addi %arg4, %2 : i32
+ %5 = subi %4, %3 : i32
+ linalg.yield %5 : i32
+ }: tensor<?x?xi32> -> tensor<?x?xi32>
+ return
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @generic_op_indexed_generic_op_fusion
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 2
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ARG3]] : i32
+// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32
+// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32
+// CHECK: %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32
+// CHECK: %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32
+// CHECK: linalg.yield %[[VAL3]] : i32
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
+ %arg1: tensor<?x?xi32>) {
+ %0 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %arg0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = index_cast %arg3 : index to i32
+ %4 = addi %arg4, %2 : i32
+ %5 = subi %4, %3 : i32
+ linalg.yield %5 : i32
+ }: tensor<?x?xi32> -> tensor<?x?xi32>
+ %1 = linalg.generic {
+ args_in = 2 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %0, %arg1 {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %10 = addi %arg2, %arg3 : i32
+ linalg.yield %10 : i32
+ } : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
+ return
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @indexed_generic_op_generic_op_fusion
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 2
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32
+// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32
+// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND]] : i32
+// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32
+// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG3]] : i32
+// CHECK: linalg.yield %[[VAL3]] : i32
+// CHECK-NOT: linalg.generic
+
+// -----
+
+// The indices of the first indexed_generic op are swapped after fusion.
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) {
+ %0 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %arg0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = index_cast %arg3 : index to i32
+ %4 = addi %arg4, %2 : i32
+ %5 = subi %4, %3 : i32
+ linalg.yield %5 : i32
+ }: tensor<?x?xi32> -> tensor<?x?xi32>
+ %1 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map1, #map1],
+ iterator_types = ["parallel", "parallel"] } %0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = index_cast %arg3 : index to i32
+ %4 = addi %arg4, %2 : i32
+ %5 = subi %4, %3 : i32
+ linalg.yield %5 : i32
+ }: tensor<?x?xi32> -> tensor<?x?xi32>
+ return
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @indexed_generic_op_fusion
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 1
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[ARG1]] : index to i32
+// CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[ARG0]] : index to i32
+// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND1]] : i32
+// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND1]] : i32
+// CHECK: %[[ADD_OPERAND2:.+]] = index_cast %[[ARG0]] : index to i32
+// CHECK: %[[SUB_OPERAND2:.+]] = index_cast %[[ARG1]] : index to i32
+// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32
+// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
+// CHECK: linalg.yield %[[VAL4]] : i32
+// CHECK-NOT: linalg.indexed_generic
More information about the Mlir-commits
mailing list