[Mlir-commits] [mlir] 6ccf2d6 - [mlir] Add an interface for Cast-Like operations

River Riddle llvmlistbot at llvm.org
Wed Jan 20 16:28:29 PST 2021


Author: River Riddle
Date: 2021-01-20T16:28:17-08:00
New Revision: 6ccf2d62b4876c88427ae97d0cd3c9ed4330560a

URL: https://github.com/llvm/llvm-project/commit/6ccf2d62b4876c88427ae97d0cd3c9ed4330560a
DIFF: https://github.com/llvm/llvm-project/commit/6ccf2d62b4876c88427ae97d0cd3c9ed4330560a.diff

LOG: [mlir] Add an interface for Cast-Like operations

A cast-like operation is one that converts from a set of input types to a set of output types. The arity of the inputs may be from 0-N, whereas the arity of the outputs may be anything from 1-N. Cast-like operations are removable in cases where they produce a "no-op", i.e when the input types and output types match 1-1.

Differential Revision: https://reviews.llvm.org/D94831

Added: 
    mlir/include/mlir/Interfaces/CastInterfaces.h
    mlir/include/mlir/Interfaces/CastInterfaces.td
    mlir/lib/Interfaces/CastInterfaces.cpp

Modified: 
    mlir/docs/Tutorials/Toy/Ch-4.md
    mlir/examples/toy/Ch4/CMakeLists.txt
    mlir/examples/toy/Ch4/include/toy/Dialect.h
    mlir/examples/toy/Ch4/include/toy/Ops.td
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
    mlir/examples/toy/Ch5/CMakeLists.txt
    mlir/examples/toy/Ch5/include/toy/Dialect.h
    mlir/examples/toy/Ch5/include/toy/Ops.td
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
    mlir/examples/toy/Ch6/CMakeLists.txt
    mlir/examples/toy/Ch6/include/toy/Dialect.h
    mlir/examples/toy/Ch6/include/toy/Ops.td
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
    mlir/examples/toy/Ch7/CMakeLists.txt
    mlir/examples/toy/Ch7/include/toy/Dialect.h
    mlir/examples/toy/Ch7/include/toy/Ops.td
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/IR/Diagnostics.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/Interfaces/CMakeLists.txt
    mlir/lib/Dialect/Shape/IR/CMakeLists.txt
    mlir/lib/Dialect/StandardOps/CMakeLists.txt
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/Interfaces/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md
index dc1314419320..c454a762e2b5 100644
--- a/mlir/docs/Tutorials/Toy/Ch-4.md
+++ b/mlir/docs/Tutorials/Toy/Ch-4.md
@@ -182,26 +182,50 @@ to add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent
 casts between two 
diff erent shapes.
 
 ```tablegen
-def CastOp : Toy_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+    DeclareOpInterfaceMethods<CastOpInterface>,
+    NoSideEffect,
+    SameOperandsAndResultShape]
+  > {
   let summary = "shape cast operation";
   let description = [{
     The "cast" operation converts a tensor from one type to an equivalent type
     without changing any data elements. The source and destination types
-    must both be tensor types with the same element type. If both are ranked
-    then the rank should be the same and static dimensions should match. The
-    operation is invalid if converting to a mismatching constant dimension.
+    must both be tensor types with the same element type. If both are ranked,
+    then shape is required to match. The operation is invalid if converting
+    to a mismatching constant dimension.
   }];
 
   let arguments = (ins F64Tensor:$input);
   let results = (outs F64Tensor:$output);
+}
+```
+
+Note that the definition of this cast operation adds a `CastOpInterface` to the
+traits list. This interface provides several utilities for cast-like operation,
+such as folding identity casts and verification. We hook into this interface by
+providing a definition for the `areCastCompatible` method:
 
-  // Set the folder bit so that we can fold redundant cast operations.
-  let hasFolder = 1;
+```c++
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  // The inputs must be Tensors with the same element type.
+  TensorType input = inputs.front().dyn_cast<TensorType>();
+  TensorType output = outputs.front().dyn_cast<TensorType>();
+  if (!input || !output || input.getElementType() != output.getElementType())
+    return false;
+  // The shape is required to match if both types are ranked.
+  return !input.hasRank() || !output.hasRank() || input == output;
 }
+
 ```
 
