[Mlir-commits] [mlir] 518e6f3 - [mlir][Linalg] Fix fusion on tensors operands / bbArg mismatch

Nicolas Vasilache llvmlistbot at llvm.org
Tue Apr 6 08:42:01 PDT 2021


Author: Nicolas Vasilache
Date: 2021-04-06T15:39:40Z
New Revision: 518e6f341dddab7824592b3769146318950a01be

URL: https://github.com/llvm/llvm-project/commit/518e6f341dddab7824592b3769146318950a01be
DIFF: https://github.com/llvm/llvm-project/commit/518e6f341dddab7824592b3769146318950a01be.diff

LOG: [mlir][Linalg] Fix fusion on tensors operands / bbArg mismatch

Linalg fusion on tensors has mismatching assumptions on the operand side than on the region bbArg side.
Relax the behavior on the operand/indexing map side so that we better support output operands that may also be read from.

Differential revision: https://reviews.llvm.org/D99499

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 4fdd9a2221f0..95e008aacc45 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -896,7 +896,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*desc=*/[{
         Return the indexing maps within the current operation.
       }],
-      /*retTy=*/"SmallVector<AffineMap, 4>",
+      /*retTy=*/"SmallVector<AffineMap>",
       /*methodName=*/"getIndexingMaps",
       /*args=*/(ins),
       /*methodBody=*/"",
@@ -931,6 +931,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return getIndexingMaps()[i];
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the input indexing maps.
+      }],
+      /*retTy=*/"SmallVector<AffineMap>",
+      /*methodName=*/"getInputIndexingMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto maps = $_op.getIndexingMaps();
+        return SmallVector<AffineMap>{maps.begin(),
+                                      maps.begin() + $_op.getNumInputs()};
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the output indexing map at index `i`.
@@ -944,6 +958,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return getIndexingMaps()[i + $_op.getNumInputs()];
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the output indexing maps.
+      }],
+      /*retTy=*/"SmallVector<AffineMap>",
+      /*methodName=*/"getOutputIndexingMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto maps = $_op.getIndexingMaps();
+        return SmallVector<AffineMap>{maps.begin() + $_op.getNumInputs(),
+                                      maps.begin() + $_op.getNumShapedOperands()};
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return whether the op has only MemRef input and outputs.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index bb1a051c78e5..34eac4bdfcaa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -61,16 +61,14 @@ static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
 /// the `producer` to use in the fused operation given the indexing map of the
 /// result of the producer in the consumer.
-static void getIndexingMapOfProducerOperandsInFusedOp(
-    LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
-    SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
+static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+    OpOperand &producerOpOperand, AffineMap producerResultIndexMap,
+    AffineMap fusedConsumerArgIndexMap) {
   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
   // from consumer loop -> consumer arg tensor index/producer result tensor
   // index. The fused loop is same as the consumer loop. For each producer arg
   // the indexing map to be computed is a map from consumer loop -> producer
   // arg tensor index.
-
-  AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
   // producerResultIndexMap is a map from producer loop -> tensor index.
   // Compute the inverse to get map from tensor index -> producer loop.
   // The inverse is a map from producer result tensor index -> producer loop.
@@ -78,19 +76,19 @@ static void getIndexingMapOfProducerOperandsInFusedOp(
       inversePermutation(producerResultIndexMap);
   assert(invProducerResultIndexMap &&
          "expected producer result indexig map to be invertible");
-  for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
-    // argMap is a map from producer loop -> producer arg tensor index.
-    AffineMap argMap = producer.getInputIndexingMap(argNum);
-
-    // Compose argMap with invProducerResultIndexMap to get a map from
-    // producer result tensor index -> producer arg tensor index.
-    AffineMap t1 = argMap.compose(invProducerResultIndexMap);
-
-    // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
-    // consumer loop/ fused loop -> producer arg tensor index.
-    AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
-    fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
-  }
+
+  LinalgOp producer = cast<LinalgOp>(producerOpOperand.getOwner());
+  // argMap is a map from producer loop -> producer arg tensor index.
+  AffineMap argMap =
+      producer.getIndexingMap(producerOpOperand.getOperandNumber());
+
+  // Compose argMap with invProducerResultIndexMap to get a map from
+  // producer result tensor index -> producer arg tensor index.
+  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
+
+  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
+  // consumer loop/ fused loop -> producer arg tensor index.
+  return t1.compose(fusedConsumerArgIndexMap);
 }
 
 /// Generate the region of the fused tensor operation. The region of the fused
