[Mlir-commits] [mlir] b241226 - [mlir][linalg] Avoid illegal elementwise fusion into reductions
Stephan Herhut
llvmlistbot at llvm.org
Thu Nov 11 06:56:25 PST 2021
Author: Stephan Herhut
Date: 2021-11-11T15:56:12+01:00
New Revision: b241226aec1bbbff25e06de78adc4b00389ffd12
URL: https://github.com/llvm/llvm-project/commit/b241226aec1bbbff25e06de78adc4b00389ffd12
DIFF: https://github.com/llvm/llvm-project/commit/b241226aec1bbbff25e06de78adc4b00389ffd12.diff
LOG: [mlir][linalg] Avoid illegal elementwise fusion into reductions
Fusing into a reduction is only valid if doing so does not erase information on a reduction dimensions size.
Differential Revision: https://reviews.llvm.org/D113500
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index fa836ed9577a..783e1be84920 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -26,6 +26,38 @@
using namespace mlir;
using namespace mlir::linalg;
+/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
+/// the `producer` to use in the fused operation given the indexing map of the
+/// result of the producer in the consumer.
+static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+ OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
+ AffineMap fusedConsumerArgIndexMap) {
+ // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
+ // from consumer loop -> consumer arg tensor index/producer result tensor
+ // index. The fused loop is same as the consumer loop. For each producer arg
+ // the indexing map to be computed is a map from consumer loop -> producer
+ // arg tensor index.
+ // producerResultIndexMap is a map from producer loop -> tensor index.
+ // Compute the inverse to get map from tensor index -> producer loop.
+ // The inverse is a map from producer result tensor index -> producer loop.
+ AffineMap invProducerResultIndexMap =
+ inversePermutation(producerResultIndexMap);
+ assert(invProducerResultIndexMap &&
+ "expected producer result indexig map to be invertible");
+
+ LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
+ // argMap is a map from producer loop -> producer arg tensor index.
+ AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
+
+ // Compose argMap with invProducerResultIndexMap to get a map from
+ // producer result tensor index -> producer arg tensor index.
+ AffineMap t1 = argMap.compose(invProducerResultIndexMap);
+
+ // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
+ // consumer loop/ fused loop -> producer arg tensor index.
+ return t1.compose(fusedConsumerArgIndexMap);
+}
+
/// Conditions for elementwise fusion of generic operations.
static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
OpOperand *consumerOpOperand) {
@@ -57,39 +89,42 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
// verify it is a permutation.
AffineMap producerResultIndexMap =
producer.getTiedIndexingMap(producer.getOutputOperand(0));
- return producerResultIndexMap.isPermutation();
-}
+ if (!producerResultIndexMap.isPermutation())
+ return false;
-/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
-/// the `producer` to use in the fused operation given the indexing map of the
-/// result of the producer in the consumer.
-static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
- OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
- AffineMap fusedConsumerArgIndexMap) {
- // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
- // from consumer loop -> consumer arg tensor index/producer result tensor
- // index. The fused loop is same as the consumer loop. For each producer arg
- // the indexing map to be computed is a map from consumer loop -> producer
- // arg tensor index.
- // producerResultIndexMap is a map from producer loop -> tensor index.
- // Compute the inverse to get map from tensor index -> producer loop.
- // The inverse is a map from producer result tensor index -> producer loop.
- AffineMap invProducerResultIndexMap =
- inversePermutation(producerResultIndexMap);
- assert(invProducerResultIndexMap &&
- "expected producer result indexig map to be invertible");
+ // Ensure that the fusion does not remove size information required to
+ // get the loop bounds. For non-reduction generics, this is trivially the
+ // case due to the output operand. For reductions, we need to check that after
+ // the fusion, each loop dimension has at least one input that defines it.
+ if ((consumer.getNumReductionLoops())) {
+ llvm::BitVector coveredDims(consumer.getNumLoops(), false);
+
+ auto addToCoveredDims = [&](AffineMap map) {
+ for (auto result : map.getResults())
+ if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
+ coveredDims[dimExpr.getPosition()] = true;
+ };
- LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
- // argMap is a map from producer loop -> producer arg tensor index.
- AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
+ for (auto pair :
+ llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
+ Value operand = std::get<0>(pair);
+ if (operand == consumerOpOperand->get())
+ continue;
+ AffineMap operandMap = std::get<1>(pair);
+ addToCoveredDims(operandMap);
+ }
- // Compose argMap with invProducerResultIndexMap to get a map from
- // producer result tensor index -> producer arg tensor index.
- AffineMap t1 = argMap.compose(invProducerResultIndexMap);
+ for (OpOperand *operand : producer.getInputOperands()) {
+ AffineMap newIndexingMap =
+ getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+ operand, producerResultIndexMap, consumerIndexMap);
+ addToCoveredDims(newIndexingMap);
+ }
+ if (!coveredDims.all())
+ return false;
+ }
- // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
- // consumer loop/ fused loop -> producer arg tensor index.
- return t1.compose(fusedConsumerArgIndexMap);
+ return true;
}
/// Generate the region of the fused tensor operation. The region of the fused
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 3dea5367b619..652286f98184 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -907,3 +907,41 @@ func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2
} -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
+
+// -----
+
+// Fusing the broadcast into a reduction would require to insert extra knowledge
+// about the size of the reduction dimension. As long, as this is not
+// implemented, we check that two linalg operations remain.
+// TODO: Support this case in element-wise fusion.
+
+#map0 = affine_map<(d0, d1) -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+#map2 = affine_map<(d0, d1) -> (d1, d0)>
+#map3 = affine_map<(d0, d1) -> (d0)>
+
+// CHECK-LABEL: @no_fusion_missing_reduction_shape
+// CHECK: linalg.generic
+// CHECK: linalg.generic
+func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> tensor<?xf32> {
+ %cst = arith.constant 0xFF800000 : f32
+ %4 = linalg.init_tensor [%arg1, %arg1] : tensor<?x?xf32>
+ %5 = linalg.generic {
+ indexing_maps = [#map0, #map1],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%arg0 : tensor<f32>) outs(%4 : tensor<?x?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } -> tensor<?x?xf32>
+ %6 = linalg.init_tensor [%arg1] : tensor<?xf32>
+ %7 = linalg.fill(%cst, %6) : f32, tensor<?xf32> -> tensor<?xf32>
+ %8 = linalg.generic {
+ indexing_maps = [#map2, #map3],
+ iterator_types = ["parallel", "reduction"]
+ } ins(%5 : tensor<?x?xf32>) outs(%7 : tensor<?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
+ %9 = maxf %arg2, %arg3 : f32
+ linalg.yield %9 : f32
+ } -> tensor<?xf32>
+ return %8 : tensor<?xf32>
+}
More information about the Mlir-commits
mailing list