[Mlir-commits] [mlir] bcd6424 - [mlir][Linalg] Fix linalg on tensor fusion

Nicolas Vasilache llvmlistbot at llvm.org
Mon Mar 22 06:30:02 PDT 2021


Author: Nicolas Vasilache
Date: 2021-03-22T13:29:40Z
New Revision: bcd6424f9b693af57b29a0f03c52d6991be35d41

URL: https://github.com/llvm/llvm-project/commit/bcd6424f9b693af57b29a0f03c52d6991be35d41
DIFF: https://github.com/llvm/llvm-project/commit/bcd6424f9b693af57b29a0f03c52d6991be35d41.diff

LOG: [mlir][Linalg] Fix linalg on tensor fusion

- Drop unnecessary occurrences of rewriter.eraseOp: dead linalg ops on tensors should be cleaned up by DCE.
- reimplement the part of Linalg on fusion that constructs the body and block arguments: the previous implementation had too much magic. Instead this spells out all cases explicitly and asserts / introduces TODOs for incorrect cases.

As a consequence, we can use the default traversal order for this pattern.

Differential Revision: https://reviews.llvm.org/D99070

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 1e94dfd3ef94..a6d0fd5dd7b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -37,6 +37,11 @@ static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
   if (producer.getNumParallelLoops() != producer.getNumLoops())
     return false;
 
+  // Only allow fusing the producer of an input operand for now.
+  // TODO: allow fusing the producer of an output operand.
+  if (consumerIdx >= consumer.getNumInputs())
+    return false;
+
   // Get the consumer index map. The number of results of the consumer index
   // map must match the number of loops of the producer.
   AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
@@ -120,60 +125,86 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
        isa<IndexedGenericOp>(consumer.getOperation()))
           ? std::max(producer.getNumLoops(), consumer.getNumLoops())
           : 0;
-  // Firstly, add all the indices to the block arguments.
+
+  // 0. Firstly, add all the indices to the block arguments.
   for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
     fusedBlock->addArgument(rewriter.getIndexType());
-  // Map the arguments for the unmodified args from the consumer.
-  for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
-    if (consumerArg.index() == consumerIdx + numConsumerIndices) {
-      // Map the arguments for the args from the producer.
-      for (auto producerArg :
-           llvm::enumerate(producerBlock.getArguments().take_front(
-               producer.getNumInputs() + numProducerIndices))) {
-        // If producer is an indexed_generic op, map the indices from consumer
-        // loop to producer loop (because the fusedOp is built based on
-        // consumer's perspective).
-        if (producerArg.index() < numProducerIndices) {
-          auto newIndex = rewriter.create<mlir::AffineApplyOp>(
-              producer.getLoc(),
-              consumerToProducerLoopsMap.getSubMap(producerArg.index()),
-              fusedBlock->getArguments().take_front(numFusedOpIndices));
-          mapper.map(producerArg.value(), newIndex);
-        } else {
-          mapper.map(producerArg.value(),
-                     fusedBlock->addArgument(producerArg.value().getType()));
-        }
-      }
-      continue;
-    }
-
-    // If consumer is an indexed_generic op, map the indices to the block
-    // arguments directly. Otherwise, add the same type of argument and map to
-    // it.
-    if (consumerArg.index() < numConsumerIndices) {
-      mapper.map(consumerArg.value(),
-                 fusedBlock->getArgument(consumerArg.index()));
-    } else {
-      mapper.map(consumerArg.value(),
-                 fusedBlock->addArgument(consumerArg.value().getType()));
-    }
+  // 1. Map consumer indices to fusedBlock indices 1-1.
+  mapper.map(consumerBlock.getArguments().take_front(numConsumerIndices),
+             fusedBlock->getArguments().take_front(numConsumerIndices));
+  // 2. Embed producer indices into fusedBlock index space 1-1.
+  for (auto it :
+       llvm::zip(producerBlock.getArguments().take_front(numProducerIndices),
+                 fusedBlock->getArguments().take_front(numProducerIndices))) {
+    auto newIndex = rewriter.create<mlir::AffineApplyOp>(
+        producer.getLoc(),
+        consumerToProducerLoopsMap.getSubMap(std::get<0>(it).getArgNumber()),
+        fusedBlock->getArguments().take_front(numFusedOpIndices));
+    mapper.map(std::get<0>(it), newIndex);
   }
