[Mlir-commits] [flang] [mlir] [mlir] Handle simple commutative cases in CSE. (PR #75274)
Jacques Pienaar
llvmlistbot at llvm.org
Wed Dec 13 05:29:57 PST 2023
https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/75274
>From d3cefdd3d0adbd4e307cb4b71c98b81fe49ce08b Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Tue, 12 Dec 2023 18:30:57 -0800
Subject: [PATCH 1/2] [mlir] Handle simple commutative cases in CSE.
Tried to keep this simple while handling obvious CSE instances. For more complicated cases the expectation is still that the sorting pass would run before. While simple, this case did turn up in a real deployed instance where it had a large (>10% e2e) impact. This can of course be refined.
---
mlir/include/mlir/IR/OperationSupport.h | 8 ++-
mlir/lib/IR/OperationSupport.cpp | 85 +++++++++++++++++++++----
2 files changed, 77 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564b..ba66dffeeb8e96 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1271,7 +1271,9 @@ struct OperationEquivalence {
isEquivalentTo(Operation *lhs, Operation *rhs,
function_ref<LogicalResult(Value, Value)> checkEquivalent,
function_ref<void(Value, Value)> markEquivalent = nullptr,
- Flags flags = Flags::None);
+ Flags flags = Flags::None,
+ function_ref<LogicalResult(ValueRange, ValueRange)>
+ checkCommutativeEquivalent = nullptr);
/// Compare two operations and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
@@ -1282,7 +1284,9 @@ struct OperationEquivalence {
Region *lhs, Region *rhs,
function_ref<LogicalResult(Value, Value)> checkEquivalent,
function_ref<void(Value, Value)> markEquivalent,
- OperationEquivalence::Flags flags);
+ OperationEquivalence::Flags flags,
+ function_ref<LogicalResult(ValueRange, ValueRange)>
+ checkCommutativeEquivalent = nullptr);
/// Compare two regions and return if they are equivalent.
static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108d..630a3bc5016fff 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -683,8 +683,16 @@ llvm::hash_code OperationEquivalence::computeHash(
hash = llvm::hash_combine(hash, op->getLoc());
// - Operands
- for (Value operand : op->getOperands())
- hash = llvm::hash_combine(hash, hashOperands(operand));
+ if (op->hasTrait<mlir::OpTrait::IsCommutative>() && op->getNumOperands() > 0) {
+ // If commutative, don't hash the operands as hash is not order independent
+ // and even if it were would not be sufficient for CSE usage.
+ // FIXME: This has the effect of resulting in more hash collisions
+ // for the sake of CSE, this could be improved.
+ hash = llvm::hash_combine(hash, op->getNumOperands());
+ } else {
+ for (Value operand : op->getOperands())
+ hash = llvm::hash_combine(hash, hashOperands(operand));
+ }
// - Results
for (Value result : op->getResults())
@@ -696,7 +704,9 @@ llvm::hash_code OperationEquivalence::computeHash(
Region *lhs, Region *rhs,
function_ref<LogicalResult(Value, Value)> checkEquivalent,
function_ref<void(Value, Value)> markEquivalent,
- OperationEquivalence::Flags flags) {
+ OperationEquivalence::Flags flags,
+ function_ref<LogicalResult(ValueRange, ValueRange)>
+ checkCommutativeEquivalent) {
DenseMap<Block *, Block *> blocksMap;
auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
// Check block arguments.
@@ -751,6 +761,36 @@ struct ValueEquivalenceCache {
return success(lhsValue == rhsValue ||
equivalentValues.lookup(lhsValue) == rhsValue);
}
+ LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
+ ValueRange rhsRange) {
+ // Handle simple case where sizes mismatch.
+ if (lhsRange.size() != rhsRange.size())
+ return failure();
+
+ // Handle where operands in order are equivalent.
+ auto lhsIt = lhsRange.begin();
+ auto rhsIt = rhsRange.begin();
+ for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
+ if (failed(checkEquivalent(*lhsIt, *rhsIt)))
+ break;
+ }
+ if (lhsIt == lhsRange.end())
+ return success();
+
+ // Handle another simple case where operands are just a permutation.
+ // Note: This is not sufficient, this handles simple cases relatively
+ // cheaply.
+ auto sortValues = [](ValueRange values) {
+ SmallVector<Value> sortedValues = llvm::to_vector(values);
+ llvm::sort(sortedValues, [](Value a, Value b) {
+ return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+ });
+ return sortedValues;
+ };
+ auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
+ auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
+ return success(lhsSorted == rhsSorted);
+ }
void markEquivalent(Value lhsResult, Value rhsResult) {
auto insertion = equivalentValues.insert({lhsResult, rhsResult});
// Make sure that the value was not already marked equivalent to some other
@@ -773,13 +813,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
[&](Value lhsResult, Value rhsResult) {
cache.markEquivalent(lhsResult, rhsResult);
},
- flags);
+ flags,
+ [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+ return cache.checkCommutativeEquivalent(lhs, rhs);
+ });
}
/*static*/ bool OperationEquivalence::isEquivalentTo(
Operation *lhs, Operation *rhs,
function_ref<LogicalResult(Value, Value)> checkEquivalent,
- function_ref<void(Value, Value)> markEquivalent, Flags flags) {
+ function_ref<void(Value, Value)> markEquivalent, Flags flags,
+ function_ref<LogicalResult(ValueRange, ValueRange)>
+ checkCommutativeEquivalent) {
if (lhs == rhs)
return true;
@@ -798,15 +843,24 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
return false;
// 2. Compare operands.
- for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
- Value curArg = std::get<0>(operandPair);
- Value otherArg = std::get<1>(operandPair);
- if (curArg == otherArg)
- continue;
- if (curArg.getType() != otherArg.getType())
- return false;
- if (failed(checkEquivalent(curArg, otherArg)))
+ if (checkCommutativeEquivalent &&
+ lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+ auto lhsRange = lhs->getOperands();
+ auto rhsRange = rhs->getOperands();
+ if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
return false;
+ } else {
+ // Check pair wise for equivalence.
+ for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
+ Value curArg = std::get<0>(operandPair);
+ Value otherArg = std::get<1>(operandPair);
+ if (curArg == otherArg)
+ continue;
+ if (curArg.getType() != otherArg.getType())
+ return false;
+ if (failed(checkEquivalent(curArg, otherArg)))
+ return false;
+ }
}
// 3. Compare result types and mark results as equivalent.
@@ -841,7 +895,10 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
[&](Value lhsResult, Value rhsResult) {
cache.markEquivalent(lhsResult, rhsResult);
},
- flags);
+ flags,
+ [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+ return cache.checkCommutativeEquivalent(lhs, rhs);
+ });
}
//===----------------------------------------------------------------------===//
>From 6fc2f9d23168503abf5778084b09d1cc614cf3da Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Wed, 13 Dec 2023 13:28:44 +0000
Subject: [PATCH 2/2] Format & enable now passing test.
---
flang/test/Fir/commute.fir | 3 ---
mlir/lib/IR/OperationSupport.cpp | 3 ++-
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/flang/test/Fir/commute.fir b/flang/test/Fir/commute.fir
index b39fe48145b3f3..be25fd61953deb 100644
--- a/flang/test/Fir/commute.fir
+++ b/flang/test/Fir/commute.fir
@@ -1,7 +1,4 @@
// RUN: fir-opt %s | tco | FileCheck %s
-//
-// XFAIL:*
-// See: https://github.com/llvm/llvm-project/issues/63784
// CHECK-LABEL: define i32 @f1(i32 %0, i32 %1)
func.func @f1(%a : i32, %b : i32) -> i32 {
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 630a3bc5016fff..14e21dc1d1a87c 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -683,7 +683,8 @@ llvm::hash_code OperationEquivalence::computeHash(
hash = llvm::hash_combine(hash, op->getLoc());
// - Operands
- if (op->hasTrait<mlir::OpTrait::IsCommutative>() && op->getNumOperands() > 0) {
+ if (op->hasTrait<mlir::OpTrait::IsCommutative>() &&
+ op->getNumOperands() > 0) {
// If commutative, don't hash the operands as hash is not order independent
// and even if it were would not be sufficient for CSE usage.
// FIXME: This has the effect of resulting in more hash collisions
More information about the Mlir-commits
mailing list