[Mlir-commits] [mlir] ac0fe5d - [mlir][linalg] Remove unused payload related OutOpOperand

Stanley Winata llvmlistbot at llvm.org
Mon Oct 10 11:48:55 PDT 2022


Author: Stanley Winata
Date: 2022-10-10T11:45:46-07:00
New Revision: ac0fe5dd14734be850ffebea1574d46e393f429a

URL: https://github.com/llvm/llvm-project/commit/ac0fe5dd14734be850ffebea1574d46e393f429a
DIFF: https://github.com/llvm/llvm-project/commit/ac0fe5dd14734be850ffebea1574d46e393f429a.diff

LOG: [mlir][linalg] Remove unused payload related OutOpOperand

Some higher level operations such as torch.max generates linalg generic
that returns both the index and the value of the max operation. However
sometimes not all information is being used. This however blocks
vectorization for certain cases which causes performance degradation.
This patch aims to fix this issue.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D135388

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 47619ebf0b415..79961b72fcd3d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -857,6 +857,44 @@ void GenericOp::getEffects(
                         outputBuffers);
 }
 
+static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
+  if (!result.use_empty())
+    return false;
+  // If out operand not used in payload, we can drop it.
+  OpOperand *outputOpOperand =
+      genericOp.getOutputOperand(result.getResultNumber());
+  if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
+    return true;
+
+  // The out operand that is part of a payload can be dropped if
+  // these conditions are met:
+  // - Result from out operand is dead.
+  // - User of arg is yield.
+  // - outArg data is not being used by other outArgs.
+
+  // Check block arg and cycle from out operand has a single use.
+  BlockArgument outputArg =
+      genericOp.getRegionOutputArgs()[result.getResultNumber()];
+  if (!outputArg.hasOneUse())
+    return false;
+  Operation *argUserOp = *outputArg.user_begin();
+
+  // Check argUser has no other use.
+  if (!argUserOp->use_empty())
+    return false;
+
+  // Check that argUser is a yield.
+  auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
+  if (!yieldOp)
+    return false;
+
+  // Check outArg data is not being used by other outArgs.
+  if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
+    return false;
+
+  return true;
+}
+
 LogicalResult GenericOp::verify() { return success(); }
 
 namespace {
@@ -995,57 +1033,55 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
         newIndexingMaps.push_back(
             genericOp.getMatchingIndexingMap(outputOpOperand.value()));
       }
-    } else {
-      // Output argument can be dropped if the result has
-      // - no users, and
-      // - it is not used in the payload, and
-      // - the corresponding indexing maps are not needed for loop bound
-      //   computation.
-      auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
-      for (const auto &outputOpOperand :
-           llvm::enumerate(genericOp.getOutputOperands())) {
-        Value result = genericOp.getResult(outputOpOperand.index());
-        AffineMap indexingMap =
-            genericOp.getMatchingIndexingMap(outputOpOperand.value());
-        auto key =
-            std::make_tuple(outputOpOperand.value()->get(), indexingMap,
-                            yieldOp->getOperand(outputOpOperand.index()));
-
-        // Do not drop an out if its value is used in the payload.
-        if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
-          if (result.use_empty()) {
-            // Check if the opoperand can be dropped without affecting loop
-            // bound computation. Add the operand to the list of dropped op
-            // operand for checking. If it cannot be dropped, need to pop the
-            // value back.
-            droppedOpOperands.push_back(outputOpOperand.value());
-            if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
-              continue;
-            }
-            droppedOpOperands.pop_back();
-          }
-
-          // The out operand can also be dropped if it is computed redundantly
-          // by another result, the conditions for that are
-          // - The same operand is used as the out operand
-          // - The same indexing map is used
-          // - The same yield value is used.
-          auto it = dedupedOutpts.find(key);
-          if (it != dedupedOutpts.end()) {
-            origToNewPos[outputOpOperand.index()] = it->second;
-            droppedOpOperands.push_back(outputOpOperand.value());
-            continue;
-          }
+      return origToNewPos;
+    }
+    // Output argument can be dropped if the result has
+    // - no users, and
+    // - it is not used in the payload, and
+    // - the corresponding indexing maps are not needed for loop bound
+    //   computation.
+    auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
+    for (const auto &outputOpOperand :
+         llvm::enumerate(genericOp.getOutputOperands())) {
+      OpResult result = genericOp.getTiedOpResult(outputOpOperand.value());
+      AffineMap indexingMap =
+          genericOp.getMatchingIndexingMap(outputOpOperand.value());
+      auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap,
+                                 yieldOp->getOperand(outputOpOperand.index()));
+      assert(genericOp.getNumOutputs() >= outputOpOperand.index() &&
+             "Output op idx greater than number of outputs.");
+      if (isResultValueDead(genericOp, result)) {
+        // Check if the opoperand can be dropped without affecting loop
+        // bound computation. Add the operand to the list of dropped op
+        // operand for checking. If it cannot be dropped, need to pop the
+        // value back.
+        droppedOpOperands.push_back(outputOpOperand.value());
+        if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
+          continue;
         }
