[Mlir-commits] [mlir] 344eee6 - [MLIR] Allow `Idempotent` trait to be applied to binary ops.

Mehdi Amini llvmlistbot at llvm.org
Fri Nov 26 10:31:57 PST 2021


Author: Chris Jones
Date: 2021-11-26T18:22:49Z
New Revision: 344eee6f384caea3d64df28ef17f4204febc5e94

URL: https://github.com/llvm/llvm-project/commit/344eee6f384caea3d64df28ef17f4204febc5e94
DIFF: https://github.com/llvm/llvm-project/commit/344eee6f384caea3d64df28ef17f4204febc5e94.diff

LOG: [MLIR] Allow `Idempotent` trait to be applied to binary ops.

Add `Idempotent` trait to `arith.{andi,ori}`.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D114574

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/lib/IR/Operation.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/trait.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index f5857c90a5804..57459b518bf5c 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -408,7 +408,7 @@ def Arith_RemSIOp : Arith_IntBinaryOp<"remsi"> {
 // AndIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative]> {
+def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative, Idempotent]> {
   let summary = "integer binary and";
   let description = [{
     The `andi` operation takes two operands and returns one result, each of
@@ -436,7 +436,7 @@ def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative]> {
 // OrIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative]> {
+def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative, Idempotent]> {
   let summary = "integer binary or";
   let description = [{
     The `ori` operation takes two operands and returns one result, each of these

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 15d1ffd1c70f0..26212d397575e 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1945,7 +1945,7 @@ def ResultsBroadcastableShape :
   NativeOpTrait<"ResultsBroadcastableShape">;
 // X op Y == Y op X
 def Commutative  : NativeOpTrait<"IsCommutative">;
-// op op X == op X
+// op op X == op X (unary) / X op X == X (binary)
 def Idempotent  : NativeOpTrait<"IsIdempotent">;
 // op op X == X
 def Involution  : NativeOpTrait<"IsInvolution">;

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 5b4a936611895..2fc3cfbf08092 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1090,15 +1090,17 @@ class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
 };
 
 /// This class adds property that the operation is idempotent.
-/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
+/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
+/// or a binary operation "g" that satisfies g(x, x) = x.
 template <typename ConcreteType>
 class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
 public:
   static LogicalResult verifyTrait(Operation *op) {
     static_assert(ConcreteType::template hasTrait<OneResult>(),
                   "expected operation to produce one result");
-    static_assert(ConcreteType::template hasTrait<OneOperand>(),
-                  "expected operation to take one operand");
+    static_assert(ConcreteType::template hasTrait<OneOperand>() ||
+                      ConcreteType::template hasTrait<NOperands<2>::Impl>(),
+                  "expected operation to take one or two operands");
     static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
                   "expected operation to preserve type");
     // Idempotent requires the operation to be side effect free as well

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 29e938964363b..b65b152124dfa 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -494,9 +494,6 @@ OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
   APInt intValue;
   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
     return getLhs();
-  /// and(x, x) -> x
-  if (getLhs() == getRhs())
-    return getRhs();
 
   return constFoldBinaryOp<IntegerAttr>(operands,
                                         [](APInt a, APInt b) { return a & b; });
@@ -510,9 +507,6 @@ OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
   /// or(x, 0) -> x
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
-  /// or(x, x) -> x
-  if (getLhs() == getRhs())
-    return getRhs();
   /// or(x, <all ones>) -> <all ones>
   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
     if (rhsAttr.getValue().isAllOnes())

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index b1a23a225732b..164aeff04f677 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -629,9 +629,13 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
-  auto *argumentOp = op->getOperand(0).getDefiningOp();
-  if (argumentOp && op->getName() == argumentOp->getName()) {
-    // Replace the outer operation output with the inner operation.
+  if (op->getNumOperands() == 1) {
+    auto *argumentOp = op->getOperand(0).getDefiningOp();
+    if (argumentOp && op->getName() == argumentOp->getName()) {
+      // Replace the outer operation output with the inner operation.
+      return op->getOperand(0);
+    }
+  } else if (op->getOperand(0) == op->getOperand(1)) {
     return op->getOperand(0);
   }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 8ac9049ee4864..2bea95017fa5f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1001,6 +1001,13 @@ def TestIdempotentTraitOp
   let results = (outs I32);
 }
 
+def TestIdempotentTraitBinaryOp
+    : TEST_Op<"op_idempotent_trait_binary",
+              [SameOperandsAndResultType, NoSideEffect, Idempotent]> {
+  let arguments = (ins I32:$op1, I32:$op2);
+  let results = (outs I32);
+}
+
 def TestInvolutionTraitNoOperationFolderOp
  : TEST_Op<"op_involution_trait_no_operation_fold",
            [SameOperandsAndResultType, NoSideEffect, Involution]> {

diff  --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir
index cd2d8f430f3f1..c487585e9ecef 100644
--- a/mlir/test/mlir-tblgen/trait.mlir
+++ b/mlir/test/mlir-tblgen/trait.mlir
@@ -93,3 +93,11 @@ func @testTripleIdempotent(%arg0: i32) -> i32 {
   // CHECK: return [[IDEMPOTENT]]
   return %2: i32
 }
+
+// CHECK-LABEL: func @testBinaryIdempotent
+// CHECK-SAME:  ([[ARG0:%.+]]: i32)
+func @testBinaryIdempotent(%arg0 : i32) -> i32 {
+  %0 = "test.op_idempotent_trait_binary"(%arg0, %arg0) : (i32, i32) -> i32
+  // CHECK: return [[ARG0]]
+  return %0: i32
+}


        


More information about the Mlir-commits mailing list