[Mlir-commits] [mlir] 907403f - [mlir] Add a new `ConstantLike` trait to better identify operations that represent a "constant".

River Riddle llvmlistbot at llvm.org
Thu Mar 12 14:27:59 PDT 2020


Author: River Riddle
Date: 2020-03-12T14:26:15-07:00
New Revision: 907403f342fe661b590f930a83f940c67b3ff855

URL: https://github.com/llvm/llvm-project/commit/907403f342fe661b590f930a83f940c67b3ff855
DIFF: https://github.com/llvm/llvm-project/commit/907403f342fe661b590f930a83f940c67b3ff855.diff

LOG: [mlir] Add a new `ConstantLike` trait to better identify operations that represent a "constant".

The current mechanism for identifying is a bit hacky and extremely adhoc, i.e. we explicit check 1-result, 0-operand, no side-effect, and always foldable and then assume that this is a constant. Adding a trait adds structure to this, and makes checking for a constant much more efficient as we can guarantee that all of these things have already been verified.

Differential Revision: https://reviews.llvm.org/D76020

Added: 
    

Modified: 
    mlir/examples/toy/Ch7/include/toy/Ops.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/Matchers.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/IR/Builders.cpp
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/IR/traits.mlir
    mlir/test/mlir-tblgen/op-decl.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index 9db68fe7b98e..adf56dc040d7 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -49,7 +49,8 @@ def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>;
 // constant operation is marked as 'NoSideEffect' as it is a pure operation
 // and may be removed if dead.
 def ConstantOp : Toy_Op<"constant",
-    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+    [ConstantLike, NoSideEffect,
+     DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
   // Provide a summary and description for this operation. This can be used to
   // auto-generate documentation of the operations within our dialect.
   let summary = "constant";
@@ -295,7 +296,7 @@ def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> {
   let hasFolder = 1;
 }
 
-def StructConstantOp : Toy_Op<"struct_constant", [NoSideEffect]> {
+def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, NoSideEffect]> {
   let summary = "struct constant";
   let description = [{
     Constant operation turns a literal struct value into an SSA value. The data

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index b6affd876f54..26f8510a718d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -67,7 +67,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
 
 // -----
 
-def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
+def SPV_ConstantOp : SPV_Op<"constant", [ConstantLike, NoSideEffect]> {
   let summary = "The op that declares a SPIR-V normal constant";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 00e25363ec5b..daf5da739c50 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -796,7 +796,7 @@ def CondBranchOp : Std_Op<"cond_br",
 //===----------------------------------------------------------------------===//
 
 def ConstantOp : Std_Op<"constant",
-    [NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
+    [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
   let summary = "constant";
 
   let arguments = (ins AnyAttr:$value);

diff  --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index d9979b8467ee..12da468f3d75 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -48,8 +48,13 @@ struct attr_value_binder {
   }
 };
 
-/// The matcher that matches a constant foldable operation that has no side
-/// effect, no operands and produces a single result.
+/// The matcher that matches operations that have the `ConstantLike` trait.
+struct constant_op_matcher {
+  bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
+};
+
+/// The matcher that matches operations that have the `ConstantLike` trait, and
+/// binds the folded attribute value.
 template <typename AttrT> struct constant_op_binder {
   AttrT *bind_value;
 
@@ -60,20 +65,19 @@ template <typename AttrT> struct constant_op_binder {
   constant_op_binder() : bind_value(nullptr) {}
 
   bool match(Operation *op) {
-    if (op->getNumOperands() > 0 || op->getNumResults() != 1)
-      return false;
-    if (!op->hasNoSideEffect())
+    if (!op->hasTrait<OpTrait::ConstantLike>())
       return false;
 
+    // Fold the constant to an attribute.
     SmallVector<OpFoldResult, 1> foldedOp;
-    if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
-      if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
-        if (auto attrT = attr.dyn_cast<AttrT>()) {
-          if (bind_value)
-            *bind_value = attrT;
-          return true;
-        }
-      }
+    LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp);
+    (void)result;
+    assert(succeeded(result) && "expected constant to be foldable");
+
+    if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
+      if (bind_value)
+        *bind_value = attr;
+      return true;
     }
     return false;
   }
@@ -201,8 +205,8 @@ struct RecursivePatternMatcher {
 } // end namespace detail
 
 /// Matches a constant foldable operation.
-inline detail::constant_op_binder<Attribute> m_Constant() {
-  return detail::constant_op_binder<Attribute>();
+inline detail::constant_op_matcher m_Constant() {
+  return detail::constant_op_matcher();
 }
 
 /// Matches a value from a constant foldable operation and writes the value to

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 62b811ee6101..fa890e22e823 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1549,6 +1549,8 @@ def ResultsBroadcastableShape :
 def Broadcastable : NativeOpTrait<"ResultsBroadcastableShape">;
 // X op Y == Y op X
 def Commutative  : NativeOpTrait<"IsCommutative">;
+// Op behaves like a constant.
+def ConstantLike : NativeOpTrait<"ConstantLike">;
 // Op behaves like a function.
 def FunctionLike : NativeOpTrait<"FunctionLike">;
 // Op is isolated from above.

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 315256cad35b..7d663d363097 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -902,6 +902,25 @@ class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
   }
 };
 
+/// This class provides the API for a sub-set of ops that are known to be
+/// constant-like. These are non-side effecting operations with one result and
+/// zero operands that can always be folded to a specific attribute value.
+template <typename ConcreteType>
+class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    static_assert(ConcreteType::template hasTrait<OneResult>(),
+                  "expected operation to produce one result");
+    static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
+                  "expected operation to take zero operands");
+    // TODO: We should verify that the operation can always be folded, but this
+    // requires that the attributes of the op already be verified. We should add
+    // support for verifying traits "after" the operation to enable this use
+    // case.
+    return success();
+  }
+};
+
 /// This class provides the API for ops that are known to be isolated from
 /// above.
 template <typename ConcreteType>

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index c578dcfd1032..dab357975d85 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -399,7 +399,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
         cst->erase();
       return cleanupFailure();
     }