+        droppedOpOperands.pop_back();
+      }
 
-        origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
-        dedupedOutpts[key] = newOutputOperands.size();
-        newOutputOperands.push_back(outputOpOperand.value()->get());
-        newIndexingMaps.push_back(
-            genericOp.getMatchingIndexingMap(outputOpOperand.value()));
+      if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
+        // The out operand can also be dropped if it is computed redundantly
+        // by another result, the conditions for that are
+        // - The same operand is used as the out operand
+        // - The same indexing map is used
+        // - The same yield value is used.
+        auto it = dedupedOutpts.find(key);
+        if (it != dedupedOutpts.end()) {
+          origToNewPos[outputOpOperand.index()] = it->second;
+          droppedOpOperands.push_back(outputOpOperand.value());
+          continue;
+        }
       }
-    }
 
+      origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
+      dedupedOutpts[key] = newOutputOperands.size();
+      newOutputOperands.push_back(outputOpOperand.value()->get());
+      newIndexingMaps.push_back(
+          genericOp.getMatchingIndexingMap(outputOpOperand.value()));
+    }
     return origToNewPos;
   }
 
@@ -1085,12 +1121,10 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
     updateReplacements(origOutputOperands, newOutputOperands,
                        origOutsToNewOutsPos);
 
-    rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
-
     // Drop the unused yield args.
     if (newOp.getNumOutputs() != genericOp.getNumOutputs()) {
       OpBuilder::InsertionGuard g(rewriter);
-      YieldOp origYieldOp = cast<YieldOp>(newOpBlock->getTerminator());
+      YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
       rewriter.setInsertionPoint(origYieldOp);
 
       SmallVector<Value> newYieldVals(newOp.getNumOutputs(), nullptr);
@@ -1103,6 +1137,8 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
       }
       rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
     }
+
+    rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
   }
 };
 
@@ -1178,13 +1214,75 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
+
+/// Remove unused cycles.
+/// We can remove unused cycle within a payload of generic region
+/// if these conditions are met:
+/// - Result from out operand is dead.
+/// - Block arg from out operand has a single use in the %cycle
+/// instruction.
+/// - Cycle has a single use and it is in yield.
+struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+
+    // If the op doesnt have tensor semantics, preserve the outputs as is.
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+
+    bool hasRemovedCycles = false;
+    // Iterate over output operands and remove any unused cycles.
+    for (const auto &outputOpOperand :
+         llvm::enumerate(genericOp.getOutputOperands())) {
+
+      // Check that result from out operand is dead.
+      Value result = genericOp.getResult(outputOpOperand.index());
+      if (!result.use_empty())
+        continue;
+
+      // Check that outputArg has one use in cycle.
+      BlockArgument outputArg =
+          genericOp.getRegionOutputArgs()[outputOpOperand.index()];
+      if (!outputArg.hasOneUse())
+        continue;
+
+      // Check cycle has at most one use.
+      Operation *cycleOp = *outputArg.user_begin();
+      if (!cycleOp->hasOneUse())
+        continue;
+
+      // Check that the cycleUser is a yield.
+      Operation *cycleUserOp = *cycleOp->user_begin();
+      if (!isa<linalg::YieldOp>(cycleUserOp))
+        continue;
+
+      // Check that argIndex matches yieldIndex, else data is being used.
+      if (cycleUserOp->getOperand(outputOpOperand.index()) !=
+          cycleOp->getResult(0))
+        continue;
+
+      // Directly replace the cycle with the blockArg such that
+      // Deduplicate pattern can eliminate it along with unused yield.
+      rewriter.replaceOp(cycleOp, outputArg);
+      rewriter.updateRootInPlace(genericOp, [] {});
+      hasRemovedCycles = true;
+    }
+
+    if (hasRemovedCycles) {
+      return success();
+    }
+
+    return failure();
+  }
+};
 } // namespace
 
 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results
