[Mlir-commits] [mlir] fcaf6dd - [mlir][Transforms] CSE of ops with a single block.
Mahesh Ravishankar
llvmlistbot at llvm.org
Tue Nov 15 18:56:45 PST 2022
Author: Mahesh Ravishankar
Date: 2022-11-16T02:55:43Z
New Revision: fcaf6dd597ea93eb8f746dda236d859e071346c5
URL: https://github.com/llvm/llvm-project/commit/fcaf6dd597ea93eb8f746dda236d859e071346c5
DIFF: https://github.com/llvm/llvm-project/commit/fcaf6dd597ea93eb8f746dda236d859e071346c5.diff
LOG: [mlir][Transforms] CSE of ops with a single block.
Currently CSE does not support CSE of ops with regions. This patch
extends the CSE support to ops with a single region.
Differential Revision: https://reviews.llvm.org/D134306
Depends on D137857
Added:
Modified:
mlir/lib/IR/OperationSupport.cpp
mlir/lib/Transforms/CSE.cpp
mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
mlir/test/Transforms/cse.mlir
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index d46f1b46bf7b8..97d09eba7eb2f 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -721,16 +721,34 @@ bool OperationEquivalence::isEquivalentTo(
ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
- lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
- llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
- return a.getAsOpaquePointer() < b.getAsOpaquePointer();
- });
- lhsOperands = lhsOperandStorage;
+ auto sortValues = [](ValueRange values) {
+ SmallVector<Value> sortedValues = llvm::to_vector(values);
+ llvm::sort(sortedValues, [](Value a, Value b) {
+ auto aArg = a.dyn_cast<BlockArgument>();
+ auto bArg = b.dyn_cast<BlockArgument>();
+
+ // Case 1. Both `a` and `b` are `BlockArgument`s.
+ if (aArg && bArg) {
+ if (aArg.getParentBlock() == bArg.getParentBlock())
+ return aArg.getArgNumber() < bArg.getArgNumber();
+ return aArg.getParentBlock() < bArg.getParentBlock();
+ }
- rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
- llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
- return a.getAsOpaquePointer() < b.getAsOpaquePointer();
- });
+ // Case 2. One of then is a `BlockArgument` and other is not. Treat
+ // `BlockArgument` as lesser.
+ if (aArg && !bArg)
+ return true;
+ if (bArg && !aArg)
+ return false;
+
+ // Case 3. Both are values.
+ return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+ });
+ return sortedValues;
+ };
+ lhsOperandStorage = sortValues(lhsOperands);
+ lhsOperands = lhsOperandStorage;
+ rhsOperandStorage = sortValues(rhsOperands);
rhsOperands = rhsOperandStorage;
}
auto checkValueRangeMapping =
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3df419c01504a..97f6cfd20f28f 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -47,11 +47,70 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;
+
+ // If op has no regions, operation equivalence w.r.t operands alone is
+ // enough.
+ if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) {
+ return OperationEquivalence::isEquivalentTo(
+ const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
+ OperationEquivalence::exactValueMatch,
+ OperationEquivalence::ignoreValueEquivalence,
+ OperationEquivalence::IgnoreLocations);
+ }
+
+ // If lhs or rhs does not have a single region with a single block, they
+ // aren't CSEed for now.
+ if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 ||
+ !llvm::hasSingleElement(lhs->getRegion(0)) ||
+ !llvm::hasSingleElement(rhs->getRegion(0)))
+ return false;
+
+ // Compare the two blocks.
+ Block &lhsBlock = lhs->getRegion(0).front();
+ Block &rhsBlock = rhs->getRegion(0).front();
+
+ // Don't CSE if number of arguments
diff er.
+ if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
+ return false;
+
+ // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in
+ // `rhsBlock`. `Value`s from `lhsBlock` are the key.
+ DenseMap<Value, Value> areEquivalentValues;
+ for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(),
+ rhs->getRegion(0).getArguments())) {
+ areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs);
+ }
+
+ // Helper function to get the parent operation.
+ auto getParent = [](Value v) -> Operation * {
+ if (auto blockArg = v.dyn_cast<BlockArgument>())
+ return blockArg.getParentBlock()->getParentOp();
+ return v.getDefiningOp()->getParentOp();
+ };
+
+ // Callback to compare if operands of ops in the region of `lhs` and `rhs`
+ // are equivalent.
+ auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
+ if (lhsValue == rhsValue)
+ return success();
+ if (areEquivalentValues.lookup(lhsValue) == rhsValue)
+ return success();
+ return failure();
+ };
+
+ // Callback to compare if results of ops in the region of `lhs` and `rhs`
+ // are equivalent.
+ auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult {
+ if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) {
+ auto insertion = areEquivalentValues.insert({lhsResult, rhsResult});
+ return success(insertion.first->second == rhsResult);
+ }
+ return success();
+ };
+
return OperationEquivalence::isEquivalentTo(
const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
- /*mapOperands=*/OperationEquivalence::exactValueMatch,
- /*mapResults=*/OperationEquivalence::ignoreValueEquivalence,
- OperationEquivalence::IgnoreLocations);
+ mapOperands, mapResults, OperationEquivalence::IgnoreLocations);
}
};
} // namespace
@@ -204,7 +263,8 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
// Don't simplify operations with nested blocks. We don't currently model
// equality comparisons correctly among other things. It is also unclear
// whether we would want to CSE such operations.
- if (op->getNumRegions() != 0)
+ if (!(op->getNumRegions() == 0 ||
+ (op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0)))))
return failure();
// Some simple use case of operation with memory side-effect are dealt with
diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
index cc5f861482182..dede407eda201 100644
--- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
@@ -17,7 +17,6 @@
// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>)
// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64>
// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref<?xf64>
-// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[T6]] : memref<16xf64>)
// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>)
// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex>
// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]]
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index dbc2d5efb36ad..08429f7ec191d 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -322,3 +322,127 @@ func.func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
%3 = arith.muli %1, %2 : i32
return %3 : i32
}
+
+// Check that an operation with a single region can CSE.
+func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32):
+ test.region_yield %arg0 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ %1 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32):
+ test.region_yield %arg0 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @cse_single_block_ops
+// CHECK: %[[OP:.+]] = test.cse_of_single_block_op
+// CHECK-NOT: test.cse_of_single_block_op
+// CHECK: return %[[OP]], %[[OP]]
+
+// Operations with
diff erent number of bbArgs dont CSE.
+func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ test.region_yield %arg0 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ %1 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32):
+ test.region_yield %arg0 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @no_cse_varied_bbargs
+// CHECK: %[[OP0:.+]] = test.cse_of_single_block_op
+// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
+// CHECK: return %[[OP0]], %[[OP1]]
+
+// Operations with
diff erent regions dont CSE
+func.func @no_cse_region_
diff erence_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ test.region_yield %arg0 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ %1 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ test.region_yield %arg1 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @no_cse_region_
diff erence_simple
+// CHECK: %[[OP0:.+]] = test.cse_of_single_block_op
+// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
+// CHECK: return %[[OP0]], %[[OP1]]
+
+// Operation with identical region with multiple statements CSE.
+func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ %1 = arith.divf %arg0, %arg1 : f32
+ %2 = arith.remf %arg0, %c : f32
+ %3 = arith.select %d, %1, %2 : f32
+ test.region_yield %3 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ %1 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ %1 = arith.divf %arg0, %arg1 : f32
+ %2 = arith.remf %arg0, %c : f32
+ %3 = arith.select %d, %1, %2 : f32
+ test.region_yield %3 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @cse_single_block_ops_identical_bodies
+// CHECK: %[[OP:.+]] = test.cse_of_single_block_op
+// CHECK-NOT: test.cse_of_single_block_op
+// CHECK: return %[[OP]], %[[OP]]
+
+// Operation with non-identical regions dont CSE.
+func.func @no_cse_single_block_ops_
diff erent_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ %1 = arith.divf %arg0, %arg1 : f32
+ %2 = arith.remf %arg0, %c : f32
+ %3 = arith.select %d, %1, %2 : f32
+ test.region_yield %3 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ %1 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ %1 = arith.divf %arg0, %arg1 : f32
+ %2 = arith.remf %arg0, %c : f32
+ %3 = arith.select %d, %2, %1 : f32
+ test.region_yield %3 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @no_cse_single_block_ops_
diff erent_bodies
+// CHECK: %[[OP0:.+]] = test.cse_of_single_block_op
+// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
+// CHECK: return %[[OP0]], %[[OP1]]
+
+// Account for commutative ops within regions during CSE.
+func.func @cse_single_block_with_commutative_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ %1 = arith.addf %arg0, %arg1 : f32
+ %2 = arith.mulf %1, %c : f32
+ test.region_yield %2 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ %1 = test.cse_of_single_block_op inputs(%a, %b) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ %1 = arith.addf %arg1, %arg0 : f32
+ %2 = arith.mulf %c, %1 : f32
+ test.region_yield %2 : f32
+ } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @cse_single_block_with_commutative_ops
+// CHECK: %[[OP:.+]] = test.cse_of_single_block_op
+// CHECK-NOT: test.cse_of_single_block_op
+// CHECK: return %[[OP]], %[[OP]]
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 84dd37f03fa1c..cd447d7fbe97b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -670,8 +670,8 @@ def TestProducingBranchOp : TEST_Op<"producing_br",
// Produces an error value on the error path
def TestInternalBranchOp : TEST_Op<"internal_br",
- [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
- AttrSizedOperandSegments]> {
+ [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
+ AttrSizedOperandSegments]> {
let arguments = (ins Variadic<AnyType>:$successOperands,
Variadic<AnyType>:$errorOperands);
@@ -3045,4 +3045,19 @@ def RecursivelySpeculatableOp : TEST_Op<"recursively_speculatable_op", [
let regions = (region SizedRegion<1>:$body);
}
+//===---------------------------------------------------------------------===//
+// Test CSE
+//===---------------------------------------------------------------------===//
+
+def TestCSEOfSingleBlockOp : TEST_Op<"cse_of_single_block_op",
+ [SingleBlockImplicitTerminator<"RegionYieldOp">, Pure]> {
+ let arguments = (ins Variadic<AnyType>:$inputs);
+ let results = (outs Variadic<AnyType>:$outputs);
+ let regions = (region SizedRegion<1>:$region);
+ let assemblyFormat = [{
+ attr-dict `inputs` `(` $inputs `)`
+ $region `:` type($inputs) `->` type($outputs)
+ }];
+}
+
#endif // TEST_OPS
More information about the Mlir-commits
mailing list