-We can then override the necessary hook on the ToyInlinerInterface to insert
-this for us when necessary:
+With a proper cast operation, we can now override the necessary hook on the
+ToyInlinerInterface to insert it for us when necessary:
 
 ```c++
 struct ToyInlinerInterface : public DialectInlinerInterface {

diff  --git a/mlir/examples/toy/Ch4/CMakeLists.txt b/mlir/examples/toy/Ch4/CMakeLists.txt
index f303e2e40afd..468a46718d34 100644
--- a/mlir/examples/toy/Ch4/CMakeLists.txt
+++ b/mlir/examples/toy/Ch4/CMakeLists.txt
@@ -29,6 +29,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
 target_link_libraries(toyc-ch4
   PRIVATE
     MLIRAnalysis
+    MLIRCastInterfaces
     MLIRCallInterfaces
     MLIRIR
     MLIRParser

diff  --git a/mlir/examples/toy/Ch4/include/toy/Dialect.h b/mlir/examples/toy/Ch4/include/toy/Dialect.h
index 3c266cf02d6d..41d20fa7de22 100644
--- a/mlir/examples/toy/Ch4/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch4/include/toy/Dialect.h
@@ -17,6 +17,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "toy/ShapeInferenceInterface.h"
 

diff  --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td
index adf8ab2aec80..8ba8f1a69e33 100644
--- a/mlir/examples/toy/Ch4/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch4/include/toy/Ops.td
@@ -14,6 +14,7 @@
 #define TOY_OPS
 
 include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "toy/ShapeInferenceInterface.td"
 
@@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
   ];
 }
 
-def CastOp : Toy_Op<"cast",
-    [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
-     SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+     DeclareOpInterfaceMethods<CastOpInterface>,
+     DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
+     NoSideEffect,
+     SameOperandsAndResultShape
+  ]> {
   let summary = "shape cast operation";
   let description = [{
     The "cast" operation converts a tensor from one type to an equivalent type
-    without changing any data elements. The source and destination types
-    must both be tensor types with the same element type. If both are ranked
-    then the rank should be the same and static dimensions should match. The
-    operation is invalid if converting to a mismatching constant dimension.
+    without changing any data elements. The source and destination types must
+    both be tensor types with the same element type. If both are ranked, then
+    shape is required to match. The operation is invalid if converting to a
+    mismatching constant dimension.
   }];
 
   let arguments = (ins F64Tensor:$input);
   let results = (outs F64Tensor:$output);
 
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
-
-  // Set the folder bit so that we can fold redundant cast operations.
-  let hasFolder = 1;
 }
 
 def GenericCallOp : Toy_Op<"generic_call",

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index 0a3ec29b5707..0d317db1bdff 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
 /// inference interface.
 void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
 
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  // The inputs must be Tensors with the same element type.
+  TensorType input = inputs.front().dyn_cast<TensorType>();
+  TensorType output = outputs.front().dyn_cast<TensorType>();
+  if (!input || !output || input.getElementType() != output.getElementType())
+    return false;
+  // The shape is required to match if both types are ranked.
+  return !input.hasRank() || !output.hasRank() || input == output;
+}
+
 //===----------------------------------------------------------------------===//
 // GenericCallOp
 

diff  --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
index 18e77296ecf2..0af4cbfc11f1 100644
--- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
@@ -23,11 +23,6 @@ namespace {
 #include "ToyCombine.inc"
 } // end anonymous namespace
 
-/// Fold simple cast operations that return the same type as the input.
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
-  return mlir::impl::foldCastOp(*this);
-}
-
 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
 /// optimizes the following scenario: transpose(transpose(x)) -> x
 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {

diff  --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt
index 008e8f608573..7170550d3a7a 100644
--- a/mlir/examples/toy/Ch5/CMakeLists.txt
+++ b/mlir/examples/toy/Ch5/CMakeLists.txt
@@ -33,6 +33,7 @@ target_link_libraries(toyc-ch5
     ${dialect_libs}
     MLIRAnalysis
     MLIRCallInterfaces
+    MLIRCastInterfaces
     MLIRIR
     MLIRParser
     MLIRPass

diff  --git a/mlir/examples/toy/Ch5/include/toy/Dialect.h b/mlir/examples/toy/Ch5/include/toy/Dialect.h
index 3c266cf02d6d..41d20fa7de22 100644
--- a/mlir/examples/toy/Ch5/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch5/include/toy/Dialect.h
@@ -17,6 +17,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "toy/ShapeInferenceInterface.h"
 

diff  --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td
index ee7d6c4b340d..5e9a0a0f4fd7 100644
--- a/mlir/examples/toy/Ch5/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch5/include/toy/Ops.td
@@ -14,6 +14,7 @@
 #define TOY_OPS
 
 include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "toy/ShapeInferenceInterface.td"
 
@@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
   ];
 }
 
-def CastOp : Toy_Op<"cast",
-    [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
-     SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+     DeclareOpInterfaceMethods<CastOpInterface>,
+     DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
+     NoSideEffect,
+     SameOperandsAndResultShape
+  ]> {
   let summary = "shape cast operation";
   let description = [{
     The "cast" operation converts a tensor from one type to an equivalent type
-    without changing any data elements. The source and destination types
-    must both be tensor types with the same element type. If both are ranked
-    then the rank should be the same and static dimensions should match. The
-    operation is invalid if converting to a mismatching constant dimension.
+    without changing any data elements. The source and destination types must
+    both be tensor types with the same element type. If both are ranked, then
+    shape is required to match. The operation is invalid if converting to a
+    mismatching constant dimension.
   }];
 
   let arguments = (ins F64Tensor:$input);
   let results = (outs F64Tensor:$output);
 
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
-
-  // Set the folder bit so that we can fold redundant cast operations.
-  let hasFolder = 1;
 }
 
 def GenericCallOp : Toy_Op<"generic_call",

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index d4356f61f83b..a4ca119a9aaf 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
 /// inference interface.
 void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
 
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  // The inputs must be Tensors with the same element type.
+  TensorType input = inputs.front().dyn_cast<TensorType>();
+  TensorType output = outputs.front().dyn_cast<TensorType>();
+  if (!input || !output || input.getElementType() != output.getElementType())
+    return false;
+  // The shape is required to match if both types are ranked.
+  return !input.hasRank() || !output.hasRank() || input == output;
+}
+
 //===----------------------------------------------------------------------===//
 // GenericCallOp
 

diff  --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
index 18e77296ecf2..0af4cbfc11f1 100644
--- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
@@ -23,11 +23,6 @@ namespace {
 #include "ToyCombine.inc"
 } // end anonymous namespace
 
-/// Fold simple cast operations that return the same type as the input.
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
-  return mlir::impl::foldCastOp(*this);
-}
-
 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
 /// optimizes the following scenario: transpose(transpose(x)) -> x
 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {

diff  --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt
index 3d36f970fc84..99041c13f22b 100644
--- a/mlir/examples/toy/Ch6/CMakeLists.txt
+++ b/mlir/examples/toy/Ch6/CMakeLists.txt
@@ -39,6 +39,7 @@ target_link_libraries(toyc-ch6
     ${conversion_libs}
     MLIRAnalysis
     MLIRCallInterfaces
+    MLIRCastInterfaces
     MLIRExecutionEngine
     MLIRIR
     MLIRLLVMIR

diff  --git a/mlir/examples/toy/Ch6/include/toy/Dialect.h b/mlir/examples/toy/Ch6/include/toy/Dialect.h
index 3c266cf02d6d..41d20fa7de22 100644
--- a/mlir/examples/toy/Ch6/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch6/include/toy/Dialect.h
@@ -17,6 +17,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "toy/ShapeInferenceInterface.h"
 

diff  --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td
index 6f3998357010..2f9e169ce5ae 100644
--- a/mlir/examples/toy/Ch6/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch6/include/toy/Ops.td
@@ -14,6 +14,7 @@
 #define TOY_OPS
 
 include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "toy/ShapeInferenceInterface.td"
 
@@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
   ];
 }
 
-def CastOp : Toy_Op<"cast",
-    [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
-     SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+     DeclareOpInterfaceMethods<CastOpInterface>,
+     DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
+     NoSideEffect,
+     SameOperandsAndResultShape
+  ]> {
   let summary = "shape cast operation";
   let description = [{
     The "cast" operation converts a tensor from one type to an equivalent type
-    without changing any data elements. The source and destination types
-    must both be tensor types with the same element type. If both are ranked
-    then the rank should be the same and static dimensions should match. The
-    operation is invalid if converting to a mismatching constant dimension.
+    without changing any data elements. The source and destination types must
+    both be tensor types with the same element type. If both are ranked, then
+    shape is required to match. The operation is invalid if converting to a
+    mismatching constant dimension.
   }];
 
   let arguments = (ins F64Tensor:$input);
   let results = (outs F64Tensor:$output);
 
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
-
-  // Set the folder bit so that we can fold redundant cast operations.
-  let hasFolder = 1;
 }
 
 def GenericCallOp : Toy_Op<"generic_call",

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index d4356f61f83b..a4ca119a9aaf 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
 /// inference interface.
 void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
 
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  // The inputs must be Tensors with the same element type.
+  TensorType input = inputs.front().dyn_cast<TensorType>();
+  TensorType output = outputs.front().dyn_cast<TensorType>();
+  if (!input || !output || input.getElementType() != output.getElementType())
+    return false;
+  // The shape is required to match if both types are ranked.
+  return !input.hasRank() || !output.hasRank() || input == output;
+}
+
 //===----------------------------------------------------------------------===//
 // GenericCallOp
 

diff  --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
index 18e77296ecf2..0af4cbfc11f1 100644
--- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
@@ -23,11 +23,6 @@ namespace {
 #include "ToyCombine.inc"
 } // end anonymous namespace
 
-/// Fold simple cast operations that return the same type as the input.
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
-  return mlir::impl::foldCastOp(*this);
-}
-
 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
 /// optimizes the following scenario: transpose(transpose(x)) -> x
 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {

diff  --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt
index 639d377d4b66..c24ad53144e1 100644
--- a/mlir/examples/toy/Ch7/CMakeLists.txt
+++ b/mlir/examples/toy/Ch7/CMakeLists.txt
@@ -39,6 +39,7 @@ target_link_libraries(toyc-ch7
     ${conversion_libs}
     MLIRAnalysis
     MLIRCallInterfaces
+    MLIRCastInterfaces
     MLIRExecutionEngine
     MLIRIR
     MLIRParser

diff  --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h
index 15b55d03f77e..1b754f3d1089 100644
--- a/mlir/examples/toy/Ch7/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h
@@ -17,6 +17,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "toy/ShapeInferenceInterface.h"
 

diff  --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index c73a6b94e903..3984a5cfa8bd 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -14,6 +14,7 @@
 #define TOY_OPS
 
 include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "toy/ShapeInferenceInterface.td"
 
@@ -115,25 +116,25 @@ def AddOp : Toy_Op<"add",
   ];
 }
 
-def CastOp : Toy_Op<"cast",
-    [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
-     SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+     DeclareOpInterfaceMethods<CastOpInterface>,
+     DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
+     NoSideEffect,
+     SameOperandsAndResultShape
+  ]> {
   let summary = "shape cast operation";
   let description = [{
     The "cast" operation converts a tensor from one type to an equivalent type
-    without changing any data elements. The source and destination types
-    must both be tensor types with the same element type. If both are ranked
-    then the rank should be the same and static dimensions should match. The
-    operation is invalid if converting to a mismatching constant dimension.
+    without changing any data elements. The source and destination types must
+    both be tensor types with the same element type. If both are ranked, then
+    shape is required to match. The operation is invalid if converting to a
+    mismatching constant dimension.
   }];
 
   let arguments = (ins F64Tensor:$input);
   let results = (outs F64Tensor:$output);
 
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
-
-  // Set the folder bit so that we can fold redundant cast operations.
-  let hasFolder = 1;
 }
 
 def GenericCallOp : Toy_Op<"generic_call",

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 659b82e6d80e..9fe1635f4760 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -284,6 +284,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
 /// inference interface.
 void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
 
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  // The inputs must be Tensors with the same element type.
+  TensorType input = inputs.front().dyn_cast<TensorType>();
+  TensorType output = outputs.front().dyn_cast<TensorType>();
+  if (!input || !output || input.getElementType() != output.getElementType())
+    return false;
+  // The shape is required to match if both types are ranked.
+  return !input.hasRank() || !output.hasRank() || input == output;
+}
+
 //===----------------------------------------------------------------------===//
 // GenericCallOp
 

diff  --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
index 60db56bbbc08..bfbd36b40fa0 100644
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -23,11 +23,6 @@ namespace {
 #include "ToyCombine.inc"
 } // end anonymous namespace
 
-/// Fold simple cast operations that return the same type as the input.
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
-  return mlir::impl::foldCastOp(*this);
-}
-
 /// Fold constants.
 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
 

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 56ff32252fee..30905d3af411 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -19,6 +19,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 6dbb24a4358f..f4caa7da1721 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
@@ -45,9 +46,10 @@ class Std_Op<string mnemonic, list<OpTrait> traits = []> :
 // Base class for standard cast operations. Requires single operand and result,
 // but does not constrain them to specific types.
 class CastOp<string mnemonic, list<OpTrait> traits = []> :
-    Std_Op<mnemonic,
-           !listconcat(traits, [NoSideEffect, SameOperandsAndResultShape])> {
-
+    Std_Op<mnemonic, traits # [
+      NoSideEffect, SameOperandsAndResultShape,
+      DeclareOpInterfaceMethods<CastOpInterface>
+    ]> {
   let results = (outs AnyType);
 
   let builders = [
@@ -62,9 +64,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
   let printer = [{
     return printStandardCastOp(this->getOperation(), p);
   }];
-  let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
 
-  let hasFolder = 1;
+  // Cast operations are fully verified by its traits.
+  let verifier = ?;
 }
 
 // Base class for arithmetic cast operations.
@@ -1643,14 +1645,6 @@ def FPExtOp : ArithmeticCastOp<"fpext">, Arguments<(ins AnyType:$in)> {
     The destination type must to be strictly wider than the source type.
     Only scalars are currently supported.
   }];
-
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-  }];
-
-  let hasFolder = 0;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1663,14 +1657,6 @@ def FPToSIOp : ArithmeticCastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
     Cast from a value interpreted as floating-point to the nearest (rounding
     towards zero) signed integer value.
   }];
-
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-  }];
-
-  let hasFolder = 0;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1683,14 +1669,6 @@ def FPToUIOp : ArithmeticCastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
     Cast from a value interpreted as floating-point to the nearest (rounding
     towards zero) unsigned integer value.
   }];
-
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-  }];
-
-  let hasFolder = 0;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1705,14 +1683,6 @@ def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
     If the value cannot be exactly represented, it is rounded using the default
     rounding mode. Only scalars are currently supported.
   }];
-
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-  }];
-
-  let hasFolder = 0;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1849,12 +1819,6 @@ def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
     sign-extended. If casting to a narrower integer, the value is truncated.
   }];
 
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-  }];
-
   let hasFolder = 1;
 }
 
@@ -2045,14 +2009,7 @@ def MemRefCastOp : CastOp<"memref_cast", [
   let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
   let results = (outs AnyRankedOrUnrankedMemRef);
 
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-
-    /// The result of a memref_cast is always a memref.
-    Type getType() { return getResult().getType(); }
-  }];
+  let hasFolder = 1;
 }
 
 
@@ -2786,14 +2743,6 @@ def SIToFPOp : ArithmeticCastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
     exactly represented, it is rounded using the default rounding mode. Scalars
     and vector types are currently supported.
   }];
-
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-  }];
-
-  let hasFolder = 0;
 }
 
 //===----------------------------------------------------------------------===//
@@ -3628,14 +3577,6 @@ def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
     value cannot be exactly represented, it is rounded using the default
     rounding mode. Scalars and vector types are currently supported.
   }];
-
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-  }];
-
-  let hasFolder = 0;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 3a1a20835959..98d16c7acfe6 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -13,6 +13,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e7776c4e8a9b..d45f1f61cdd7 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -10,6 +10,7 @@
 #define TENSOR_OPS
 
 include "mlir/Dialect/Tensor/IR/TensorBase.td"
+include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
@@ -24,7 +25,9 @@ class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
 // CastOp
 //===----------------------------------------------------------------------===//
 
-def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
+def Tensor_CastOp : Tensor_Op<"cast", [
+    DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
+  ]> {
   let summary = "tensor cast operation";
   let description = [{
     Convert a tensor from one type to an equivalent type without changing any
@@ -51,19 +54,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
   let arguments = (ins AnyTensor:$source);
   let results = (outs AnyTensor:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
-  let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
-
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-
-    /// The result of a tensor.cast is always a tensor.
-    TensorType getType() { return getResult().getType().cast<TensorType>(); }
-  }];
 
-  let hasFolder = 1;
   let hasCanonicalizer = 1;
+  let verifier = ?;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 0dbf18284131..1fee314b285c 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -50,6 +50,34 @@ enum class DiagnosticSeverity {
 /// A variant type that holds a single argument for a diagnostic.
 class DiagnosticArgument {
 public:
+  /// Note: The constructors below are only exposed due to problems accessing
+  /// constructors from type traits, they should not be used directly by users.
+  // Construct from an Attribute.
+  explicit DiagnosticArgument(Attribute attr);
+  // Construct from a floating point number.
+  explicit DiagnosticArgument(double val)
+      : kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
+  explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
+  // Construct from a signed integer.
+  template <typename T>
+  explicit DiagnosticArgument(
+      T val, typename std::enable_if<std::is_signed<T>::value &&
+                                     std::numeric_limits<T>::is_integer &&
+                                     sizeof(T) <= sizeof(int64_t)>::type * = 0)
+      : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
+  // Construct from an unsigned integer.
+  template <typename T>
+  explicit DiagnosticArgument(
+      T val, typename std::enable_if<std::is_unsigned<T>::value &&
+                                     std::numeric_limits<T>::is_integer &&
+                                     sizeof(T) <= sizeof(uint64_t)>::type * = 0)
+      : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
+  // Construct from a string reference.
+  explicit DiagnosticArgument(StringRef val)
+      : kind(DiagnosticArgumentKind::String), stringVal(val) {}
+  // Construct from a Type.
+  explicit DiagnosticArgument(Type val);
+
   /// Enum that represents the 
diff erent kinds of diagnostic arguments
   /// supported.
   enum class DiagnosticArgumentKind {
@@ -100,37 +128,6 @@ class DiagnosticArgument {
 private:
   friend class Diagnostic;
 
-  // Construct from an Attribute.
-  explicit DiagnosticArgument(Attribute attr);
-
-  // Construct from a floating point number.
-  explicit DiagnosticArgument(double val)
-      : kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
-  explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
-
-  // Construct from a signed integer.
-  template <typename T>
-  explicit DiagnosticArgument(
-      T val, typename std::enable_if<std::is_signed<T>::value &&
-                                     std::numeric_limits<T>::is_integer &&
-                                     sizeof(T) <= sizeof(int64_t)>::type * = 0)
-      : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
-
-  // Construct from an unsigned integer.
-  template <typename T>
-  explicit DiagnosticArgument(
-      T val, typename std::enable_if<std::is_unsigned<T>::value &&
-                                     std::numeric_limits<T>::is_integer &&
-                                     sizeof(T) <= sizeof(uint64_t)>::type * = 0)
-      : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
-
-  // Construct from a string reference.
-  explicit DiagnosticArgument(StringRef val)
-      : kind(DiagnosticArgumentKind::String), stringVal(val) {}
-
-  // Construct from a Type.
-  explicit DiagnosticArgument(Type val);
-
   /// The kind of this argument.
   DiagnosticArgumentKind kind;
 
@@ -189,8 +186,10 @@ class Diagnostic {
 
   /// Stream operator for inserting new diagnostic arguments.
   template <typename Arg>
-  typename std::enable_if<!std::is_convertible<Arg, StringRef>::value,
-                          Diagnostic &>::type
+  typename std::enable_if<
+      !std::is_convertible<Arg, StringRef>::value &&
+          std::is_constructible<DiagnosticArgument, Arg>::value,
+      Diagnostic &>::type
   operator<<(Arg &&val) {
     arguments.push_back(DiagnosticArgument(std::forward<Arg>(val)));
     return *this;
@@ -220,17 +219,17 @@ class Diagnostic {
   }
 
   /// Stream in a range.
-  template <typename T> Diagnostic &operator<<(iterator_range<T> range) {
-    return appendRange(range);
-  }
-  template <typename T> Diagnostic &operator<<(ArrayRef<T> range) {
+  template <typename T, typename ValueT = llvm::detail::ValueOfRange<T>>
+  std::enable_if_t<!std::is_constructible<DiagnosticArgument, T>::value,
+                   Diagnostic &>
+  operator<<(T &&range) {
     return appendRange(range);
   }
 
   /// Append a range to the diagnostic. The default delimiter between elements
   /// is ','.
-  template <typename T, template <typename> class Container>
-  Diagnostic &appendRange(const Container<T> &c, const char *delim = ", ") {
+  template <typename T>
+  Diagnostic &appendRange(const T &c, const char *delim = ", ") {
     llvm::interleave(
         c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; });
     return *this;

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 622952c76289..c021bdc8ee9d 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1822,18 +1822,27 @@ ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
 void printOneResultOp(Operation *op, OpAsmPrinter &p);
 } // namespace impl
 
-// These functions are out-of-line implementations of the methods in CastOp,
-// which avoids them being template instantiated/duplicated.
+// These functions are out-of-line implementations of the methods in
+// CastOpInterface, which avoids them being template instantiated/duplicated.
 namespace impl {
+/// Attempt to fold the given cast operation.
+LogicalResult foldCastInterfaceOp(Operation *op,
+                                  ArrayRef<Attribute> attrOperands,
+                                  SmallVectorImpl<OpFoldResult> &foldResults);
+/// Attempt to verify the given cast operation.
+LogicalResult verifyCastInterfaceOp(
+    Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
+
 // TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
 // need for them, but some older ODS code in `std` still depends on them).
 void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
                  Type destType);
 ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
 void printCastOp(Operation *op, OpAsmPrinter &p);
-// TODO: Create a CastOpInterface with a method areCastCompatible.
-// Also, consider adding functionality to CastOpInterface to be able to perform
-// the ChainedTensorCast canonicalization generically.
+// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
+// when all uses have been updated. Also, consider adding functionality to
+// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
+// generically.
 Value foldCastOp(Operation *op);
 LogicalResult verifyCastOp(Operation *op,
                            function_ref<bool(Type, Type)> areCastCompatible);

diff  --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 65e19f3eec1b..3bb154988b01 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_interface(CallInterfaces)
+add_mlir_interface(CastInterfaces)
 add_mlir_interface(ControlFlowInterfaces)
 add_mlir_interface(CopyOpInterface)
 add_mlir_interface(DerivedAttributeOpInterface)

diff  --git a/mlir/include/mlir/Interfaces/CastInterfaces.h b/mlir/include/mlir/Interfaces/CastInterfaces.h
new file mode 100644
index 000000000000..99a1f2ed7821
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/CastInterfaces.h
@@ -0,0 +1,22 @@
+//===- CastInterfaces.h - Cast Interfaces for MLIR --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definitions of the cast interfaces defined in
+// `CastInterfaces.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_CASTINTERFACES_H
+#define MLIR_INTERFACES_CASTINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/CastInterfaces.h.inc"
+
+#endif // MLIR_INTERFACES_CASTINTERFACES_H