-      .add<DeduplicateAndRemoveDeadOperandsAndResults, EraseIdentityGenericOp>(
-          context);
+  results.add<DeduplicateAndRemoveDeadOperandsAndResults,
+              EraseIdentityGenericOp, RemoveUnusedCycleInGenericOp>(context);
 }
 
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
index 108d870cb581f..a0950017662a9 100644
--- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
@@ -286,3 +286,162 @@ func.func @drop_redundant_results(
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       outs(%[[ARG0]] :
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+// Drop dead result with 
diff erent tensors.
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+func.func @drop_dead_results_with_
diff erent_tensors(%arg0 : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %c1 = arith.constant 1 : index
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %c2 = arith.constant 2 : index
+  %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %init0 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32>
+  %0:4 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2, #map3, #map4],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%arg0 : tensor<?x?x?xf32>)
+      outs(%arg0, %arg0, %init0, %init0
+          : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) :
+      linalg.yield %b0, %b0, %b3, %b4 : f32, f32, f32, f32
+    } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return %0#0, %0#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>     
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+//      CHECK: func @drop_dead_results_with_
diff erent_tensors(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>)
+//      CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:       outs(%[[ARG0]], %[[ARG0]] :
+//      CHECK:   return %[[GENERIC]]#0, %[[GENERIC]]#1
+
+// -----
+
+// Drop dead result with unused cycles.
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+func.func @drop_dead_results_with_unused_cycles(%arg0 : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %c1 = arith.constant 1 : index
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %c2 = arith.constant 2 : index
+  %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %init0 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32>
+  %0:4 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2, #map3, #map4],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%arg0 : tensor<?x?x?xf32>)
+      outs(%arg0, %arg0, %init0, %init0
+          : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) :
+      %1 = arith.addf %b0, %b0: f32
+      %2 = arith.addf %b0, %b3: f32
+      %3 = arith.addf %b0, %b4: f32
+      linalg.yield %1, %1, %2, %3 : f32, f32, f32, f32
+    } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return %0#0, %0#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>     
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+//      CHECK: func @drop_dead_results_with_unused_cycles(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>)
+//      CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:       outs(%[[ARG0]], %[[ARG0]] :
+//      CHECK:   return %[[GENERIC]]#0, %[[GENERIC]]#1
+
+// -----
+
+// Drop only the results not used by others.
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+func.func @drop_only_the_results_not_used_by_others(%arg0 : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %c1 = arith.constant 1 : index
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %c2 = arith.constant 2 : index
+  %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %init0 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32>
+  %0:3 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2, #map3],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%arg0 : tensor<?x?x?xf32>)
+      outs(%arg0, %init0, %init0
+          : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32) :
+      linalg.yield %b2, %b1, %b3 : f32, f32, f32
+    } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return %0#0 : tensor<?x?x?xf32>
+}
+
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+//      CHECK: func @drop_only_the_results_not_used_by_others(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty
+//      CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:       outs(%[[ARG0]], %[[INIT]] :
+//      CHECK:   return %[[GENERIC]]#0
+
+// -----
+
+// Drop only the cycles not used by others.
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+func.func @drop_only_the_cycles_not_used_by_others(%arg0 : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %c1 = arith.constant 1 : index
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %c2 = arith.constant 2 : index
+  %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %init0 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32>
+  %0:3 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2, #map3],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%arg0 : tensor<?x?x?xf32>)
+      outs(%arg0, %init0, %init0
+          : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32) :
+      %1 = arith.addf %b1, %b2: f32
+      %2 = arith.addf %b1, %b3 : f32
+      linalg.yield %1, %b1, %2 : f32, f32, f32
+    } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return %0#0 : tensor<?x?x?xf32>
+}
+
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+//      CHECK: func @drop_only_the_cycles_not_used_by_others(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty
+//      CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:       outs(%[[ARG0]], %[[INIT]] :
+//      CHECK:   return %[[GENERIC]]#0


        


More information about the Mlir-commits mailing list