-
-  // Add operations from producer (except the yield operation) to the fused
+  // TODO: allow fusing the producer of an output operand.
+  assert(consumerIdx < consumer.getNumInputs() &&
+         "expected producer of input operand");
+  // 3. Consumer input operands up to consumerIdx (exclusive).
+  for (BlockArgument bbArg : consumerBlock.getArguments()
+                                 .drop_front(numConsumerIndices)
+                                 .take_front(consumerIdx)) // input assumption.
+    mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+
+  // Replacing consumerIdx requires getting the cloned, yielded, value from
+  // the (cloned) producer block. This happens in step 9.
+
+  // 4. Splice in producer's input operands.
+  for (BlockArgument bbArg : producerBlock.getArguments()
+                                 .drop_front(numProducerIndices)
+                                 .take_front(producer.getNumInputs()))
+    mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
+  for (BlockArgument bbArg : consumerBlock.getArguments()
+                                 .drop_front(numConsumerIndices)
+                                 .take_front(consumer.getNumInputs())
+                                 .drop_front(consumerIdx + 1))
+    mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+  // 6. All of consumer's output operands.
+  for (BlockArgument bbArg :
+       consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
+    mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+  // 7. All of producer's output operands except the one fused.
+  // TODO: allow fusion of multi-result producers.
+  assert(producer->getNumResults() == 1 && "expected single result producer");
+
+  // 8. Clone operations from producer (except the yield operation) to the fused
   // op.
-  for (auto &op : producerBlock.getOperations()) {
-    if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
-      // Lookup the value the yield operation is mapped to.
-      Value yieldVal = yieldOp.getOperand(0);
-      if (Value clonedVal = mapper.lookupOrNull(yieldVal))
-        mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
-                   clonedVal);
-      continue;
-    }
+  for (auto &op : producerBlock.without_terminator())
     rewriter.clone(op, mapper);
+  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
+  // forward the yield operand.
+  auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
+  // TODO: allow fusion of multi-result producers.
+  assert(producer->getNumResults() == 1 && "expected single result producer");
+  unsigned producerResultNumber = 0;
+  Value replacement =
+      mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
+  // Sanity checks, if replacement is not already in the mapper then it must be
+  // produced outside.
+  if (replacement == yieldOp.getOperand(producerResultNumber)) {
+    if (auto bb = replacement.dyn_cast<BlockArgument>())
+      assert(bb.getOwner() != &producerBlock &&
+             "yielded block argument must have been mapped");
+    else
+      assert(!producer->isAncestor(replacement.getDefiningOp()) &&
+             "yielded value must have been mapped");
   }
+  mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
+             replacement);
+  // 10. Clone operations from the consumer to the fused op.
   for (auto &op : consumerBlock.getOperations())
     rewriter.clone(op, mapper);
+
+  // Sanity checks.
+  assert(fusedBlock->getNumArguments() ==
+             fusedOp->getNumOperands() + numFusedOpIndices &&
+         "Ill-formed LinalgOp region");
 }
 
 static Optional<SmallVector<Value, 1>>
@@ -856,8 +887,6 @@ struct FoldProducerReshapeOpByLinearization
       op->setOperands(fusedOperands);
       op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps));
       rewriter.finalizeRootUpdate(op);
-      if (reshapeOp.use_empty())
-        rewriter.eraseOp(reshapeOp);
       return success();
     }
     return failure();
@@ -897,8 +926,6 @@ struct FoldWithProducerReshapeOpByExpansion
       if (!replacementValues)
         return failure();
       rewriter.replaceOp(genericOp, replacementValues.getValue());
-      if (reshapeOp.use_empty())
-        rewriter.eraseOp(reshapeOp);
       return success();
     }
     return failure();