diff  --git a/mlir/include/mlir/Interfaces/CastInterfaces.td b/mlir/include/mlir/Interfaces/CastInterfaces.td
new file mode 100644
index 000000000000..c2a01df42c7f
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/CastInterfaces.td
@@ -0,0 +1,51 @@
+//===- CastInterfaces.td - Cast Interfaces for ops ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces that can be used to define information
+// related to cast-like operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_CASTINTERFACES
+#define MLIR_INTERFACES_CASTINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def CastOpInterface : OpInterface<"CastOpInterface"> {
+  let description = [{
+    A cast-like operation is one that converts from a set of input types to a
+    set of output types. The arity of the inputs may be from 0-N, whereas the
+    arity of the outputs may be anything from 1-N. Cast-like operations are
+    trivially removable in cases where they produce an No-op, i.e when the
+    input types and output types match 1-1.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    StaticInterfaceMethod<[{
+        Returns true if the given set of input and result types are compatible
+        to cast using this cast operation.
+      }],
+      "bool", "areCastCompatible",
+      (ins "mlir::TypeRange":$inputs, "mlir::TypeRange":$outputs)
+    >,
+  ];
+
+  let extraTraitClassDeclaration = [{
+    /// Attempt to fold the given cast operation.
+    static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
+                                   SmallVectorImpl<OpFoldResult> &results) {
+      return impl::foldCastInterfaceOp(op, operands, results);
+    }
+  }];
+  let verify = [{
+    return impl::verifyCastInterfaceOp($_op, ConcreteOp::areCastCompatible);
+  }];
+}
+
+#endif // MLIR_INTERFACES_CASTINTERFACES

