[Mlir-commits] [mlir] 7f61396 - [mlir][Interfaces] Add implicit casts from concrete operation types to the interfaces they implement.

River Riddle llvmlistbot at llvm.org
Thu Nov 12 23:00:10 PST 2020


Author: River Riddle
Date: 2020-11-12T22:56:08-08:00
New Revision: 7f61396cfac5f114707a4240a314dec28e03a1d5

URL: https://github.com/llvm/llvm-project/commit/7f61396cfac5f114707a4240a314dec28e03a1d5
DIFF: https://github.com/llvm/llvm-project/commit/7f61396cfac5f114707a4240a314dec28e03a1d5.diff

LOG: [mlir][Interfaces] Add implicit casts from concrete operation types to the interfaces they implement.

This removes the need to have an explicit `cast<>` given that we always know it `isa` instance of the interface.

Differential Revision: https://reviews.llvm.org/D91304

Added: 
    

Modified: 
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index fa3ef3e14fa2..44b0f67d1e30 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -16,7 +16,6 @@
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/TypeName.h"
-#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 namespace detail {
@@ -75,10 +74,28 @@ class Interface : public BaseType {
   using InterfaceBase =
       Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
 
+  /// This is a special trait that registers a given interface with an object.
+  template <typename ConcreteT>
+  struct Trait : public BaseTrait<ConcreteT, Trait> {
+    using ModelT = Model<ConcreteT>;
+
+    /// Define an accessor for the ID of this interface.
+    static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
+  };
+
+  /// Construct an interface from an instance of the value type.
   Interface(ValueT t = ValueT())
       : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
-    assert((!t || impl) &&
-           "instantiating an interface with an unregistered operation");
+    assert((!t || impl) && "expected value to provide interface instance");
+  }
+
+  /// Construct an interface instance from a type that implements this
+  /// interface's trait.
+  template <typename T, typename std::enable_if_t<
+                            std::is_base_of<Trait<T>, T>::value> * = nullptr>
+  Interface(T t)
+      : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
+    assert((!t || impl) && "expected value to provide interface instance");
   }
 
   /// Support 'classof' by checking if the given object defines the concrete
@@ -88,15 +105,6 @@ class Interface : public BaseType {
   /// Define an accessor for the ID of this interface.
   static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
 
-  /// This is a special trait that registers a given interface with an object.
-  template <typename ConcreteT>
-  struct Trait : public BaseTrait<ConcreteT, Trait> {
-    using ModelT = Model<ConcreteT>;
-
-    /// Define an accessor for the ID of this interface.
-    static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
-  };
-
 protected:
   /// Get the raw concept in the correct derived concept type.
   const Concept *getImpl() const { return impl; }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 7cb9bb5b13bf..abc10e8f486a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -341,10 +341,9 @@ template <typename... Args>
 static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
                                          Args... args) {
   if (isa<GenericOp>(op.getOperation()))
-    return cast<LinalgOp>(rewriter.create<GenericOp>(args...).getOperation());
+    return rewriter.create<GenericOp>(args...);
   if (isa<IndexedGenericOp>(op.getOperation()))
-    return cast<LinalgOp>(
-        rewriter.create<IndexedGenericOp>(args...).getOperation());
+    return rewriter.create<IndexedGenericOp>(args...);
   llvm_unreachable(
       "expected only linalg.generic or linalg.indexed_generic ops");
   return nullptr;


        


More information about the Mlir-commits mailing list