[Mlir-commits] [mlir] 344eee6 - [MLIR] Allow `Idempotent` trait to be applied to binary ops.
Mehdi Amini
llvmlistbot at llvm.org
Fri Nov 26 10:31:57 PST 2021
Author: Chris Jones
Date: 2021-11-26T18:22:49Z
New Revision: 344eee6f384caea3d64df28ef17f4204febc5e94
URL: https://github.com/llvm/llvm-project/commit/344eee6f384caea3d64df28ef17f4204febc5e94
DIFF: https://github.com/llvm/llvm-project/commit/344eee6f384caea3d64df28ef17f4204febc5e94.diff
LOG: [MLIR] Allow `Idempotent` trait to be applied to binary ops.
Add `Idempotent` trait to `arith.{andi,ori}`.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D114574
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/lib/IR/Operation.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/trait.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index f5857c90a5804..57459b518bf5c 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -408,7 +408,7 @@ def Arith_RemSIOp : Arith_IntBinaryOp<"remsi"> {
// AndIOp
//===----------------------------------------------------------------------===//
-def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative]> {
+def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative, Idempotent]> {
let summary = "integer binary and";
let description = [{
The `andi` operation takes two operands and returns one result, each of
@@ -436,7 +436,7 @@ def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative]> {
// OrIOp
//===----------------------------------------------------------------------===//
-def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative]> {
+def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative, Idempotent]> {
let summary = "integer binary or";
let description = [{
The `ori` operation takes two operands and returns one result, each of these
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 15d1ffd1c70f0..26212d397575e 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1945,7 +1945,7 @@ def ResultsBroadcastableShape :
NativeOpTrait<"ResultsBroadcastableShape">;
// X op Y == Y op X
def Commutative : NativeOpTrait<"IsCommutative">;
-// op op X == op X
+// op op X == op X (unary) / X op X == X (binary)
def Idempotent : NativeOpTrait<"IsIdempotent">;
// op op X == X
def Involution : NativeOpTrait<"IsInvolution">;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 5b4a936611895..2fc3cfbf08092 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1090,15 +1090,17 @@ class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
};
/// This class adds property that the operation is idempotent.
-/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
+/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
+/// or a binary operation "g" that satisfies g(x, x) = x.
template <typename ConcreteType>
class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(ConcreteType::template hasTrait<OneResult>(),
"expected operation to produce one result");
- static_assert(ConcreteType::template hasTrait<OneOperand>(),
- "expected operation to take one operand");
+ static_assert(ConcreteType::template hasTrait<OneOperand>() ||
+ ConcreteType::template hasTrait<NOperands<2>::Impl>(),
+ "expected operation to take one or two operands");
static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
"expected operation to preserve type");
// Idempotent requires the operation to be side effect free as well
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 29e938964363b..b65b152124dfa 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -494,9 +494,6 @@ OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
APInt intValue;
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
return getLhs();
- /// and(x, x) -> x
- if (getLhs() == getRhs())
- return getRhs();
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a & b; });
@@ -510,9 +507,6 @@ OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
/// or(x, 0) -> x
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
- /// or(x, x) -> x
- if (getLhs() == getRhs())
- return getRhs();
/// or(x, <all ones>) -> <all ones>
if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
if (rhsAttr.getValue().isAllOnes())
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index b1a23a225732b..164aeff04f677 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -629,9 +629,13 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
//===----------------------------------------------------------------------===//
OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
- auto *argumentOp = op->getOperand(0).getDefiningOp();
- if (argumentOp && op->getName() == argumentOp->getName()) {
- // Replace the outer operation output with the inner operation.
+ if (op->getNumOperands() == 1) {
+ auto *argumentOp = op->getOperand(0).getDefiningOp();
+ if (argumentOp && op->getName() == argumentOp->getName()) {
+ // Replace the outer operation output with the inner operation.
+ return op->getOperand(0);
+ }
+ } else if (op->getOperand(0) == op->getOperand(1)) {
return op->getOperand(0);
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 8ac9049ee4864..2bea95017fa5f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1001,6 +1001,13 @@ def TestIdempotentTraitOp
let results = (outs I32);
}
+def TestIdempotentTraitBinaryOp
+ : TEST_Op<"op_idempotent_trait_binary",
+ [SameOperandsAndResultType, NoSideEffect, Idempotent]> {
+ let arguments = (ins I32:$op1, I32:$op2);
+ let results = (outs I32);
+}
+
def TestInvolutionTraitNoOperationFolderOp
: TEST_Op<"op_involution_trait_no_operation_fold",
[SameOperandsAndResultType, NoSideEffect, Involution]> {
diff --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir
index cd2d8f430f3f1..c487585e9ecef 100644
--- a/mlir/test/mlir-tblgen/trait.mlir
+++ b/mlir/test/mlir-tblgen/trait.mlir
@@ -93,3 +93,11 @@ func @testTripleIdempotent(%arg0: i32) -> i32 {
// CHECK: return [[IDEMPOTENT]]
return %2: i32
}
+
+// CHECK-LABEL: func @testBinaryIdempotent
+// CHECK-SAME: ([[ARG0:%.+]]: i32)
+func @testBinaryIdempotent(%arg0 : i32) -> i32 {
+ %0 = "test.op_idempotent_trait_binary"(%arg0, %arg0) : (i32, i32) -> i32
+ // CHECK: return [[ARG0]]
+ return %0: i32
+}
More information about the Mlir-commits
mailing list