diff  --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
index f8321842db31..cf2bf19c4f6f 100644
--- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRShape
   MLIRShapeOpsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRCastInterfaces
   MLIRControlFlowInterfaces
   MLIRDialect
   MLIRInferTypeOpInterface

diff  --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index 67f285817a91..058e680ef677 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRStandard
 
   LINK_LIBS PUBLIC
   MLIRCallInterfaces
+  MLIRCastInterfaces
   MLIRControlFlowInterfaces
   MLIREDSC
   MLIRIR

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 428006e20d9f..45dd0fd0086a 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -195,7 +195,8 @@ static LogicalResult foldMemRefCast(Operation *op) {
 /// Returns 'true' if the vector types are cast compatible,  and 'false'
 /// otherwise.
 static bool areVectorCastSimpleCompatible(
-    Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
+    Type a, Type b,
+    function_ref<bool(TypeRange, TypeRange)> areElementsCastCompatible) {
   if (auto va = a.dyn_cast<VectorType>())
     if (auto vb = b.dyn_cast<VectorType>())
       return va.getShape().equals(vb.getShape()) &&
@@ -1746,7 +1747,10 @@ LogicalResult DmaWaitOp::verify() {
 // FPExtOp
 //===----------------------------------------------------------------------===//
 
-bool FPExtOp::areCastCompatible(Type a, Type b) {
+bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
   if (auto fa = a.dyn_cast<FloatType>())
     if (auto fb = b.dyn_cast<FloatType>())
       return fa.getWidth() < fb.getWidth();
@@ -1757,7 +1761,10 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
 // FPToSIOp
 //===----------------------------------------------------------------------===//
 
-bool FPToSIOp::areCastCompatible(Type a, Type b) {
+bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
   if (a.isa<FloatType>() && b.isSignlessInteger())
     return true;
   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
@@ -1767,7 +1774,10 @@ bool FPToSIOp::areCastCompatible(Type a, Type b) {
 // FPToUIOp
 //===----------------------------------------------------------------------===//
 
-bool FPToUIOp::areCastCompatible(Type a, Type b) {
+bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
   if (a.isa<FloatType>() && b.isSignlessInteger())
     return true;
   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
@@ -1777,7 +1787,10 @@ bool FPToUIOp::areCastCompatible(Type a, Type b) {
 // FPTruncOp
 //===----------------------------------------------------------------------===//
 
-bool FPTruncOp::areCastCompatible(Type a, Type b) {
+bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
   if (auto fa = a.dyn_cast<FloatType>())
     if (auto fb = b.dyn_cast<FloatType>())
       return fa.getWidth() > fb.getWidth();
@@ -1889,7 +1902,10 @@ GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 // Index cast is applicable from index to integer and backwards.
-bool IndexCastOp::areCastCompatible(Type a, Type b) {
+bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
   if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
     auto aShaped = a.cast<ShapedType>();
     auto bShaped = b.cast<ShapedType>();
@@ -1965,7 +1981,10 @@ void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 
 Value MemRefCastOp::getViewSource() { return source(); }
 
-bool MemRefCastOp::areCastCompatible(Type a, Type b) {
+bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
   auto aT = a.dyn_cast<MemRefType>();
   auto bT = b.dyn_cast<MemRefType>();
 
@@ -2036,8 +2055,6 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
 }
 
 OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
-  if (Value folded = impl::foldCastOp(*this))
-    return folded;
   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
@@ -2633,7 +2650,10 @@ OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 // sitofp is applicable from integer types to float types.
-bool SIToFPOp::areCastCompatible(Type a, Type b) {
+bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
   if (a.isSignlessInteger() && b.isa<FloatType>())
     return true;
   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
@@ -2715,7 +2735,10 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 // uitofp is applicable from integer types to float types.
-bool UIToFPOp::areCastCompatible(Type a, Type b) {
+bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
   if (a.isSignlessInteger() && b.isa<FloatType>())
     return true;
   return areVectorCastSimpleCompatible(a, b, areCastCompatible);

diff  --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index b8fb44a9f4cb..de650995ebb6 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRTensor
   Core
 
   LINK_LIBS PUBLIC
+  MLIRCastInterfaces
   MLIRIR
   MLIRSideEffectInterfaces
   MLIRSupport

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e231a3a3b56e..92115d51476e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -73,7 +73,10 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
   return true;
 }
 
-bool CastOp::areCastCompatible(Type a, Type b) {
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
   auto aT = a.dyn_cast<TensorType>();
   auto bT = b.dyn_cast<TensorType>();
   if (!aT || !bT)
@@ -85,10 +88,6 @@ bool CastOp::areCastCompatible(Type a, Type b) {
   return succeeded(verifyCompatibleShape(aT, bT));
 }
 
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
-  return impl::foldCastOp(*this);
-}
-
 /// Compute a TensorType that has the joined shape knowledge of the two
 /// given TensorTypes. The element types need to match.
 static TensorType joinShapes(TensorType one, TensorType two) {

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 4152121dd548..ba1d1e5109cc 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1208,6 +1208,48 @@ void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
 // CastOp implementation
 //===----------------------------------------------------------------------===//
 
+/// Attempt to fold the given cast operation.
+LogicalResult
+impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands,
+                          SmallVectorImpl<OpFoldResult> &foldResults) {
+  OperandRange operands = op->getOperands();
+  if (operands.empty())
+    return failure();
+  ResultRange results = op->getResults();
+
+  // Check for the case where the input and output types match 1-1.
+  if (operands.getTypes() == results.getTypes()) {
+    foldResults.append(operands.begin(), operands.end());
+    return success();
+  }
+
+  return failure();
+}
+
+/// Attempt to verify the given cast operation.
+LogicalResult impl::verifyCastInterfaceOp(
+    Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) {
+  auto resultTypes = op->getResultTypes();
+  if (llvm::empty(resultTypes))
+    return op->emitOpError()
+           << "expected at least one result for cast operation";
+
+  auto operandTypes = op->getOperandTypes();
+  if (!areCastCompatible(operandTypes, resultTypes)) {
+    InFlightDiagnostic diag = op->emitOpError("operand type");
+    if (llvm::empty(operandTypes))
+      diag << "s []";
+    else if (llvm::size(operandTypes) == 1)
+      diag << " " << *operandTypes.begin();
+    else
+      diag << "s " << operandTypes;
+    return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ")
+                << resultTypes << " are cast incompatible";
+  }
+
+  return success();
+}
+
 void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source,
                        Type destType) {
   result.addOperands(source);

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 0a8f75b6f7d9..26b484bc5d78 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -1,5 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   CallInterfaces.cpp
+  CastInterfaces.cpp
   ControlFlowInterfaces.cpp
   CopyOpInterface.cpp
   DerivedAttributeOpInterface.cpp
@@ -27,6 +28,7 @@ endfunction(add_mlir_interface_library)
 
 
 add_mlir_interface_library(CallInterfaces)
+add_mlir_interface_library(CastInterfaces)
 add_mlir_interface_library(ControlFlowInterfaces)
 add_mlir_interface_library(CopyOpInterface)
 add_mlir_interface_library(DerivedAttributeOpInterface)

diff  --git a/mlir/lib/Interfaces/CastInterfaces.cpp b/mlir/lib/Interfaces/CastInterfaces.cpp
new file mode 100644
index 000000000000..400c1978cdaa
--- /dev/null
+++ b/mlir/lib/Interfaces/CastInterfaces.cpp
@@ -0,0 +1,17 @@
+//===- CastInterfaces.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/CastInterfaces.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Table-generated class definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/CastInterfaces.cpp.inc"


        


More information about the Mlir-commits mailing list