[Mlir-commits] [mlir] 907403f - [mlir] Add a new `ConstantLike` trait to better identify operations that represent a "constant".
River Riddle
llvmlistbot at llvm.org
Thu Mar 12 14:27:59 PDT 2020
Author: River Riddle
Date: 2020-03-12T14:26:15-07:00
New Revision: 907403f342fe661b590f930a83f940c67b3ff855
URL: https://github.com/llvm/llvm-project/commit/907403f342fe661b590f930a83f940c67b3ff855
DIFF: https://github.com/llvm/llvm-project/commit/907403f342fe661b590f930a83f940c67b3ff855.diff
LOG: [mlir] Add a new `ConstantLike` trait to better identify operations that represent a "constant".
The current mechanism for identifying is a bit hacky and extremely adhoc, i.e. we explicit check 1-result, 0-operand, no side-effect, and always foldable and then assume that this is a constant. Adding a trait adds structure to this, and makes checking for a constant much more efficient as we can guarantee that all of these things have already been verified.
Differential Revision: https://reviews.llvm.org/D76020
Added:
Modified:
mlir/examples/toy/Ch7/include/toy/Ops.td
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/Matchers.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Builders.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/IR/traits.mlir
mlir/test/mlir-tblgen/op-decl.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index 9db68fe7b98e..adf56dc040d7 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -49,7 +49,8 @@ def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>;
// constant operation is marked as 'NoSideEffect' as it is a pure operation
// and may be removed if dead.
def ConstantOp : Toy_Op<"constant",
- [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+ [ConstantLike, NoSideEffect,
+ DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
// Provide a summary and description for this operation. This can be used to
// auto-generate documentation of the operations within our dialect.
let summary = "constant";
@@ -295,7 +296,7 @@ def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> {
let hasFolder = 1;
}
-def StructConstantOp : Toy_Op<"struct_constant", [NoSideEffect]> {
+def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, NoSideEffect]> {
let summary = "struct constant";
let description = [{
Constant operation turns a literal struct value into an SSA value. The data
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index b6affd876f54..26f8510a718d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -67,7 +67,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
// -----
-def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
+def SPV_ConstantOp : SPV_Op<"constant", [ConstantLike, NoSideEffect]> {
let summary = "The op that declares a SPIR-V normal constant";
let description = [{
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 00e25363ec5b..daf5da739c50 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -796,7 +796,7 @@ def CondBranchOp : Std_Op<"cond_br",
//===----------------------------------------------------------------------===//
def ConstantOp : Std_Op<"constant",
- [NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
+ [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "constant";
let arguments = (ins AnyAttr:$value);
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index d9979b8467ee..12da468f3d75 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -48,8 +48,13 @@ struct attr_value_binder {
}
};
-/// The matcher that matches a constant foldable operation that has no side
-/// effect, no operands and produces a single result.
+/// The matcher that matches operations that have the `ConstantLike` trait.
+struct constant_op_matcher {
+ bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
+};
+
+/// The matcher that matches operations that have the `ConstantLike` trait, and
+/// binds the folded attribute value.
template <typename AttrT> struct constant_op_binder {
AttrT *bind_value;
@@ -60,20 +65,19 @@ template <typename AttrT> struct constant_op_binder {
constant_op_binder() : bind_value(nullptr) {}
bool match(Operation *op) {
- if (op->getNumOperands() > 0 || op->getNumResults() != 1)
- return false;
- if (!op->hasNoSideEffect())
+ if (!op->hasTrait<OpTrait::ConstantLike>())
return false;
+ // Fold the constant to an attribute.
SmallVector<OpFoldResult, 1> foldedOp;
- if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
- if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
- if (auto attrT = attr.dyn_cast<AttrT>()) {
- if (bind_value)
- *bind_value = attrT;
- return true;
- }
- }
+ LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp);
+ (void)result;
+ assert(succeeded(result) && "expected constant to be foldable");
+
+ if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
+ if (bind_value)
+ *bind_value = attr;
+ return true;
}
return false;
}
@@ -201,8 +205,8 @@ struct RecursivePatternMatcher {
} // end namespace detail
/// Matches a constant foldable operation.
-inline detail::constant_op_binder<Attribute> m_Constant() {
- return detail::constant_op_binder<Attribute>();
+inline detail::constant_op_matcher m_Constant() {
+ return detail::constant_op_matcher();
}
/// Matches a value from a constant foldable operation and writes the value to
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 62b811ee6101..fa890e22e823 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1549,6 +1549,8 @@ def ResultsBroadcastableShape :
def Broadcastable : NativeOpTrait<"ResultsBroadcastableShape">;
// X op Y == Y op X
def Commutative : NativeOpTrait<"IsCommutative">;
+// Op behaves like a constant.
+def ConstantLike : NativeOpTrait<"ConstantLike">;
// Op behaves like a function.
def FunctionLike : NativeOpTrait<"FunctionLike">;
// Op is isolated from above.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 315256cad35b..7d663d363097 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -902,6 +902,25 @@ class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
}
};
+/// This class provides the API for a sub-set of ops that are known to be
+/// constant-like. These are non-side effecting operations with one result and
+/// zero operands that can always be folded to a specific attribute value.
+template <typename ConcreteType>
+class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ static_assert(ConcreteType::template hasTrait<OneResult>(),
+ "expected operation to produce one result");
+ static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
+ "expected operation to take zero operands");
+ // TODO: We should verify that the operation can always be folded, but this
+ // requires that the attributes of the op already be verified. We should add
+ // support for verifying traits "after" the operation to enable this use
+ // case.
+ return success();
+ }
+};
+
/// This class provides the API for ops that are known to be isolated from
/// above.
template <typename ConcreteType>
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index c578dcfd1032..dab357975d85 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -399,7 +399,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
cst->erase();
return cleanupFailure();
}
- assert(matchPattern(constOp, m_Constant(&attr)));
+ assert(matchPattern(constOp, m_Constant()));
generatedConstants.push_back(constOp);
results.push_back(constOp->getResult(0));
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index c34896f6a6b0..f374d3803baa 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -57,7 +57,7 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
// Ask the dialect to materialize a constant operation for this value.
if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
assert(insertPt == builder.getInsertionPoint());
- assert(matchPattern(constOp, m_Constant(&value)));
+ assert(matchPattern(constOp, m_Constant()));
return constOp;
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 444b91bbe19e..59e4a764afcc 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -454,7 +454,7 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
// -----
func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
- // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}}
+ // expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}}
linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
}
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 42044bde5dcd..655fbb89b2d4 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -24,7 +24,7 @@ func @failedSameOperandElementType(%t1f: tensor<1xf32>, %t1i: tensor<1xi32>) {
// -----
func @failedSameOperandAndResultElementType_no_operands() {
- // expected-error at +1 {{expected 1 or more operands}}
+ // expected-error at +1 {{expected 2 operands, but found 0}}
"test.same_operand_element_type"() : () -> tensor<1xf32>
}
diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 2bad43dca074..a6719cfa5117 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -55,7 +55,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> {
// CHECK: ArrayRef<Value> tblgen_operands;
// CHECK: };
-// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl
+// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::HasNoSideEffect
// CHECK: public:
// CHECK: using Op::Op;
// CHECK: using OperandAdaptor = AOpOperandAdaptor;
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 69d5e97fa7d4..853f399af5e2 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1523,14 +1523,6 @@ void OpEmitter::genTraits() {
unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
- // Add the native and interface traits.
- for (const auto &trait : op.getTraits()) {
- if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
- opClass.addTrait(opTrait->getTrait());
- else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
- opClass.addTrait(opTrait->getTrait());
- }
-
// Add variadic size trait and normal op traits.
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariadicOperands();
@@ -1555,6 +1547,14 @@ void OpEmitter::genTraits() {
break;
}
}
+
+ // Add the native and interface traits.
+ for (const auto &trait : op.getTraits()) {
+ if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
+ opClass.addTrait(opTrait->getTrait());
+ else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
+ opClass.addTrait(opTrait->getTrait());
+ }
}
void OpEmitter::genOpNameGetter() {
More information about the Mlir-commits
mailing list