[Mlir-commits] [mlir] b241226 - [mlir][linalg] Avoid illegal elementwise fusion into reductions

Stephan Herhut llvmlistbot at llvm.org
Thu Nov 11 06:56:25 PST 2021


Author: Stephan Herhut
Date: 2021-11-11T15:56:12+01:00
New Revision: b241226aec1bbbff25e06de78adc4b00389ffd12

URL: https://github.com/llvm/llvm-project/commit/b241226aec1bbbff25e06de78adc4b00389ffd12
DIFF: https://github.com/llvm/llvm-project/commit/b241226aec1bbbff25e06de78adc4b00389ffd12.diff

LOG: [mlir][linalg] Avoid illegal elementwise fusion into reductions

Fusing into a reduction is only valid if doing so does not erase information on a reduction dimensions size.

Differential Revision: https://reviews.llvm.org/D113500

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index fa836ed9577a..783e1be84920 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -26,6 +26,38 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
+/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
+/// the `producer` to use in the fused operation given the indexing map of the
+/// result of the producer in the consumer.
+static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+    OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
+    AffineMap fusedConsumerArgIndexMap) {
+  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
+  // from consumer loop -> consumer arg tensor index/producer result tensor
+  // index. The fused loop is same as the consumer loop. For each producer arg
+  // the indexing map to be computed is a map from consumer loop -> producer
+  // arg tensor index.
+  // producerResultIndexMap is a map from producer loop -> tensor index.
+  // Compute the inverse to get map from tensor index -> producer loop.
+  // The inverse is a map from producer result tensor index -> producer loop.
+  AffineMap invProducerResultIndexMap =
+      inversePermutation(producerResultIndexMap);
+  assert(invProducerResultIndexMap &&
+         "expected producer result indexig map to be invertible");
+
+  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
+  // argMap is a map from producer loop -> producer arg tensor index.
+  AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
+
+  // Compose argMap with invProducerResultIndexMap to get a map from
+  // producer result tensor index -> producer arg tensor index.
+  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
+
+  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
+  // consumer loop/ fused loop -> producer arg tensor index.
+  return t1.compose(fusedConsumerArgIndexMap);
+}
+
 /// Conditions for elementwise fusion of generic operations.
 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
                                      OpOperand *consumerOpOperand) {
@@ -57,39 +89,42 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
   // verify it is a permutation.
   AffineMap producerResultIndexMap =
       producer.getTiedIndexingMap(producer.getOutputOperand(0));
-  return producerResultIndexMap.isPermutation();
-}
+  if (!producerResultIndexMap.isPermutation())
+    return false;
 
-/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
-/// the `producer` to use in the fused operation given the indexing map of the
-/// result of the producer in the consumer.
-static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
-    OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
-    AffineMap fusedConsumerArgIndexMap) {
-  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
-  // from consumer loop -> consumer arg tensor index/producer result tensor
-  // index. The fused loop is same as the consumer loop. For each producer arg
-  // the indexing map to be computed is a map from consumer loop -> producer
-  // arg tensor index.
-  // producerResultIndexMap is a map from producer loop -> tensor index.
-  // Compute the inverse to get map from tensor index -> producer loop.
-  // The inverse is a map from producer result tensor index -> producer loop.
-  AffineMap invProducerResultIndexMap =
-      inversePermutation(producerResultIndexMap);
-  assert(invProducerResultIndexMap &&
-         "expected producer result indexig map to be invertible");
+  // Ensure that the fusion does not remove size information required to
+  // get the loop bounds. For non-reduction generics, this is trivially the
+  // case due to the output operand. For reductions, we need to check that after
+  // the fusion, each loop dimension has at least one input that defines it.
+  if ((consumer.getNumReductionLoops())) {
+    llvm::BitVector coveredDims(consumer.getNumLoops(), false);
+
+    auto addToCoveredDims = [&](AffineMap map) {
+      for (auto result : map.getResults())
+        if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
+          coveredDims[dimExpr.getPosition()] = true;
+    };
 
-  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
-  // argMap is a map from producer loop -> producer arg tensor index.
-  AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
+    for (auto pair :
+         llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
+      Value operand = std::get<0>(pair);
+      if (operand == consumerOpOperand->get())
+        continue;
+      AffineMap operandMap = std::get<1>(pair);
+      addToCoveredDims(operandMap);
+    }
 
-  // Compose argMap with invProducerResultIndexMap to get a map from
-  // producer result tensor index -> producer arg tensor index.
-  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
+    for (OpOperand *operand : producer.getInputOperands()) {
+      AffineMap newIndexingMap =
+          getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+              operand, producerResultIndexMap, consumerIndexMap);
+      addToCoveredDims(newIndexingMap);
+    }
+    if (!coveredDims.all())
+      return false;
+  }
 
-  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
-  // consumer loop/ fused loop -> producer arg tensor index.
-  return t1.compose(fusedConsumerArgIndexMap);
+  return true;
 }
 
 /// Generate the region of the fused tensor operation. The region of the fused

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 3dea5367b619..652286f98184 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -907,3 +907,41 @@ func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2
   } -> tensor<3x2xf32>
   return %1 : tensor<3x2xf32>
 }
+
+// -----
+
+// Fusing the broadcast into a reduction would require to insert extra knowledge
+// about the size of the reduction dimension. As long, as this is not
+// implemented, we check that two linalg operations remain.
+// TODO: Support this case in element-wise fusion.
+
+#map0 = affine_map<(d0, d1) -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+#map2 = affine_map<(d0, d1) -> (d1, d0)>
+#map3 = affine_map<(d0, d1) -> (d0)>
+
+// CHECK-LABEL: @no_fusion_missing_reduction_shape
+// CHECK: linalg.generic
+// CHECK: linalg.generic
+func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> tensor<?xf32> {
+  %cst = arith.constant 0xFF800000 : f32
+  %4 = linalg.init_tensor [%arg1, %arg1] : tensor<?x?xf32>
+  %5 = linalg.generic {
+    indexing_maps = [#map0, #map1],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0 : tensor<f32>) outs(%4 : tensor<?x?xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
+    linalg.yield %arg2 : f32
+  } -> tensor<?x?xf32>
+  %6 = linalg.init_tensor [%arg1] : tensor<?xf32>
+  %7 = linalg.fill(%cst, %6) : f32, tensor<?xf32> -> tensor<?xf32>
+  %8 = linalg.generic {
+    indexing_maps = [#map2, #map3],
+    iterator_types = ["parallel", "reduction"]
+  } ins(%5 : tensor<?x?xf32>) outs(%7 : tensor<?xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
+    %9 = maxf %arg2, %arg3 : f32
+    linalg.yield %9 : f32
+  } -> tensor<?xf32>
+  return %8 : tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list