-    assert(matchPattern(constOp, m_Constant(&attr)));
+    assert(matchPattern(constOp, m_Constant()));
 
     generatedConstants.push_back(constOp);
     results.push_back(constOp->getResult(0));

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index c34896f6a6b0..f374d3803baa 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -57,7 +57,7 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
   // Ask the dialect to materialize a constant operation for this value.
   if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
     assert(insertPt == builder.getInsertionPoint());
-    assert(matchPattern(constOp, m_Constant(&value)));
+    assert(matchPattern(constOp, m_Constant()));
     return constOp;
   }
 

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 444b91bbe19e..59e4a764afcc 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -454,7 +454,7 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
 // -----
 
 func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
-  // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}}
+  // expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}}
   linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
 }
 

diff  --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 42044bde5dcd..655fbb89b2d4 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -24,7 +24,7 @@ func @failedSameOperandElementType(%t1f: tensor<1xf32>, %t1i: tensor<1xi32>) {
 // -----
 
 func @failedSameOperandAndResultElementType_no_operands() {
-  // expected-error at +1 {{expected 1 or more operands}}
+  // expected-error at +1 {{expected 2 operands, but found 0}}
   "test.same_operand_element_type"() : () -> tensor<1xf32>
 }
 

diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 2bad43dca074..a6719cfa5117 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -55,7 +55,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> {
 // CHECK:   ArrayRef<Value> tblgen_operands;
 // CHECK: };
 
-// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl
+// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::HasNoSideEffect
 // CHECK: public:
 // CHECK:   using Op::Op;
 // CHECK:   using OperandAdaptor = AOpOperandAdaptor;

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 69d5e97fa7d4..853f399af5e2 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1523,14 +1523,6 @@ void OpEmitter::genTraits() {
   unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
   addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
 
-  // Add the native and interface traits.
-  for (const auto &trait : op.getTraits()) {
-    if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
-      opClass.addTrait(opTrait->getTrait());
-    else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
-      opClass.addTrait(opTrait->getTrait());
-  }
-
   // Add variadic size trait and normal op traits.
   int numOperands = op.getNumOperands();
   int numVariadicOperands = op.getNumVariadicOperands();
@@ -1555,6 +1547,14 @@ void OpEmitter::genTraits() {
       break;
     }
   }
+
+  // Add the native and interface traits.
+  for (const auto &trait : op.getTraits()) {
+    if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
+      opClass.addTrait(opTrait->getTrait());
+    else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
+      opClass.addTrait(opTrait->getTrait());
+  }
 }
 
 void OpEmitter::genOpNameGetter() {


        


More information about the Mlir-commits mailing list