@@ -163,6 +161,18 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
                                  .drop_front(numProducerIndices)
                                  .take_front(producer.getNumInputs()))
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+
+  // 4.b. Producer output operand/map that is fused needs to be mapped to the
+  // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
+  assert(producer->getNumResults() == 1 && "expected single result producer");
+  if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
+    BlockArgument bbArg =
+        producerBlock.getArguments()
+            .drop_front(numConsumerIndices + producer.getNumInputs())
+            // TODO: bbArg index of
+            .front();
+    mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+  }
   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
   for (BlockArgument bbArg : consumerBlock.getArguments()
                                  .drop_front(numConsumerIndices)
@@ -221,73 +231,90 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
       !controlFn(producer->getResult(0), consumerOpOperand))
     return llvm::None;
 
-  unsigned numFusedOperands =
-      producer.getNumInputs() + consumer.getNumInputs() - 1;
-
-  // Compute the fused operands list,
-  SmallVector<Value, 2> fusedOperands;
-  fusedOperands.reserve(numFusedOperands);
-  auto consumerOperands = consumer.getInputs();
-  auto producerOperands = producer.getInputs();
-  fusedOperands.assign(consumerOperands.begin(),
-                       std::next(consumerOperands.begin(), consumerIdx));
-  fusedOperands.append(producerOperands.begin(), producerOperands.end());
-  fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
-                       consumerOperands.end());
-
-  // Compute indexing_maps for the fused operation. The indexing_maps for the
-  // operands of the consumers that aren't fused are the same. The
-  // indexing_maps for the producers need to be computed based on the
-  // indexing_map of the operand at consumerIdx in the consumer.
-  SmallVector<Attribute, 4> fusedIndexMaps;
-  auto consumerIndexMaps = consumer.indexing_maps();
-  fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs());
-  fusedIndexMaps.assign(consumerIndexMaps.begin(),
-                        std::next(consumerIndexMaps.begin(), consumerIdx));
-  // Compute indexing maps for the producer args in the fused operation.
-  getIndexingMapOfProducerOperandsInFusedOp(
-      producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
-
-  // Append the indexing maps for the remaining consumer operands.
-  fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
-                        consumerIndexMaps.end());
+  // TODO: allow fusing the producer of an output operand.
+  assert(consumerIdx < consumer.getNumInputs() &&
+         "expected producer of input operand");
+
+  // Compute the fused operands list and indexing maps.
+  SmallVector<Value> fusedOperands;
+  SmallVector<AffineMap> fusedIndexMaps;
+  fusedOperands.reserve(producer->getNumOperands() +
+                        consumer->getNumOperands());
+  fusedIndexMaps.reserve(producer->getNumOperands() +
+                         consumer->getNumOperands());
+  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
+  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
+  llvm::append_range(fusedOperands,
+                     consumer.getInputs().take_front(consumerIdx));
+  llvm::append_range(
+      fusedIndexMaps,
+      ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.take_front(
+          consumerIdx));
+  // 4. Splice in producer's input operands/maps.
+  llvm::append_range(fusedOperands, producer.getInputs());
+  assert(producer->getNumResults() == 1 && "expected single result producer");
+  AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+  for (auto &inputOpOperand : producer.getInputOpOperands()) {
+    // Compute indexing maps for the producer args in the fused operation.
+    AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+        inputOpOperand, producerResultIndexMap,
+        consumer.getInputIndexingMap(consumerIdx));
+    fusedIndexMaps.push_back(map);
+  }
+  // 4.b. Producer output operand/map that is fused needs to be passed if it is
+  // an "initTensor" (i.e. its value is actually read).
+  assert(producer->getNumResults() == 1 && "expected single result producer");
+  if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
+    llvm::append_range(fusedOperands, producer.getOutputs().take_front());
+    // Compute indexing maps for the producer args in the fused operation.
+    AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+        producer.getOutputOpOperands().front(), producerResultIndexMap,
+        consumer.getOutputIndexingMap(0));
+    fusedIndexMaps.push_back(map);
+  }
+  // 5. Remaining consumer's input operands/maps (drop past index
+  // `consumerIdx`).
+  llvm::append_range(fusedOperands,
+                     consumer.getInputs().drop_front(consumerIdx + 1));
+  llvm::append_range(
+      fusedIndexMaps,
+      ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.drop_front(
+          consumerIdx + 1));
+  // 6. All of consumer's output operands (skip operands: added by the builder).
+  // llvm::append_range(fusedOperands, consumer.getOutputs());
+  llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps());
+  // 7. All of producer's output operands/maps except the one fused.
+  // TODO: allow fusion of multi-result producers.
+  assert(producer->getNumResults() == 1 && "expected single result producer");
 
   // Generate the fused op.
