[Mlir-commits] [mlir] 1ceaffd - [MLIR] Add a foldTrait() mechanism to allow traits to define folding and test it with an Involution trait
Mehdi Amini
llvmlistbot at llvm.org
Thu Oct 8 20:26:24 PDT 2020
Author: ahmedsabie
Date: 2020-10-09T03:25:53Z
New Revision: 1ceaffd95a6bdc4b7d2193e049bcd6b40ee9ff50
URL: https://github.com/llvm/llvm-project/commit/1ceaffd95a6bdc4b7d2193e049bcd6b40ee9ff50
DIFF: https://github.com/llvm/llvm-project/commit/1ceaffd95a6bdc4b7d2193e049bcd6b40ee9ff50.diff
LOG: [MLIR] Add a foldTrait() mechanism to allow traits to define folding and test it with an Involution trait
This change allows folds to be done on a newly introduced involution trait rather than having to manually rewrite this optimization for every instance of an involution
Reviewed By: rriddle, andyly, stephenneuendorffer
Differential Revision: https://reviews.llvm.org/D88809
Added:
mlir/test/lib/Dialect/Test/TestTraits.cpp
mlir/test/mlir-tblgen/trait.mlir
Modified:
mlir/docs/Traits.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Operation.cpp
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md
index 488da39e6504..8b1cf0a03b99 100644
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits.md
@@ -56,6 +56,47 @@ Note: It is generally good practice to define the implementation of the
`verifyTrait` hook out-of-line as a free function when possible to avoid
instantiating the implementation for every concrete operation type.
+Operation traits may also provide a `foldTrait` hook that is called when
+folding the concrete operation. The trait folders will only be invoked if
+the concrete operation fold is either not implemented, fails, or performs
+an in-place fold.
+
+The following signature of fold will be called if it is implemented
+and the op has a single result.
+
+```c++
+template <typename ConcreteType>
+class MyTrait : public OpTrait::TraitBase<ConcreteType, MyTrait> {
+public:
+ /// Override the 'foldTrait' hook to support trait based folding on the
+ /// concrete operation.
+ static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { {
+ // ...
+ }
+};
+```
+
+Otherwise, if the operation has a single result and the above signature is
+not implemented, or the operation has multiple results, then the following signature
+will be used (if implemented):
+
+```c++
+template <typename ConcreteType>
+class MyTrait : public OpTrait::TraitBase<ConcreteType, MyTrait> {
+public:
+ /// Override the 'foldTrait' hook to support trait based folding on the
+ /// concrete operation.
+ static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) { {
+ // ...
+ }
+};
+```
+
+Note: It is generally good practice to define the implementation of the
+`foldTrait` hook out-of-line as a free function when possible to avoid
+instantiating the implementation for every concrete operation type.
+
### Parametric Traits
The above demonstrates the definition of a simple self-contained trait. It is
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index eaaf5b75230e..5845371161bf 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1723,6 +1723,8 @@ def ResultsBroadcastableShape :
NativeOpTrait<"ResultsBroadcastableShape">;
// X op Y == Y op X
def Commutative : NativeOpTrait<"IsCommutative">;
+// op op X == X
+def Involution : NativeOpTrait<"IsInvolution">;
// Op behaves like a constant.
def ConstantLike : NativeOpTrait<"ConstantLike">;
// Op behaves like a function.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 6861523e0d04..b6fe53b6b524 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -21,6 +21,7 @@
#include "mlir/IR/Operation.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
+
#include <type_traits>
namespace mlir {
@@ -277,7 +278,16 @@ class FoldingHook {
/// AbstractOperation.
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- return cast<ConcreteType>(op).fold(operands, results);
+ auto operationFoldResult = cast<ConcreteType>(op).fold(operands, results);
+ // Failure to fold or in place fold both mean we can continue folding.
+ if (failed(operationFoldResult) || results.empty()) {
+ auto traitFoldResult = ConcreteType::foldTraits(op, operands, results);
+ // Only return the trait fold result if it is a success since
+ // operationFoldResult might have been a success originally.
+ if (succeeded(traitFoldResult))
+ return traitFoldResult;
+ }
+ return operationFoldResult;
}
/// This hook implements a generalized folder for this operation. Operations
@@ -326,6 +336,14 @@ class FoldingHook<ConcreteType, isSingleResult,
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
auto result = cast<ConcreteType>(op).fold(operands);
+ // Failure to fold or in place fold both mean we can continue folding.
+ if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
+ // Only consider the trait fold result if it is a success since
+ // the operation fold might have been a success originally.
+ if (auto traitFoldResult = ConcreteType::foldTraits(op, operands))
+ result = traitFoldResult;
+ }
+
if (!result)
return failure();
@@ -370,9 +388,11 @@ namespace OpTrait {
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
+OpFoldResult foldInvolution(Operation *op);
LogicalResult verifyZeroOperands(Operation *op);
LogicalResult verifyOneOperand(Operation *op);
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
+LogicalResult verifyIsInvolution(Operation *op);
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyOperandsAreFloatLike(Operation *op);
LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
@@ -426,6 +446,23 @@ class TraitBase {
static AbstractOperation::OperationProperties getTraitProperties() {
return 0;
}
+
+ static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
+ SmallVector<OpFoldResult, 1> results;
+ if (failed(foldTrait(op, operands, results)))
+ return {};
+ if (results.empty())
+ return op->getResult(0);
+ assert(results.size() == 1 &&
+ "Single result op cannot return multiple fold results");
+
+ return results[0];
+ }
+
+ static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return failure();
+ }
};
//===----------------------------------------------------------------------===//
@@ -974,6 +1011,26 @@ class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
}
};
+/// This class adds property that the operation is an involution.
+/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
+template <typename ConcreteType>
+class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ static_assert(ConcreteType::template hasTrait<OneResult>(),
+ "expected operation to produce one result");
+ static_assert(ConcreteType::template hasTrait<OneOperand>(),
+ "expected operation to take one operand");
+ static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
+ "expected operation to preserve type");
+ return impl::verifyIsInvolution(op);
+ }
+
+ static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
+ return impl::foldInvolution(op);
+ }
+};
+
/// This class verifies that all operands of the specified op have a float type,
/// a vector thereof, or a tensor thereof.
template <typename ConcreteType>
@@ -1306,6 +1363,19 @@ class Op : public OpState,
failed(cast<ConcreteType>(op).verify()));
}
+ /// This is the hook that tries to fold the given operation according to its
+ /// traits. It delegates to the Traits for their policy implementations, and
+ /// allows the user to specify their own fold() method.
+ static OpFoldResult foldTraits(Operation *op, ArrayRef<Attribute> operands) {
+ return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands);
+ }
+
+ static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands,
+ results);
+ }
+
// Returns the properties of an operation by combining the properties of the
// traits of the op.
static AbstractOperation::OperationProperties getOperationProperties() {
@@ -1358,6 +1428,53 @@ class Op : public OpState,
}
};
+ template <typename... Types>
+ struct BaseFolder;
+
+ template <typename First, typename... Rest>
+ struct BaseFolder<First, Rest...> {
+ static OpFoldResult foldTraits(Operation *op,
+ ArrayRef<Attribute> operands) {
+ auto result = First::foldTrait(op, operands);
+ // Failure to fold or in place fold both mean we can continue folding.
+ if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
+ // Only consider the trait fold result if it is a success since
+ // the operation fold might have been a success originally.
+ auto resultRemaining = BaseFolder<Rest...>::foldTraits(op, operands);
+ if (resultRemaining)
+ result = resultRemaining;
+ }
+
+ return result;
+ }
+
+ static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ auto result = First::foldTrait(op, operands, results);
+ // Failure to fold or in place fold both mean we can continue folding.
+ if (failed(result) || results.empty()) {
+ auto resultRemaining =
+ BaseFolder<Rest...>::foldTraits(op, operands, results);
+ if (succeeded(resultRemaining))
+ result = resultRemaining;
+ }
+
+ return result;
+ }
+ };
+
+ template <typename...>
+ struct BaseFolder {
+ static OpFoldResult foldTraits(Operation *op,
+ ArrayRef<Attribute> operands) {
+ return {};
+ }
+ static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return failure();
+ }
+ };
+
template <typename...> struct BaseProperties {
static AbstractOperation::OperationProperties getTraitProperties() {
return 0;
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index f531a6097c25..280b16a589d1 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FoldInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <numeric>
using namespace mlir;
@@ -679,6 +680,16 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
// Op Trait implementations
//===----------------------------------------------------------------------===//
+OpFoldResult OpTrait::impl::foldInvolution(Operation *op) {
+ auto *argumentOp = op->getOperand(0).getDefiningOp();
+ if (argumentOp && op->getName() == argumentOp->getName()) {
+ // Replace the outer involutions output with inner's input.
+ return argumentOp->getOperand(0);
+ }
+
+ return {};
+}
+
LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) {
if (op->getNumOperands() != 0)
return op->emitOpError() << "requires zero operands";
@@ -720,6 +731,12 @@ static Type getTensorOrVectorElementType(Type type) {
return type;
}
+LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) {
+ if (!MemoryEffectOpInterface::hasNoEffect(op))
+ return op->emitOpError() << "requires operation to have no side effects";
+ return success();
+}
+
LogicalResult
OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) {
for (auto opType : op->getOperandTypes()) {
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 696b43992971..31c8cccae36e 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -1,6 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
TestDialect.cpp
TestPatterns.cpp
+ TestTraits.cpp
)
set(LLVM_TARGET_DEFINITIONS TestInterfaces.td)
@@ -23,6 +24,7 @@ add_public_tablegen_target(MLIRTestOpsIncGen)
add_mlir_library(MLIRTestDialect
TestDialect.cpp
TestPatterns.cpp
+ TestTraits.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 73610457cf7b..d36d7bd58ea8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -798,6 +798,29 @@ def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
let results = (outs I32);
}
+def TestInvolutionTraitNoOperationFolderOp
+ : TEST_Op<"op_involution_trait_no_operation_fold",
+ [SameOperandsAndResultType, NoSideEffect, Involution]> {
+ let arguments = (ins I32:$op1);
+ let results = (outs I32);
+}
+
+def TestInvolutionTraitFailingOperationFolderOp
+ : TEST_Op<"op_involution_trait_failing_operation_fold",
+ [SameOperandsAndResultType, NoSideEffect, Involution]> {
+ let arguments = (ins I32:$op1);
+ let results = (outs I32);
+ let hasFolder = 1;
+}
+
+def TestInvolutionTraitSuccesfulOperationFolderOp
+ : TEST_Op<"op_involution_trait_succesful_operation_fold",
+ [SameOperandsAndResultType, NoSideEffect, Involution]> {
+ let arguments = (ins I32:$op1);
+ let results = (outs I32);
+ let hasFolder = 1;
+}
+
def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> {
let arguments = (ins I32);
let results = (outs I32);
diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp
new file mode 100644
index 000000000000..3cbc95ce6c74
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp
@@ -0,0 +1,45 @@
+//===- TestTraits.cpp - Test trait folding --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Trait Folder.
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
+ ArrayRef<Attribute> operands) {
+ // This failure should cause the trait fold to run instead.
+ return {};
+}
+
+OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
+ ArrayRef<Attribute> operands) {
+ auto argument_op = getOperand();
+ // The success case should cause the trait fold to be supressed.
+ return argument_op.getDefiningOp() ? argument_op : OpFoldResult{};
+}
+
+namespace {
+struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
+ void runOnFunction() override {
+ applyPatternsAndFoldGreedily(getFunction(), {});
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+void registerTestTraitsPass() {
+ PassRegistration<TestTraitFolder>("test-trait-folder", "Run trait folding");
+}
+} // namespace mlir
diff --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir
new file mode 100644
index 000000000000..341e7a7695b1
--- /dev/null
+++ b/mlir/test/mlir-tblgen/trait.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -test-trait-folder %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test that involutions fold correctly
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @testSingleInvolution
+// CHECK-SAME: ([[ARG0:%.+]]: i32)
+func @testSingleInvolution(%arg0 : i32) -> i32 {
+ // CHECK: [[INVOLUTION:%.+]] = "test.op_involution_trait_no_operation_fold"([[ARG0]])
+ %0 = "test.op_involution_trait_no_operation_fold"(%arg0) : (i32) -> i32
+ // CHECK: return [[INVOLUTION]]
+ return %0: i32
+}
+
+// CHECK-LABEL: func @testDoubleInvolution
+// CHECK-SAME: ([[ARG0:%.+]]: i32)
+func @testDoubleInvolution(%arg0: i32) -> i32 {
+ %0 = "test.op_involution_trait_no_operation_fold"(%arg0) : (i32) -> i32
+ %1 = "test.op_involution_trait_no_operation_fold"(%0) : (i32) -> i32
+ // CHECK: return [[ARG0]]
+ return %1: i32
+}
+
+// CHECK-LABEL: func @testTripleInvolution
+// CHECK-SAME: ([[ARG0:%.+]]: i32)
+func @testTripleInvolution(%arg0: i32) -> i32 {
+ // CHECK: [[INVOLUTION:%.+]] = "test.op_involution_trait_no_operation_fold"([[ARG0]])
+ %0 = "test.op_involution_trait_no_operation_fold"(%arg0) : (i32) -> i32
+ %1 = "test.op_involution_trait_no_operation_fold"(%0) : (i32) -> i32
+ %2 = "test.op_involution_trait_no_operation_fold"(%1) : (i32) -> i32
+ // CHECK: return [[INVOLUTION]]
+ return %2: i32
+}
+
+//===----------------------------------------------------------------------===//
+// Test that involutions fold occurs if operation fold fails
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @testFailingOperationFolder
+// CHECK-SAME: ([[ARG0:%.+]]: i32)
+func @testFailingOperationFolder(%arg0: i32) -> i32 {
+ %0 = "test.op_involution_trait_failing_operation_fold"(%arg0) : (i32) -> i32
+ %1 = "test.op_involution_trait_failing_operation_fold"(%0) : (i32) -> i32
+ // CHECK: return [[ARG0]]
+ return %1: i32
+}
+
+//===----------------------------------------------------------------------===//
+// Test that involution fold does not occur if operation fold succeeds
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @testInhibitInvolution
+// CHECK-SAME: ([[ARG0:%.+]]: i32)
+func @testInhibitInvolution(%arg0: i32) -> i32 {
+ // CHECK: [[OP:%.+]] = "test.op_involution_trait_succesful_operation_fold"([[ARG0]])
+ %0 = "test.op_involution_trait_succesful_operation_fold"(%arg0) : (i32) -> i32
+ %1 = "test.op_involution_trait_succesful_operation_fold"(%0) : (i32) -> i32
+ // CHECK: return [[OP]]
+ return %1: i32
+}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 0389c70be3d6..5b035659f8d7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -77,6 +77,7 @@ void registerTestRecursiveTypesPass();
void registerTestReducer();
void registerTestSpirvEntryPointABIPass();
void registerTestSCFUtilsPass();
+void registerTestTraitsPass();
void registerTestVectorConversions();
void registerVectorizerTestPass();
} // namespace mlir
@@ -134,6 +135,7 @@ void registerTestPasses() {
registerTestGpuParallelLoopMappingPass();
registerTestSpirvEntryPointABIPass();
registerTestSCFUtilsPass();
+ registerTestTraitsPass();
registerTestVectorConversions();
registerVectorizerTestPass();
}
More information about the Mlir-commits
mailing list