[llvm-branch-commits] [mlir] 6ccf2d6 - [mlir] Add an interface for Cast-Like operations
River Riddle via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jan 20 16:33:22 PST 2021
Author: River Riddle
Date: 2021-01-20T16:28:17-08:00
New Revision: 6ccf2d62b4876c88427ae97d0cd3c9ed4330560a
URL: https://github.com/llvm/llvm-project/commit/6ccf2d62b4876c88427ae97d0cd3c9ed4330560a
DIFF: https://github.com/llvm/llvm-project/commit/6ccf2d62b4876c88427ae97d0cd3c9ed4330560a.diff
LOG: [mlir] Add an interface for Cast-Like operations
A cast-like operation is one that converts from a set of input types to a set of output types. The arity of the inputs may be from 0-N, whereas the arity of the outputs may be anything from 1-N. Cast-like operations are removable in cases where they produce a "no-op", i.e when the input types and output types match 1-1.
Differential Revision: https://reviews.llvm.org/D94831
Added:
mlir/include/mlir/Interfaces/CastInterfaces.h
mlir/include/mlir/Interfaces/CastInterfaces.td
mlir/lib/Interfaces/CastInterfaces.cpp
Modified:
mlir/docs/Tutorials/Toy/Ch-4.md
mlir/examples/toy/Ch4/CMakeLists.txt
mlir/examples/toy/Ch4/include/toy/Dialect.h
mlir/examples/toy/Ch4/include/toy/Ops.td
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
mlir/examples/toy/Ch5/CMakeLists.txt
mlir/examples/toy/Ch5/include/toy/Dialect.h
mlir/examples/toy/Ch5/include/toy/Ops.td
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
mlir/examples/toy/Ch6/CMakeLists.txt
mlir/examples/toy/Ch6/include/toy/Dialect.h
mlir/examples/toy/Ch6/include/toy/Ops.td
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
mlir/examples/toy/Ch7/CMakeLists.txt
mlir/examples/toy/Ch7/include/toy/Dialect.h
mlir/examples/toy/Ch7/include/toy/Ops.td
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/IR/Diagnostics.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/lib/Dialect/Shape/IR/CMakeLists.txt
mlir/lib/Dialect/StandardOps/CMakeLists.txt
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Interfaces/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md
index dc1314419320..c454a762e2b5 100644
--- a/mlir/docs/Tutorials/Toy/Ch-4.md
+++ b/mlir/docs/Tutorials/Toy/Ch-4.md
@@ -182,26 +182,50 @@ to add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent
casts between two
diff erent shapes.
```tablegen
-def CastOp : Toy_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+ DeclareOpInterfaceMethods<CastOpInterface>,
+ NoSideEffect,
+ SameOperandsAndResultShape]
+ > {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
- must both be tensor types with the same element type. If both are ranked
- then the rank should be the same and static dimensions should match. The
- operation is invalid if converting to a mismatching constant dimension.
+ must both be tensor types with the same element type. If both are ranked,
+ then shape is required to match. The operation is invalid if converting
+ to a mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
+}
+```
+
+Note that the definition of this cast operation adds a `CastOpInterface` to the
+traits list. This interface provides several utilities for cast-like operation,
+such as folding identity casts and verification. We hook into this interface by
+providing a definition for the `areCastCompatible` method:
- // Set the folder bit so that we can fold redundant cast operations.
- let hasFolder = 1;
+```c++
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ // The inputs must be Tensors with the same element type.
+ TensorType input = inputs.front().dyn_cast<TensorType>();
+ TensorType output = outputs.front().dyn_cast<TensorType>();
+ if (!input || !output || input.getElementType() != output.getElementType())
+ return false;
+ // The shape is required to match if both types are ranked.
+ return !input.hasRank() || !output.hasRank() || input == output;
}
+
```
-We can then override the necessary hook on the ToyInlinerInterface to insert
-this for us when necessary:
+With a proper cast operation, we can now override the necessary hook on the
+ToyInlinerInterface to insert it for us when necessary:
```c++
struct ToyInlinerInterface : public DialectInlinerInterface {
diff --git a/mlir/examples/toy/Ch4/CMakeLists.txt b/mlir/examples/toy/Ch4/CMakeLists.txt
index f303e2e40afd..468a46718d34 100644
--- a/mlir/examples/toy/Ch4/CMakeLists.txt
+++ b/mlir/examples/toy/Ch4/CMakeLists.txt
@@ -29,6 +29,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
target_link_libraries(toyc-ch4
PRIVATE
MLIRAnalysis
+ MLIRCastInterfaces
MLIRCallInterfaces
MLIRIR
MLIRParser
diff --git a/mlir/examples/toy/Ch4/include/toy/Dialect.h b/mlir/examples/toy/Ch4/include/toy/Dialect.h
index 3c266cf02d6d..41d20fa7de22 100644
--- a/mlir/examples/toy/Ch4/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch4/include/toy/Dialect.h
@@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"
diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td
index adf8ab2aec80..8ba8f1a69e33 100644
--- a/mlir/examples/toy/Ch4/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch4/include/toy/Ops.td
@@ -14,6 +14,7 @@
#define TOY_OPS
include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "toy/ShapeInferenceInterface.td"
@@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
];
}
-def CastOp : Toy_Op<"cast",
- [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
- SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+ DeclareOpInterfaceMethods<CastOpInterface>,
+ DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
+ NoSideEffect,
+ SameOperandsAndResultShape
+ ]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
- without changing any data elements. The source and destination types
- must both be tensor types with the same element type. If both are ranked
- then the rank should be the same and static dimensions should match. The
- operation is invalid if converting to a mismatching constant dimension.
+ without changing any data elements. The source and destination types must
+ both be tensor types with the same element type. If both are ranked, then
+ shape is required to match. The operation is invalid if converting to a
+ mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
-
- // Set the folder bit so that we can fold redundant cast operations.
- let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",
diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index 0a3ec29b5707..0d317db1bdff 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// inference interface.
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ // The inputs must be Tensors with the same element type.
+ TensorType input = inputs.front().dyn_cast<TensorType>();
+ TensorType output = outputs.front().dyn_cast<TensorType>();
+ if (!input || !output || input.getElementType() != output.getElementType())
+ return false;
+ // The shape is required to match if both types are ranked.
+ return !input.hasRank() || !output.hasRank() || input == output;
+}
+
//===----------------------------------------------------------------------===//
// GenericCallOp
diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
index 18e77296ecf2..0af4cbfc11f1 100644
--- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
@@ -23,11 +23,6 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
-/// Fold simple cast operations that return the same type as the input.
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
- return mlir::impl::foldCastOp(*this);
-}
-
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> x
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt
index 008e8f608573..7170550d3a7a 100644
--- a/mlir/examples/toy/Ch5/CMakeLists.txt
+++ b/mlir/examples/toy/Ch5/CMakeLists.txt
@@ -33,6 +33,7 @@ target_link_libraries(toyc-ch5
${dialect_libs}
MLIRAnalysis
MLIRCallInterfaces
+ MLIRCastInterfaces
MLIRIR
MLIRParser
MLIRPass
diff --git a/mlir/examples/toy/Ch5/include/toy/Dialect.h b/mlir/examples/toy/Ch5/include/toy/Dialect.h
index 3c266cf02d6d..41d20fa7de22 100644
--- a/mlir/examples/toy/Ch5/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch5/include/toy/Dialect.h
@@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"
diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td
index ee7d6c4b340d..5e9a0a0f4fd7 100644
--- a/mlir/examples/toy/Ch5/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch5/include/toy/Ops.td
@@ -14,6 +14,7 @@
#define TOY_OPS
include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "toy/ShapeInferenceInterface.td"
@@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
];
}
-def CastOp : Toy_Op<"cast",
- [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
- SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+ DeclareOpInterfaceMethods<CastOpInterface>,
+ DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
+ NoSideEffect,
+ SameOperandsAndResultShape
+ ]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
- without changing any data elements. The source and destination types
- must both be tensor types with the same element type. If both are ranked
- then the rank should be the same and static dimensions should match. The
- operation is invalid if converting to a mismatching constant dimension.
+ without changing any data elements. The source and destination types must
+ both be tensor types with the same element type. If both are ranked, then
+ shape is required to match. The operation is invalid if converting to a
+ mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
-
- // Set the folder bit so that we can fold redundant cast operations.
- let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",
diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index d4356f61f83b..a4ca119a9aaf 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// inference interface.
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ // The inputs must be Tensors with the same element type.
+ TensorType input = inputs.front().dyn_cast<TensorType>();
+ TensorType output = outputs.front().dyn_cast<TensorType>();
+ if (!input || !output || input.getElementType() != output.getElementType())
+ return false;
+ // The shape is required to match if both types are ranked.
+ return !input.hasRank() || !output.hasRank() || input == output;
+}
+
//===----------------------------------------------------------------------===//
// GenericCallOp
diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
index 18e77296ecf2..0af4cbfc11f1 100644
--- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
@@ -23,11 +23,6 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
-/// Fold simple cast operations that return the same type as the input.
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
- return mlir::impl::foldCastOp(*this);
-}
-
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> x
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
diff --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt
index 3d36f970fc84..99041c13f22b 100644
--- a/mlir/examples/toy/Ch6/CMakeLists.txt
+++ b/mlir/examples/toy/Ch6/CMakeLists.txt
@@ -39,6 +39,7 @@ target_link_libraries(toyc-ch6
${conversion_libs}
MLIRAnalysis
MLIRCallInterfaces
+ MLIRCastInterfaces
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
diff --git a/mlir/examples/toy/Ch6/include/toy/Dialect.h b/mlir/examples/toy/Ch6/include/toy/Dialect.h
index 3c266cf02d6d..41d20fa7de22 100644
--- a/mlir/examples/toy/Ch6/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch6/include/toy/Dialect.h
@@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"
diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td
index 6f3998357010..2f9e169ce5ae 100644
--- a/mlir/examples/toy/Ch6/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch6/include/toy/Ops.td
@@ -14,6 +14,7 @@
#define TOY_OPS
include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "toy/ShapeInferenceInterface.td"
@@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
];
}
-def CastOp : Toy_Op<"cast",
- [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
- SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+ DeclareOpInterfaceMethods<CastOpInterface>,
+ DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
+ NoSideEffect,
+ SameOperandsAndResultShape
+ ]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
- without changing any data elements. The source and destination types
- must both be tensor types with the same element type. If both are ranked
- then the rank should be the same and static dimensions should match. The
- operation is invalid if converting to a mismatching constant dimension.
+ without changing any data elements. The source and destination types must
+ both be tensor types with the same element type. If both are ranked, then
+ shape is required to match. The operation is invalid if converting to a
+ mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
-
- // Set the folder bit so that we can fold redundant cast operations.
- let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",
diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index d4356f61f83b..a4ca119a9aaf 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// inference interface.
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ // The inputs must be Tensors with the same element type.
+ TensorType input = inputs.front().dyn_cast<TensorType>();
+ TensorType output = outputs.front().dyn_cast<TensorType>();
+ if (!input || !output || input.getElementType() != output.getElementType())
+ return false;
+ // The shape is required to match if both types are ranked.
+ return !input.hasRank() || !output.hasRank() || input == output;
+}
+
//===----------------------------------------------------------------------===//
// GenericCallOp
diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
index 18e77296ecf2..0af4cbfc11f1 100644
--- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
@@ -23,11 +23,6 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
-/// Fold simple cast operations that return the same type as the input.
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
- return mlir::impl::foldCastOp(*this);
-}
-
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> x
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt
index 639d377d4b66..c24ad53144e1 100644
--- a/mlir/examples/toy/Ch7/CMakeLists.txt
+++ b/mlir/examples/toy/Ch7/CMakeLists.txt
@@ -39,6 +39,7 @@ target_link_libraries(toyc-ch7
${conversion_libs}
MLIRAnalysis
MLIRCallInterfaces
+ MLIRCastInterfaces
MLIRExecutionEngine
MLIRIR
MLIRParser
diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h
index 15b55d03f77e..1b754f3d1089 100644
--- a/mlir/examples/toy/Ch7/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h
@@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"
diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index c73a6b94e903..3984a5cfa8bd 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -14,6 +14,7 @@
#define TOY_OPS
include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "toy/ShapeInferenceInterface.td"
@@ -115,25 +116,25 @@ def AddOp : Toy_Op<"add",
];
}
-def CastOp : Toy_Op<"cast",
- [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
- SameOperandsAndResultShape]> {
+def CastOp : Toy_Op<"cast", [
+ DeclareOpInterfaceMethods<CastOpInterface>,
+ DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
+ NoSideEffect,
+ SameOperandsAndResultShape
+ ]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
- without changing any data elements. The source and destination types
- must both be tensor types with the same element type. If both are ranked
- then the rank should be the same and static dimensions should match. The
- operation is invalid if converting to a mismatching constant dimension.
+ without changing any data elements. The source and destination types must
+ both be tensor types with the same element type. If both are ranked, then
+ shape is required to match. The operation is invalid if converting to a
+ mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
-
- // Set the folder bit so that we can fold redundant cast operations.
- let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 659b82e6d80e..9fe1635f4760 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -284,6 +284,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// inference interface.
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+/// Returns true if the given set of input and result types are compatible with
+/// this cast operation. This is required by the `CastOpInterface` to verify
+/// this operation and provide other additional utilities.
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ // The inputs must be Tensors with the same element type.
+ TensorType input = inputs.front().dyn_cast<TensorType>();
+ TensorType output = outputs.front().dyn_cast<TensorType>();
+ if (!input || !output || input.getElementType() != output.getElementType())
+ return false;
+ // The shape is required to match if both types are ranked.
+ return !input.hasRank() || !output.hasRank() || input == output;
+}
+
//===----------------------------------------------------------------------===//
// GenericCallOp
diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
index 60db56bbbc08..bfbd36b40fa0 100644
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -23,11 +23,6 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
-/// Fold simple cast operations that return the same type as the input.
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
- return mlir::impl::foldCastOp(*this);
-}
-
/// Fold constants.
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 56ff32252fee..30905d3af411 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -19,6 +19,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 6dbb24a4358f..f4caa7da1721 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
@@ -45,9 +46,10 @@ class Std_Op<string mnemonic, list<OpTrait> traits = []> :
// Base class for standard cast operations. Requires single operand and result,
// but does not constrain them to specific types.
class CastOp<string mnemonic, list<OpTrait> traits = []> :
- Std_Op<mnemonic,
- !listconcat(traits, [NoSideEffect, SameOperandsAndResultShape])> {
-
+ Std_Op<mnemonic, traits # [
+ NoSideEffect, SameOperandsAndResultShape,
+ DeclareOpInterfaceMethods<CastOpInterface>
+ ]> {
let results = (outs AnyType);
let builders = [
@@ -62,9 +64,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
let printer = [{
return printStandardCastOp(this->getOperation(), p);
}];
- let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
- let hasFolder = 1;
+ // Cast operations are fully verified by its traits.
+ let verifier = ?;
}
// Base class for arithmetic cast operations.
@@ -1643,14 +1645,6 @@ def FPExtOp : ArithmeticCastOp<"fpext">, Arguments<(ins AnyType:$in)> {
The destination type must to be strictly wider than the source type.
Only scalars are currently supported.
}];
-
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
- }];
-
- let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@@ -1663,14 +1657,6 @@ def FPToSIOp : ArithmeticCastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
Cast from a value interpreted as floating-point to the nearest (rounding
towards zero) signed integer value.
}];
-
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
- }];
-
- let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@@ -1683,14 +1669,6 @@ def FPToUIOp : ArithmeticCastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
Cast from a value interpreted as floating-point to the nearest (rounding
towards zero) unsigned integer value.
}];
-
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
- }];
-
- let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@@ -1705,14 +1683,6 @@ def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
If the value cannot be exactly represented, it is rounded using the default
rounding mode. Only scalars are currently supported.
}];
-
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
- }];
-
- let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@@ -1849,12 +1819,6 @@ def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
sign-extended. If casting to a narrower integer, the value is truncated.
}];
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
- }];
-
let hasFolder = 1;
}
@@ -2045,14 +2009,7 @@ def MemRefCastOp : CastOp<"memref_cast", [
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyRankedOrUnrankedMemRef);
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
-
- /// The result of a memref_cast is always a memref.
- Type getType() { return getResult().getType(); }
- }];
+ let hasFolder = 1;
}
@@ -2786,14 +2743,6 @@ def SIToFPOp : ArithmeticCastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
exactly represented, it is rounded using the default rounding mode. Scalars
and vector types are currently supported.
}];
-
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
- }];
-
- let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@@ -3628,14 +3577,6 @@ def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
value cannot be exactly represented, it is rounded using the default
rounding mode. Scalars and vector types are currently supported.
}];
-
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
- }];
-
- let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 3a1a20835959..98d16c7acfe6 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -13,6 +13,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e7776c4e8a9b..d45f1f61cdd7 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -10,6 +10,7 @@
#define TENSOR_OPS
include "mlir/Dialect/Tensor/IR/TensorBase.td"
+include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -24,7 +25,9 @@ class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
// CastOp
//===----------------------------------------------------------------------===//
-def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
+def Tensor_CastOp : Tensor_Op<"cast", [
+ DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
+ ]> {
let summary = "tensor cast operation";
let description = [{
Convert a tensor from one type to an equivalent type without changing any
@@ -51,19 +54,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
let arguments = (ins AnyTensor:$source);
let results = (outs AnyTensor:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
- let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
-
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
-
- /// The result of a tensor.cast is always a tensor.
- TensorType getType() { return getResult().getType().cast<TensorType>(); }
- }];
- let hasFolder = 1;
let hasCanonicalizer = 1;
+ let verifier = ?;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 0dbf18284131..1fee314b285c 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -50,6 +50,34 @@ enum class DiagnosticSeverity {
/// A variant type that holds a single argument for a diagnostic.
class DiagnosticArgument {
public:
+ /// Note: The constructors below are only exposed due to problems accessing
+ /// constructors from type traits, they should not be used directly by users.
+ // Construct from an Attribute.
+ explicit DiagnosticArgument(Attribute attr);
+ // Construct from a floating point number.
+ explicit DiagnosticArgument(double val)
+ : kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
+ explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
+ // Construct from a signed integer.
+ template <typename T>
+ explicit DiagnosticArgument(
+ T val, typename std::enable_if<std::is_signed<T>::value &&
+ std::numeric_limits<T>::is_integer &&
+ sizeof(T) <= sizeof(int64_t)>::type * = 0)
+ : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
+ // Construct from an unsigned integer.
+ template <typename T>
+ explicit DiagnosticArgument(
+ T val, typename std::enable_if<std::is_unsigned<T>::value &&
+ std::numeric_limits<T>::is_integer &&
+ sizeof(T) <= sizeof(uint64_t)>::type * = 0)
+ : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
+ // Construct from a string reference.
+ explicit DiagnosticArgument(StringRef val)
+ : kind(DiagnosticArgumentKind::String), stringVal(val) {}
+ // Construct from a Type.
+ explicit DiagnosticArgument(Type val);
+
/// Enum that represents the
diff erent kinds of diagnostic arguments
/// supported.
enum class DiagnosticArgumentKind {
@@ -100,37 +128,6 @@ class DiagnosticArgument {
private:
friend class Diagnostic;
- // Construct from an Attribute.
- explicit DiagnosticArgument(Attribute attr);
-
- // Construct from a floating point number.
- explicit DiagnosticArgument(double val)
- : kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
- explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
-
- // Construct from a signed integer.
- template <typename T>
- explicit DiagnosticArgument(
- T val, typename std::enable_if<std::is_signed<T>::value &&
- std::numeric_limits<T>::is_integer &&
- sizeof(T) <= sizeof(int64_t)>::type * = 0)
- : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
-
- // Construct from an unsigned integer.
- template <typename T>
- explicit DiagnosticArgument(
- T val, typename std::enable_if<std::is_unsigned<T>::value &&
- std::numeric_limits<T>::is_integer &&
- sizeof(T) <= sizeof(uint64_t)>::type * = 0)
- : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
-
- // Construct from a string reference.
- explicit DiagnosticArgument(StringRef val)
- : kind(DiagnosticArgumentKind::String), stringVal(val) {}
-
- // Construct from a Type.
- explicit DiagnosticArgument(Type val);
-
/// The kind of this argument.
DiagnosticArgumentKind kind;
@@ -189,8 +186,10 @@ class Diagnostic {
/// Stream operator for inserting new diagnostic arguments.
template <typename Arg>
- typename std::enable_if<!std::is_convertible<Arg, StringRef>::value,
- Diagnostic &>::type
+ typename std::enable_if<
+ !std::is_convertible<Arg, StringRef>::value &&
+ std::is_constructible<DiagnosticArgument, Arg>::value,
+ Diagnostic &>::type
operator<<(Arg &&val) {
arguments.push_back(DiagnosticArgument(std::forward<Arg>(val)));
return *this;
@@ -220,17 +219,17 @@ class Diagnostic {
}
/// Stream in a range.
- template <typename T> Diagnostic &operator<<(iterator_range<T> range) {
- return appendRange(range);
- }
- template <typename T> Diagnostic &operator<<(ArrayRef<T> range) {
+ template <typename T, typename ValueT = llvm::detail::ValueOfRange<T>>
+ std::enable_if_t<!std::is_constructible<DiagnosticArgument, T>::value,
+ Diagnostic &>
+ operator<<(T &&range) {
return appendRange(range);
}
/// Append a range to the diagnostic. The default delimiter between elements
/// is ','.
- template <typename T, template <typename> class Container>
- Diagnostic &appendRange(const Container<T> &c, const char *delim = ", ") {
+ template <typename T>
+ Diagnostic &appendRange(const T &c, const char *delim = ", ") {
llvm::interleave(
c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; });
return *this;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 622952c76289..c021bdc8ee9d 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1822,18 +1822,27 @@ ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
void printOneResultOp(Operation *op, OpAsmPrinter &p);
} // namespace impl
-// These functions are out-of-line implementations of the methods in CastOp,
-// which avoids them being template instantiated/duplicated.
+// These functions are out-of-line implementations of the methods in
+// CastOpInterface, which avoids them being template instantiated/duplicated.
namespace impl {
+/// Attempt to fold the given cast operation.
+LogicalResult foldCastInterfaceOp(Operation *op,
+ ArrayRef<Attribute> attrOperands,
+ SmallVectorImpl<OpFoldResult> &foldResults);
+/// Attempt to verify the given cast operation.
+LogicalResult verifyCastInterfaceOp(
+ Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
+
// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
// need for them, but some older ODS code in `std` still depends on them).
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType);
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
void printCastOp(Operation *op, OpAsmPrinter &p);
-// TODO: Create a CastOpInterface with a method areCastCompatible.
-// Also, consider adding functionality to CastOpInterface to be able to perform
-// the ChainedTensorCast canonicalization generically.
+// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
+// when all uses have been updated. Also, consider adding functionality to
+// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
+// generically.
Value foldCastOp(Operation *op);
LogicalResult verifyCastOp(Operation *op,
function_ref<bool(Type, Type)> areCastCompatible);
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 65e19f3eec1b..3bb154988b01 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_interface(CallInterfaces)
+add_mlir_interface(CastInterfaces)
add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(CopyOpInterface)
add_mlir_interface(DerivedAttributeOpInterface)
diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.h b/mlir/include/mlir/Interfaces/CastInterfaces.h
new file mode 100644
index 000000000000..99a1f2ed7821
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/CastInterfaces.h
@@ -0,0 +1,22 @@
+//===- CastInterfaces.h - Cast Interfaces for MLIR --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definitions of the cast interfaces defined in
+// `CastInterfaces.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_CASTINTERFACES_H
+#define MLIR_INTERFACES_CASTINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/CastInterfaces.h.inc"
+
+#endif // MLIR_INTERFACES_CASTINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.td b/mlir/include/mlir/Interfaces/CastInterfaces.td
new file mode 100644
index 000000000000..c2a01df42c7f
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/CastInterfaces.td
@@ -0,0 +1,51 @@
+//===- CastInterfaces.td - Cast Interfaces for ops ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces that can be used to define information
+// related to cast-like operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_CASTINTERFACES
+#define MLIR_INTERFACES_CASTINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def CastOpInterface : OpInterface<"CastOpInterface"> {
+ let description = [{
+ A cast-like operation is one that converts from a set of input types to a
+ set of output types. The arity of the inputs may be from 0-N, whereas the
+ arity of the outputs may be anything from 1-N. Cast-like operations are
+ trivially removable in cases where they produce an No-op, i.e when the
+ input types and output types match 1-1.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ StaticInterfaceMethod<[{
+ Returns true if the given set of input and result types are compatible
+ to cast using this cast operation.
+ }],
+ "bool", "areCastCompatible",
+ (ins "mlir::TypeRange":$inputs, "mlir::TypeRange":$outputs)
+ >,
+ ];
+
+ let extraTraitClassDeclaration = [{
+ /// Attempt to fold the given cast operation.
+ static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return impl::foldCastInterfaceOp(op, operands, results);
+ }
+ }];
+ let verify = [{
+ return impl::verifyCastInterfaceOp($_op, ConcreteOp::areCastCompatible);
+ }];
+}
+
+#endif // MLIR_INTERFACES_CASTINTERFACES
diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
index f8321842db31..cf2bf19c4f6f 100644
--- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRShape
MLIRShapeOpsIncGen
LINK_LIBS PUBLIC
+ MLIRCastInterfaces
MLIRControlFlowInterfaces
MLIRDialect
MLIRInferTypeOpInterface
diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index 67f285817a91..058e680ef677 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRStandard
LINK_LIBS PUBLIC
MLIRCallInterfaces
+ MLIRCastInterfaces
MLIRControlFlowInterfaces
MLIREDSC
MLIRIR
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 428006e20d9f..45dd0fd0086a 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -195,7 +195,8 @@ static LogicalResult foldMemRefCast(Operation *op) {
/// Returns 'true' if the vector types are cast compatible, and 'false'
/// otherwise.
static bool areVectorCastSimpleCompatible(
- Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
+ Type a, Type b,
+ function_ref<bool(TypeRange, TypeRange)> areElementsCastCompatible) {
if (auto va = a.dyn_cast<VectorType>())
if (auto vb = b.dyn_cast<VectorType>())
return va.getShape().equals(vb.getShape()) &&
@@ -1746,7 +1747,10 @@ LogicalResult DmaWaitOp::verify() {
// FPExtOp
//===----------------------------------------------------------------------===//
-bool FPExtOp::areCastCompatible(Type a, Type b) {
+bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
if (auto fa = a.dyn_cast<FloatType>())
if (auto fb = b.dyn_cast<FloatType>())
return fa.getWidth() < fb.getWidth();
@@ -1757,7 +1761,10 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
// FPToSIOp
//===----------------------------------------------------------------------===//
-bool FPToSIOp::areCastCompatible(Type a, Type b) {
+bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
if (a.isa<FloatType>() && b.isSignlessInteger())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
@@ -1767,7 +1774,10 @@ bool FPToSIOp::areCastCompatible(Type a, Type b) {
// FPToUIOp
//===----------------------------------------------------------------------===//
-bool FPToUIOp::areCastCompatible(Type a, Type b) {
+bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
if (a.isa<FloatType>() && b.isSignlessInteger())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
@@ -1777,7 +1787,10 @@ bool FPToUIOp::areCastCompatible(Type a, Type b) {
// FPTruncOp
//===----------------------------------------------------------------------===//
-bool FPTruncOp::areCastCompatible(Type a, Type b) {
+bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
if (auto fa = a.dyn_cast<FloatType>())
if (auto fb = b.dyn_cast<FloatType>())
return fa.getWidth() > fb.getWidth();
@@ -1889,7 +1902,10 @@ GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
// Index cast is applicable from index to integer and backwards.
-bool IndexCastOp::areCastCompatible(Type a, Type b) {
+bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
auto aShaped = a.cast<ShapedType>();
auto bShaped = b.cast<ShapedType>();
@@ -1965,7 +1981,10 @@ void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
Value MemRefCastOp::getViewSource() { return source(); }
-bool MemRefCastOp::areCastCompatible(Type a, Type b) {
+bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
auto aT = a.dyn_cast<MemRefType>();
auto bT = b.dyn_cast<MemRefType>();
@@ -2036,8 +2055,6 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
}
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
- if (Value folded = impl::foldCastOp(*this))
- return folded;
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
@@ -2633,7 +2650,10 @@ OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
// sitofp is applicable from integer types to float types.
-bool SIToFPOp::areCastCompatible(Type a, Type b) {
+bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
if (a.isSignlessInteger() && b.isa<FloatType>())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
@@ -2715,7 +2735,10 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
// uitofp is applicable from integer types to float types.
-bool UIToFPOp::areCastCompatible(Type a, Type b) {
+bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
if (a.isSignlessInteger() && b.isa<FloatType>())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index b8fb44a9f4cb..de650995ebb6 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRTensor
Core
LINK_LIBS PUBLIC
+ MLIRCastInterfaces
MLIRIR
MLIRSideEffectInterfaces
MLIRSupport
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e231a3a3b56e..92115d51476e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -73,7 +73,10 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
return true;
}
-bool CastOp::areCastCompatible(Type a, Type b) {
+bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
auto aT = a.dyn_cast<TensorType>();
auto bT = b.dyn_cast<TensorType>();
if (!aT || !bT)
@@ -85,10 +88,6 @@ bool CastOp::areCastCompatible(Type a, Type b) {
return succeeded(verifyCompatibleShape(aT, bT));
}
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
- return impl::foldCastOp(*this);
-}
-
/// Compute a TensorType that has the joined shape knowledge of the two
/// given TensorTypes. The element types need to match.
static TensorType joinShapes(TensorType one, TensorType two) {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 4152121dd548..ba1d1e5109cc 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1208,6 +1208,48 @@ void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
// CastOp implementation
//===----------------------------------------------------------------------===//
+/// Attempt to fold the given cast operation.
+LogicalResult
+impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands,
+ SmallVectorImpl<OpFoldResult> &foldResults) {
+ OperandRange operands = op->getOperands();
+ if (operands.empty())
+ return failure();
+ ResultRange results = op->getResults();
+
+ // Check for the case where the input and output types match 1-1.
+ if (operands.getTypes() == results.getTypes()) {
+ foldResults.append(operands.begin(), operands.end());
+ return success();
+ }
+
+ return failure();
+}
+
+/// Attempt to verify the given cast operation.
+LogicalResult impl::verifyCastInterfaceOp(
+ Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) {
+ auto resultTypes = op->getResultTypes();
+ if (llvm::empty(resultTypes))
+ return op->emitOpError()
+ << "expected at least one result for cast operation";
+
+ auto operandTypes = op->getOperandTypes();
+ if (!areCastCompatible(operandTypes, resultTypes)) {
+ InFlightDiagnostic diag = op->emitOpError("operand type");
+ if (llvm::empty(operandTypes))
+ diag << "s []";
+ else if (llvm::size(operandTypes) == 1)
+ diag << " " << *operandTypes.begin();
+ else
+ diag << "s " << operandTypes;
+ return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ")
+ << resultTypes << " are cast incompatible";
+ }
+
+ return success();
+}
+
void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType) {
result.addOperands(source);
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 0a8f75b6f7d9..26b484bc5d78 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
CallInterfaces.cpp
+ CastInterfaces.cpp
ControlFlowInterfaces.cpp
CopyOpInterface.cpp
DerivedAttributeOpInterface.cpp
@@ -27,6 +28,7 @@ endfunction(add_mlir_interface_library)
add_mlir_interface_library(CallInterfaces)
+add_mlir_interface_library(CastInterfaces)
add_mlir_interface_library(ControlFlowInterfaces)
add_mlir_interface_library(CopyOpInterface)
add_mlir_interface_library(DerivedAttributeOpInterface)
diff --git a/mlir/lib/Interfaces/CastInterfaces.cpp b/mlir/lib/Interfaces/CastInterfaces.cpp
new file mode 100644
index 000000000000..400c1978cdaa
--- /dev/null
+++ b/mlir/lib/Interfaces/CastInterfaces.cpp
@@ -0,0 +1,17 @@
+//===- CastInterfaces.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/CastInterfaces.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Table-generated class definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/CastInterfaces.cpp.inc"
More information about the llvm-branch-commits
mailing list