[Mlir-commits] [mlir] 518e6f3 - [mlir][Linalg] Fix fusion on tensors operands / bbArg mismatch
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Apr 6 08:42:01 PDT 2021
Author: Nicolas Vasilache
Date: 2021-04-06T15:39:40Z
New Revision: 518e6f341dddab7824592b3769146318950a01be
URL: https://github.com/llvm/llvm-project/commit/518e6f341dddab7824592b3769146318950a01be
DIFF: https://github.com/llvm/llvm-project/commit/518e6f341dddab7824592b3769146318950a01be.diff
LOG: [mlir][Linalg] Fix fusion on tensors operands / bbArg mismatch
Linalg fusion on tensors has mismatching assumptions on the operand side than on the region bbArg side.
Relax the behavior on the operand/indexing map side so that we better support output operands that may also be read from.
Differential revision: https://reviews.llvm.org/D99499
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 4fdd9a2221f0..95e008aacc45 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -896,7 +896,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*desc=*/[{
Return the indexing maps within the current operation.
}],
- /*retTy=*/"SmallVector<AffineMap, 4>",
+ /*retTy=*/"SmallVector<AffineMap>",
/*methodName=*/"getIndexingMaps",
/*args=*/(ins),
/*methodBody=*/"",
@@ -931,6 +931,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return getIndexingMaps()[i];
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the input indexing maps.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getInputIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto maps = $_op.getIndexingMaps();
+ return SmallVector<AffineMap>{maps.begin(),
+ maps.begin() + $_op.getNumInputs()};
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the output indexing map at index `i`.
@@ -944,6 +958,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return getIndexingMaps()[i + $_op.getNumInputs()];
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the output indexing maps.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getOutputIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto maps = $_op.getIndexingMaps();
+ return SmallVector<AffineMap>{maps.begin() + $_op.getNumInputs(),
+ maps.begin() + $_op.getNumShapedOperands()};
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return whether the op has only MemRef input and outputs.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index bb1a051c78e5..34eac4bdfcaa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -61,16 +61,14 @@ static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
/// the `producer` to use in the fused operation given the indexing map of the
/// result of the producer in the consumer.
-static void getIndexingMapOfProducerOperandsInFusedOp(
- LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
- SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
+static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+ OpOperand &producerOpOperand, AffineMap producerResultIndexMap,
+ AffineMap fusedConsumerArgIndexMap) {
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
// from consumer loop -> consumer arg tensor index/producer result tensor
// index. The fused loop is same as the consumer loop. For each producer arg
// the indexing map to be computed is a map from consumer loop -> producer
// arg tensor index.
-
- AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
// producerResultIndexMap is a map from producer loop -> tensor index.
// Compute the inverse to get map from tensor index -> producer loop.
// The inverse is a map from producer result tensor index -> producer loop.
@@ -78,19 +76,19 @@ static void getIndexingMapOfProducerOperandsInFusedOp(
inversePermutation(producerResultIndexMap);
assert(invProducerResultIndexMap &&
"expected producer result indexig map to be invertible");
- for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
- // argMap is a map from producer loop -> producer arg tensor index.
- AffineMap argMap = producer.getInputIndexingMap(argNum);
-
- // Compose argMap with invProducerResultIndexMap to get a map from
- // producer result tensor index -> producer arg tensor index.
- AffineMap t1 = argMap.compose(invProducerResultIndexMap);
-
- // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
- // consumer loop/ fused loop -> producer arg tensor index.
- AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
- fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
- }
+
+ LinalgOp producer = cast<LinalgOp>(producerOpOperand.getOwner());
+ // argMap is a map from producer loop -> producer arg tensor index.
+ AffineMap argMap =
+ producer.getIndexingMap(producerOpOperand.getOperandNumber());
+
+ // Compose argMap with invProducerResultIndexMap to get a map from
+ // producer result tensor index -> producer arg tensor index.
+ AffineMap t1 = argMap.compose(invProducerResultIndexMap);
+
+ // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
+ // consumer loop/ fused loop -> producer arg tensor index.
+ return t1.compose(fusedConsumerArgIndexMap);
}
/// Generate the region of the fused tensor operation. The region of the fused
@@ -163,6 +161,18 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
.drop_front(numProducerIndices)
.take_front(producer.getNumInputs()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+
+ // 4.b. Producer output operand/map that is fused needs to be mapped to the
+ // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
+ assert(producer->getNumResults() == 1 && "expected single result producer");
+ if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
+ BlockArgument bbArg =
+ producerBlock.getArguments()
+ .drop_front(numConsumerIndices + producer.getNumInputs())
+ // TODO: bbArg index of
+ .front();
+ mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+ }
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
for (BlockArgument bbArg : consumerBlock.getArguments()
.drop_front(numConsumerIndices)
@@ -221,73 +231,90 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
!controlFn(producer->getResult(0), consumerOpOperand))
return llvm::None;
- unsigned numFusedOperands =
- producer.getNumInputs() + consumer.getNumInputs() - 1;
-
- // Compute the fused operands list,
- SmallVector<Value, 2> fusedOperands;
- fusedOperands.reserve(numFusedOperands);
- auto consumerOperands = consumer.getInputs();
- auto producerOperands = producer.getInputs();
- fusedOperands.assign(consumerOperands.begin(),
- std::next(consumerOperands.begin(), consumerIdx));
- fusedOperands.append(producerOperands.begin(), producerOperands.end());
- fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
- consumerOperands.end());
-
- // Compute indexing_maps for the fused operation. The indexing_maps for the
- // operands of the consumers that aren't fused are the same. The
- // indexing_maps for the producers need to be computed based on the
- // indexing_map of the operand at consumerIdx in the consumer.
- SmallVector<Attribute, 4> fusedIndexMaps;
- auto consumerIndexMaps = consumer.indexing_maps();
- fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs());
- fusedIndexMaps.assign(consumerIndexMaps.begin(),
- std::next(consumerIndexMaps.begin(), consumerIdx));
- // Compute indexing maps for the producer args in the fused operation.
- getIndexingMapOfProducerOperandsInFusedOp(
- producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
-
- // Append the indexing maps for the remaining consumer operands.
- fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
- consumerIndexMaps.end());
+ // TODO: allow fusing the producer of an output operand.
+ assert(consumerIdx < consumer.getNumInputs() &&
+ "expected producer of input operand");
+
+ // Compute the fused operands list and indexing maps.
+ SmallVector<Value> fusedOperands;
+ SmallVector<AffineMap> fusedIndexMaps;
+ fusedOperands.reserve(producer->getNumOperands() +
+ consumer->getNumOperands());
+ fusedIndexMaps.reserve(producer->getNumOperands() +
+ consumer->getNumOperands());
+ // In the following, numbering matches that of `generateFusedTensorOpRegion`.
+ // 3. Consumer input operands/maps up to consumerIdx (exclusive).
+ llvm::append_range(fusedOperands,
+ consumer.getInputs().take_front(consumerIdx));
+ llvm::append_range(
+ fusedIndexMaps,
+ ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.take_front(
+ consumerIdx));
+ // 4. Splice in producer's input operands/maps.
+ llvm::append_range(fusedOperands, producer.getInputs());
+ assert(producer->getNumResults() == 1 && "expected single result producer");
+ AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+ for (auto &inputOpOperand : producer.getInputOpOperands()) {
+ // Compute indexing maps for the producer args in the fused operation.
+ AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+ inputOpOperand, producerResultIndexMap,
+ consumer.getInputIndexingMap(consumerIdx));
+ fusedIndexMaps.push_back(map);
+ }
+ // 4.b. Producer output operand/map that is fused needs to be passed if it is
+ // an "initTensor" (i.e. its value is actually read).
+ assert(producer->getNumResults() == 1 && "expected single result producer");
+ if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
+ llvm::append_range(fusedOperands, producer.getOutputs().take_front());
+ // Compute indexing maps for the producer args in the fused operation.
+ AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+ producer.getOutputOpOperands().front(), producerResultIndexMap,
+ consumer.getOutputIndexingMap(0));
+ fusedIndexMaps.push_back(map);
+ }
+ // 5. Remaining consumer's input operands/maps (drop past index
+ // `consumerIdx`).
+ llvm::append_range(fusedOperands,
+ consumer.getInputs().drop_front(consumerIdx + 1));
+ llvm::append_range(
+ fusedIndexMaps,
+ ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.drop_front(
+ consumerIdx + 1));
+ // 6. All of consumer's output operands (skip operands: added by the builder).
+ // llvm::append_range(fusedOperands, consumer.getOutputs());
+ llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps());
+ // 7. All of producer's output operands/maps except the one fused.
+ // TODO: allow fusion of multi-result producers.
+ assert(producer->getNumResults() == 1 && "expected single result producer");
// Generate the fused op.
- LinalgOp fusedOp;
+ Operation *fusedOp;
if (isa<GenericOp>(producer.getOperation()) &&
isa<GenericOp>(consumer.getOperation())) {
- fusedOp =
- rewriter
- .create<GenericOp>(consumer.getLoc(), consumer->getResultTypes(),
- /*inputs=*/fusedOperands,
- // TODO: handle outputs.
- consumer.getOutputs(),
- rewriter.getArrayAttr(fusedIndexMaps),
- consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*sparse=*/nullptr)
- .getOperation();
+ fusedOp = rewriter.create<GenericOp>(
+ consumer.getLoc(), consumer->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ // TODO: handle outputs.
+ consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*sparse=*/nullptr);
} else {
- fusedOp =
- rewriter
- .create<IndexedGenericOp>(
- consumer.getLoc(), consumer->getResultTypes(),
- /*inputs=*/fusedOperands,
- // TODO: handle outputs.
- consumer.getOutputs(), rewriter.getArrayAttr(fusedIndexMaps),
- consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*sparse=*/nullptr)
- .getOperation();
+ fusedOp = rewriter.create<IndexedGenericOp>(
+ consumer.getLoc(), consumer->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ // TODO: handle outputs.
+ consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*sparse=*/nullptr);
}
// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
- // producer loop -> tensor index
- AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
// tensor index -> producer loop
AffineMap invProducerResultIndexMap =
inversePermutation(producerResultIndexMap);
@@ -297,9 +324,9 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
AffineMap consumerToProducerLoopsMap =
invProducerResultIndexMap.compose(consumerResultIndexMap);
- generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
- consumer, consumerToProducerLoopsMap,
- consumerIdx, consumer.getNumLoops());
+ generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
+ consumerToProducerLoopsMap, consumerIdx,
+ consumer.getNumLoops());
return SmallVector<Value, 1>(fusedOp->getResults());
}
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 13109bd98c19..b0a006398c99 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -616,3 +616,39 @@ func @sigmoid_dynamic_dim(%0: tensor<?x1xf32>) -> tensor<?x1xf32> {
} -> tensor<?x1xf32>
return %2 : tensor<?x1xf32>
}
+
+// -----
+
+func private @compute1(%a: f64) -> f64
+func private @compute2(%a: f64, %b: i32) -> i32
+
+// CHECK-LABEL: func @generic_index_op2(
+func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tensor<1x8xi32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%arg0 : tensor<1x8xf64>) {
+ ^bb0(%a: f64):
+ %r = call @compute1(%a) : (f64) -> f64
+ linalg.yield %r : f64
+ } -> tensor<1x8xf64>
+
+ // CHECK-NEXT: %[[R:.*]] = linalg.generic
+ // CHECK: bb0(%[[BBA:[0-9a-z]*]]: f64, %[[BBB:[0-9a-z]*]]: i32):
+ // CHECK-NEXT: %[[A:.*]] = call @compute1(%[[BBA]]) : (f64) -> f64
+ // CHECK-NEXT: %[[B:.*]] = call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32
+ // CHECK-NEXT: linalg.yield %[[B]] : i32
+ // CHECK-NEXT: } -> tensor<1x8xi32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : tensor<1x8xf64>)
+ outs(%arg1 : tensor<1x8xi32>) {
+ ^bb0(%a: f64, %b: i32):
+ %r = call @compute2(%a, %b) : (f64, i32) -> i32
+ linalg.yield %r : i32
+ } -> tensor<1x8xi32>
+
+ // CHECK-NEXT: return %[[R]] : tensor<1x8xi32>
+ return %1 : tensor<1x8xi32>
+}
More information about the Mlir-commits
mailing list