-  LinalgOp fusedOp;
+  Operation *fusedOp;
   if (isa<GenericOp>(producer.getOperation()) &&
       isa<GenericOp>(consumer.getOperation())) {
-    fusedOp =
-        rewriter
-            .create<GenericOp>(consumer.getLoc(), consumer->getResultTypes(),
-                               /*inputs=*/fusedOperands,
-                               // TODO: handle outputs.
-                               consumer.getOutputs(),
-                               rewriter.getArrayAttr(fusedIndexMaps),
-                               consumer.iterator_types(),
-                               /*doc=*/nullptr,
-                               /*library_call=*/nullptr,
-                               /*sparse=*/nullptr)
-            .getOperation();
+    fusedOp = rewriter.create<GenericOp>(
+        consumer.getLoc(), consumer->getResultTypes(),
+        /*inputs=*/fusedOperands,
+        // TODO: handle outputs.
+        consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+        consumer.iterator_types(),
+        /*doc=*/nullptr,
+        /*library_call=*/nullptr,
+        /*sparse=*/nullptr);
   } else {
-    fusedOp =
-        rewriter
-            .create<IndexedGenericOp>(
-                consumer.getLoc(), consumer->getResultTypes(),
-                /*inputs=*/fusedOperands,
-                // TODO: handle outputs.
-                consumer.getOutputs(), rewriter.getArrayAttr(fusedIndexMaps),
-                consumer.iterator_types(),
-                /*doc=*/nullptr,
-                /*library_call=*/nullptr,
-                /*sparse=*/nullptr)
-            .getOperation();
+    fusedOp = rewriter.create<IndexedGenericOp>(
+        consumer.getLoc(), consumer->getResultTypes(),
+        /*inputs=*/fusedOperands,
+        // TODO: handle outputs.
+        consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+        consumer.iterator_types(),
+        /*doc=*/nullptr,
+        /*library_call=*/nullptr,
+        /*sparse=*/nullptr);
   }
 
   // Construct an AffineMap from consumer loops to producer loops.
   // consumer loop -> tensor index
   AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
-  // producer loop -> tensor index
-  AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
   // tensor index -> producer loop
   AffineMap invProducerResultIndexMap =
       inversePermutation(producerResultIndexMap);
@@ -297,9 +324,9 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
   AffineMap consumerToProducerLoopsMap =
       invProducerResultIndexMap.compose(consumerResultIndexMap);
 
-  generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
-                                   consumer, consumerToProducerLoopsMap,
-                                   consumerIdx, consumer.getNumLoops());
+  generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
+                              consumerToProducerLoopsMap, consumerIdx,
+                              consumer.getNumLoops());
   return SmallVector<Value, 1>(fusedOp->getResults());
 }
 

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 13109bd98c19..b0a006398c99 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -616,3 +616,39 @@ func @sigmoid_dynamic_dim(%0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   } -> tensor<?x1xf32>
   return %2 : tensor<?x1xf32>
 }
+
+// -----
+
+func private @compute1(%a: f64) -> f64
+func private @compute2(%a: f64, %b: i32) -> i32
+
+// CHECK-LABEL: func @generic_index_op2(
+func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tensor<1x8xi32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]}
+  outs(%arg0 : tensor<1x8xf64>) {
+  ^bb0(%a: f64):
+    %r = call @compute1(%a) : (f64) -> f64
+    linalg.yield %r : f64
+  } -> tensor<1x8xf64>
+
+  // CHECK-NEXT:   %[[R:.*]] = linalg.generic
+  //      CHECK:     bb0(%[[BBA:[0-9a-z]*]]: f64, %[[BBB:[0-9a-z]*]]: i32):
+  // CHECK-NEXT:       %[[A:.*]] = call @compute1(%[[BBA]]) : (f64) -> f64
+  // CHECK-NEXT:       %[[B:.*]] = call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32
+  // CHECK-NEXT:       linalg.yield %[[B]] : i32
+  // CHECK-NEXT:   } -> tensor<1x8xi32>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]}
+  ins(%0 : tensor<1x8xf64>)
+  outs(%arg1 : tensor<1x8xi32>) {
+  ^bb0(%a: f64, %b: i32):
+    %r = call @compute2(%a, %b) : (f64, i32) -> i32
+    linalg.yield %r : i32
+  } -> tensor<1x8xi32>
+
+  // CHECK-NEXT:   return %[[R]] : tensor<1x8xi32>
+  return %1 : tensor<1x8xi32>
+}


        


More information about the Mlir-commits mailing list