@@ -963,8 +990,6 @@ struct FoldConsumerReshapeOpByLinearization
     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
                                fusedRegion.begin());
     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
-    if (producer.use_empty())
-      rewriter.eraseOp(producer);
     return success();
   }
 };
@@ -995,8 +1020,6 @@ struct FoldReshapeWithGenericOpByExpansion
     if (!replacementValues)
       return failure();
     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
-    if (producer.use_empty())
-      rewriter.eraseOp(producer);
     return success();
   }
 };
@@ -1057,8 +1080,6 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
       rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion,
                                  fusedRegion.begin(), mapping);
       rewriter.replaceOp(linalgOp, fusedOp->getResults());
-      if (constantOp.use_empty())
-        rewriter.eraseOp(constantOp);
       return success();
     }
     return failure();
@@ -1092,15 +1113,14 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
     for (OpOperand &opOperand : op.getShapedOpOperands()) {
-      Operation *producer = opOperand.get().getDefiningOp();
-      if (!producer)
+      LinalgOp producerOp =
+          dyn_cast_or_null<LinalgOp>(opOperand.get().getDefiningOp());
+      if (!producerOp || !producerOp.hasTensorSemantics())
         continue;
       Optional<SmallVector<Value, 1>> fusedOpResults =
           fuseTensorOps(rewriter, opOperand);
       if (fusedOpResults) {
         rewriter.replaceOp(op, *fusedOpResults);
-        if (producer->use_empty())
-          rewriter.eraseOp(producer);
         return success();
       }
     }
@@ -1115,8 +1135,7 @@ struct FusionOfTensorOpsPass
     Operation *op = getOperation();
     OwningRewritePatternList patterns(op->getContext());
     populateLinalgTensorOpsFusionPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
-                                       /*useTopDown=*/false);
+    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
   }
 };
 

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index a4071897b4d8..13109bd98c19 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -578,3 +578,41 @@ func @consumer_with_reduction(%arg0: tensor<1x10xf32>,
 //      CHECK:     %[[T4:.+]] = addf %[[T3]], %[[T2]] : f32
 //      CHECK:     linalg.yield %[[T4]]
 //      CHECK:   return %[[RES]]
+
+// -----
+
+// CHECK-LABEL: func @sigmoid_dynamic_dim(
+//       CHECK:   %[[RES:.*]] = linalg.generic
+//   CHECK-NOT:   linalg.generic
+//       CHECK:   return %[[RES]]
+func @sigmoid_dynamic_dim(%0: tensor<?x1xf32>) -> tensor<?x1xf32> {
+  %cp5 = constant 5.000000e-01 : f32
+  %c0 = constant 0 : index
+  %shape = shape.shape_of %0 : tensor<?x1xf32> -> tensor<?xindex>
+  %extend = shape.to_extent_tensor %shape : tensor<?xindex> -> tensor<2xindex>
+  %extracted = tensor.extract %extend[%c0] : tensor<2xindex>
+  %init0 = linalg.init_tensor [%extracted, 1] : tensor<?x1xf32>
+  %1 = linalg.generic {indexing_maps = [
+    affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+     outs(%init0 : tensor<?x1xf32>) {
+    ^bb0(%a: f32):  // no predecessors
+      linalg.yield %cp5 : f32
+  } -> tensor<?x1xf32>
+  %d0 = memref.dim %0, %c0 : tensor<?x1xf32>
+  %init1 = linalg.init_tensor [%d0, 1] : tensor<?x1xf32>
+  %2 = linalg.generic {indexing_maps = [
+    affine_map<(d0, d1) -> (d0, d1)>,
+    affine_map<(d0, d1) -> (d0, d1)>,
+    affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+      ins(%0, %1 : tensor<?x1xf32>, tensor<?x1xf32>)
+     outs(%init1 : tensor<?x1xf32>) {
+  ^bb0(%a: f32, %b: f32, %c: f32):  // no predecessors
+      %m = mulf %a, %b : f32
+      linalg.yield %m : f32
+  } -> tensor<?x1xf32>
+  return %2 : tensor<?x1xf32>
+}


        


More information about the Mlir-commits mailing list