[Mlir-commits] [mlir] 7dff6b8 - [MLIR] Add idempotent trait folding

Andy Ly llvmlistbot at llvm.org
Fri Oct 16 08:51:19 PDT 2020


Author: ahmedsabie
Date: 2020-10-16T15:51:04Z
New Revision: 7dff6b818b1cdd52fbc99f6256760d6eb02a7622

URL: https://github.com/llvm/llvm-project/commit/7dff6b818b1cdd52fbc99f6256760d6eb02a7622
DIFF: https://github.com/llvm/llvm-project/commit/7dff6b818b1cdd52fbc99f6256760d6eb02a7622.diff

LOG: [MLIR] Add idempotent trait folding

This trait simply adds a fold of f(f(x)) = f(x) when an operation is labelled as idempotent

Reviewed By: rriddle, andyly

Differential Revision: https://reviews.llvm.org/D89421

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/IR/Operation.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/trait.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 72b3b1ab41f5..e09c18c0e1a1 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1720,6 +1720,8 @@ def ResultsBroadcastableShape :
   NativeOpTrait<"ResultsBroadcastableShape">;
 // X op Y == Y op X
 def Commutative  : NativeOpTrait<"IsCommutative">;
+// op op X == op X
+def Idempotent  : NativeOpTrait<"IsIdempotent">;
 // op op X == X
 def Involution  : NativeOpTrait<"IsInvolution">;
 // Op behaves like a constant.

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 6a78b8fca6f2..9c1cf3d841dd 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -388,10 +388,12 @@ namespace OpTrait {
 // corresponding trait classes.  This avoids them being template
 // instantiated/duplicated.
 namespace impl {
+OpFoldResult foldIdempotent(Operation *op);
 OpFoldResult foldInvolution(Operation *op);
 LogicalResult verifyZeroOperands(Operation *op);
 LogicalResult verifyOneOperand(Operation *op);
 LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
+LogicalResult verifyIsIdempotent(Operation *op);
 LogicalResult verifyIsInvolution(Operation *op);
 LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
 LogicalResult verifyOperandsAreFloatLike(Operation *op);
@@ -1012,7 +1014,7 @@ class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
 };
 
 /// This class adds property that the operation is an involution.
-/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
+/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
 template <typename ConcreteType>
 class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
 public:
@@ -1033,6 +1035,28 @@ class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
   }
 };
 
+/// This class adds property that the operation is idempotent.
+/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
+template <typename ConcreteType>
+class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    static_assert(ConcreteType::template hasTrait<OneResult>(),
+                  "expected operation to produce one result");
+    static_assert(ConcreteType::template hasTrait<OneOperand>(),
+                  "expected operation to take one operand");
+    static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
+                  "expected operation to preserve type");
+    // Idempotent requires the operation to be side effect free as well
+    // but currently this check is under a FIXME and is not actually done.
+    return impl::verifyIsIdempotent(op);
+  }
+
+  static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
+    return impl::foldIdempotent(op);
+  }
+};
+
 /// This class verifies that all operands of the specified op have a float type,
 /// a vector thereof, or a tensor thereof.
 template <typename ConcreteType>

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 51d51746a120..fe86c6f12dbb 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -679,6 +679,16 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
 // Op Trait implementations
 //===----------------------------------------------------------------------===//
 
+OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
+  auto *argumentOp = op->getOperand(0).getDefiningOp();
+  if (argumentOp && op->getName() == argumentOp->getName()) {
+    // Replace the outer operation output with the inner operation.
+    return op->getOperand(0);
+  }
+
+  return {};
+}
+
 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) {
   auto *argumentOp = op->getOperand(0).getDefiningOp();
   if (argumentOp && op->getName() == argumentOp->getName()) {
@@ -730,6 +740,14 @@ static Type getTensorOrVectorElementType(Type type) {
   return type;
 }
 
+LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) {
+  // FIXME: Add back check for no side effects on operation.
+  // Currently adding it would cause the shared library build
+  // to fail since there would be a dependency of IR on SideEffectInterfaces
+  // which is cyclical.
+  return success();
+}
+
 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) {
   // FIXME: Add back check for no side effects on operation.
   // Currently adding it would cause the shared library build

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fcc677361dcc..048400fcbe81 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -840,6 +840,13 @@ def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
   let results = (outs I32);
 }
 
+def TestIdempotentTraitOp
+ : TEST_Op<"op_idempotent_trait",
+           [SameOperandsAndResultType, NoSideEffect, Idempotent]> {
+  let arguments = (ins I32:$op1);
+  let results = (outs I32);
+}
+
 def TestInvolutionTraitNoOperationFolderOp
  : TEST_Op<"op_involution_trait_no_operation_fold",
            [SameOperandsAndResultType, NoSideEffect, Involution]> {

diff  --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir
index 341e7a7695b1..cd2d8f430f3f 100644
--- a/mlir/test/mlir-tblgen/trait.mlir
+++ b/mlir/test/mlir-tblgen/trait.mlir
@@ -59,3 +59,37 @@ func @testInhibitInvolution(%arg0: i32) -> i32 {
   // CHECK: return [[OP]]
   return %1: i32
 }
+
+//===----------------------------------------------------------------------===//
+// Test that idempotent folding works correctly
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @testSingleIdempotent
+// CHECK-SAME:  ([[ARG0:%.+]]: i32)
+func @testSingleIdempotent(%arg0 : i32) -> i32 {
+  // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]])
+  %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32
+  // CHECK: return [[IDEMPOTENT]]
+  return %0: i32
+}
+
+// CHECK-LABEL: func @testDoubleIdempotent
+// CHECK-SAME:  ([[ARG0:%.+]]: i32)
+func @testDoubleIdempotent(%arg0: i32) -> i32 {
+  // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]])
+  %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32
+  %1 = "test.op_idempotent_trait"(%0) : (i32) -> i32
+  // CHECK: return [[IDEMPOTENT]]
+  return %1: i32
+}
+
+// CHECK-LABEL: func @testTripleIdempotent
+// CHECK-SAME:  ([[ARG0:%.+]]: i32)
+func @testTripleIdempotent(%arg0: i32) -> i32 {
+  // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]])
+  %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32
+  %1 = "test.op_idempotent_trait"(%0) : (i32) -> i32
+  %2 = "test.op_idempotent_trait"(%1) : (i32) -> i32
+  // CHECK: return [[IDEMPOTENT]]
+  return %2: i32
+}


        


More information about the Mlir-commits mailing list