[Mlir-commits] [mlir] 7dff6b8 - [MLIR] Add idempotent trait folding
Andy Ly
llvmlistbot at llvm.org
Fri Oct 16 08:51:19 PDT 2020
Author: ahmedsabie
Date: 2020-10-16T15:51:04Z
New Revision: 7dff6b818b1cdd52fbc99f6256760d6eb02a7622
URL: https://github.com/llvm/llvm-project/commit/7dff6b818b1cdd52fbc99f6256760d6eb02a7622
DIFF: https://github.com/llvm/llvm-project/commit/7dff6b818b1cdd52fbc99f6256760d6eb02a7622.diff
LOG: [MLIR] Add idempotent trait folding
This trait simply adds a fold of f(f(x)) = f(x) when an operation is labelled as idempotent
Reviewed By: rriddle, andyly
Differential Revision: https://reviews.llvm.org/D89421
Added:
Modified:
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Operation.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/trait.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 72b3b1ab41f5..e09c18c0e1a1 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1720,6 +1720,8 @@ def ResultsBroadcastableShape :
NativeOpTrait<"ResultsBroadcastableShape">;
// X op Y == Y op X
def Commutative : NativeOpTrait<"IsCommutative">;
+// op op X == op X
+def Idempotent : NativeOpTrait<"IsIdempotent">;
// op op X == X
def Involution : NativeOpTrait<"IsInvolution">;
// Op behaves like a constant.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 6a78b8fca6f2..9c1cf3d841dd 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -388,10 +388,12 @@ namespace OpTrait {
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
+OpFoldResult foldIdempotent(Operation *op);
OpFoldResult foldInvolution(Operation *op);
LogicalResult verifyZeroOperands(Operation *op);
LogicalResult verifyOneOperand(Operation *op);
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
+LogicalResult verifyIsIdempotent(Operation *op);
LogicalResult verifyIsInvolution(Operation *op);
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyOperandsAreFloatLike(Operation *op);
@@ -1012,7 +1014,7 @@ class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
};
/// This class adds property that the operation is an involution.
-/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
+/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
template <typename ConcreteType>
class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
public:
@@ -1033,6 +1035,28 @@ class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
}
};
+/// This class adds property that the operation is idempotent.
+/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
+template <typename ConcreteType>
+class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ static_assert(ConcreteType::template hasTrait<OneResult>(),
+ "expected operation to produce one result");
+ static_assert(ConcreteType::template hasTrait<OneOperand>(),
+ "expected operation to take one operand");
+ static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
+ "expected operation to preserve type");
+ // Idempotent requires the operation to be side effect free as well
+ // but currently this check is under a FIXME and is not actually done.
+ return impl::verifyIsIdempotent(op);
+ }
+
+ static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
+ return impl::foldIdempotent(op);
+ }
+};
+
/// This class verifies that all operands of the specified op have a float type,
/// a vector thereof, or a tensor thereof.
template <typename ConcreteType>
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 51d51746a120..fe86c6f12dbb 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -679,6 +679,16 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
// Op Trait implementations
//===----------------------------------------------------------------------===//
+OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
+ auto *argumentOp = op->getOperand(0).getDefiningOp();
+ if (argumentOp && op->getName() == argumentOp->getName()) {
+ // Replace the outer operation output with the inner operation.
+ return op->getOperand(0);
+ }
+
+ return {};
+}
+
OpFoldResult OpTrait::impl::foldInvolution(Operation *op) {
auto *argumentOp = op->getOperand(0).getDefiningOp();
if (argumentOp && op->getName() == argumentOp->getName()) {
@@ -730,6 +740,14 @@ static Type getTensorOrVectorElementType(Type type) {
return type;
}
+LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) {
+ // FIXME: Add back check for no side effects on operation.
+ // Currently adding it would cause the shared library build
+ // to fail since there would be a dependency of IR on SideEffectInterfaces
+ // which is cyclical.
+ return success();
+}
+
LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) {
// FIXME: Add back check for no side effects on operation.
// Currently adding it would cause the shared library build
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fcc677361dcc..048400fcbe81 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -840,6 +840,13 @@ def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
let results = (outs I32);
}
+def TestIdempotentTraitOp
+ : TEST_Op<"op_idempotent_trait",
+ [SameOperandsAndResultType, NoSideEffect, Idempotent]> {
+ let arguments = (ins I32:$op1);
+ let results = (outs I32);
+}
+
def TestInvolutionTraitNoOperationFolderOp
: TEST_Op<"op_involution_trait_no_operation_fold",
[SameOperandsAndResultType, NoSideEffect, Involution]> {
diff --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir
index 341e7a7695b1..cd2d8f430f3f 100644
--- a/mlir/test/mlir-tblgen/trait.mlir
+++ b/mlir/test/mlir-tblgen/trait.mlir
@@ -59,3 +59,37 @@ func @testInhibitInvolution(%arg0: i32) -> i32 {
// CHECK: return [[OP]]
return %1: i32
}
+
+//===----------------------------------------------------------------------===//
+// Test that idempotent folding works correctly
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @testSingleIdempotent
+// CHECK-SAME: ([[ARG0:%.+]]: i32)
+func @testSingleIdempotent(%arg0 : i32) -> i32 {
+ // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]])
+ %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32
+ // CHECK: return [[IDEMPOTENT]]
+ return %0: i32
+}
+
+// CHECK-LABEL: func @testDoubleIdempotent
+// CHECK-SAME: ([[ARG0:%.+]]: i32)
+func @testDoubleIdempotent(%arg0: i32) -> i32 {
+ // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]])
+ %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32
+ %1 = "test.op_idempotent_trait"(%0) : (i32) -> i32
+ // CHECK: return [[IDEMPOTENT]]
+ return %1: i32
+}
+
+// CHECK-LABEL: func @testTripleIdempotent
+// CHECK-SAME: ([[ARG0:%.+]]: i32)
+func @testTripleIdempotent(%arg0: i32) -> i32 {
+ // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]])
+ %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32
+ %1 = "test.op_idempotent_trait"(%0) : (i32) -> i32
+ %2 = "test.op_idempotent_trait"(%1) : (i32) -> i32
+ // CHECK: return [[IDEMPOTENT]]
+ return %2: i32
+}
More information about the Mlir-commits
mailing list