[llvm-branch-commits] [mlir] 129d6e5 - [mlir] Move `std.tensor_cast` -> `tensor.cast`.
Sean Silva via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Dec 17 16:12:44 PST 2020
Author: Sean Silva
Date: 2020-12-17T16:06:56-08:00
New Revision: 129d6e554e7a0dba3443ffd8f1df185b90cc6fd5
URL: https://github.com/llvm/llvm-project/commit/129d6e554e7a0dba3443ffd8f1df185b90cc6fd5
DIFF: https://github.com/llvm/llvm-project/commit/129d6e554e7a0dba3443ffd8f1df185b90cc6fd5.diff
LOG: [mlir] Move `std.tensor_cast` -> `tensor.cast`.
This is almost entirely mechanical.
Differential Revision: https://reviews.llvm.org/D93357
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/IR/OpDefinition.h
mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Shape/IR/CMakeLists.txt
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/lib/IR/Operation.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Standard/bufferize.mlir
mlir/test/Dialect/Standard/canonicalize.mlir
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
mlir/test/IR/core-ops.mlir
mlir/test/Transforms/canonicalize.mlir
mlir/test/Transforms/cse.mlir
mlir/utils/vim/syntax/mlir.vim
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 3df609f295cc..2ef32cfe378b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -34,7 +34,7 @@ namespace linalg {
class LinalgDependenceGraph;
/// A struct containing the Linalg producer before and after fusion.
-/// When operating on tensors, `fusedProducer` may feed into a `tensor_cast` op
+/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
/// before the consumer Linalg op, until enough canonicalizations have applied.
struct FusionInfo {
LinalgOp originalProducer;
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 7302bd486657..56ff32252fee 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -354,31 +354,6 @@ computeRankReductionMask(ArrayRef<int64_t> originalShape,
/// ```
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
-/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
-/// Determines whether TensorCastOp casts to a more dynamic version of the
-/// source tensor. This is useful to fold a tensor_cast into a consuming op and
-/// implement canonicalization patterns for ops in
diff erent dialects that may
-/// consume the results of tensor_cast operations. Such foldable tensor_cast
-/// operations are typically inserted as `subtensor` ops and are canonicalized,
-/// to preserve the type compatibility of their uses.
-///
-/// Returns true when all conditions are met:
-/// 1. source and result are ranked tensors with same element type and rank.
-/// 2. the tensor type has more static information than the result
-///
-/// Example:
-/// ```mlir
-/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
-/// %2 = consumer %1 ... : tensor<?x?xf32> ...
-/// ```
-///
-/// folds into:
-///
-/// ```mlir
-/// %2 = consumer %0 ... : tensor<8x16xf32> ...
-/// ```
-bool canFoldIntoConsumerOp(TensorCastOp castOp);
-
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
/// comparison predicates.
bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 481dfaf4b34d..7af44f8435ff 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -62,7 +62,7 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
let printer = [{
return printStandardCastOp(this->getOperation(), p);
}];
- let verifier = [{ return ::verifyCastOp(*this); }];
+ let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
let hasFolder = 1;
}
@@ -3428,56 +3428,6 @@ def TanhOp : FloatUnaryOp<"tanh"> {
}];
}
-//===----------------------------------------------------------------------===//
-// TensorCastOp
-//===----------------------------------------------------------------------===//
-
-def TensorCastOp : CastOp<"tensor_cast"> {
- let summary = "tensor cast operation";
- let description = [{
- Syntax:
-
- ```
- operation ::= ssa-id `=` `std.tensor_cast` ssa-use `:` type `to` type
- ```
-
- Convert a tensor from one type to an equivalent type without changing any
- data elements. The source and destination types must both be tensor types
- with the same element type. If both are ranked, then the rank should be the
- same and static dimensions should match. The operation is invalid if
- converting to a mismatching constant dimension.
-
- Example:
-
- ```mlir
- // Convert from unknown rank to rank 2 with unknown dimension sizes.
- %2 = "std.tensor_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
- %2 = tensor_cast %1 : tensor<*xf32> to tensor<?x?xf32>
-
- // Convert to a type with more known dimensions.
- %3 = "std.tensor_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
-
- // Discard static dimension and rank information.
- %4 = "std.tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
- %5 = "std.tensor_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
- ```
- }];
-
- let arguments = (ins AnyTensor:$source);
- let results = (outs AnyTensor);
-
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
-
- /// The result of a tensor_cast is always a tensor.
- TensorType getType() { return getResult().getType().cast<TensorType>(); }
- }];
-
- let hasCanonicalizer = 1;
-}
-
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index ee517de3fca0..53980db64dc0 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -28,4 +28,38 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
+//===----------------------------------------------------------------------===//
+// Tensor Dialect Helpers
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace tensor {
+
+/// Determines whether tensor::CastOp casts to a more dynamic version of the
+/// source tensor. This is useful to fold a tensor.cast into a consuming op and
+/// implement canonicalization patterns for ops in
diff erent dialects that may
+/// consume the results of tensor.cast operations. Such foldable tensor.cast
+/// operations are typically inserted as `subtensor` ops and are canonicalized,
+/// to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked tensors with same element type and rank.
+/// 2. the tensor type has more static information than the result
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = consumer %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = consumer %0 ... : tensor<8x16xf32> ...
+/// ```
+bool canFoldIntoConsumerOp(CastOp castOp);
+
+} // namespace tensor
+} // namespace mlir
+
#endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 4eb989b2f3b5..e0500b8fcfa6 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -19,6 +19,52 @@ class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
let parser = [{ return ::parse$cppClass(parser, result); }];
}
+//===----------------------------------------------------------------------===//
+// CastOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
+ let summary = "tensor cast operation";
+ let description = [{
+ Convert a tensor from one type to an equivalent type without changing any
+ data elements. The source and destination types must both be tensor types
+ with the same element type. If both are ranked, then the rank should be the
+ same and static dimensions should match. The operation is invalid if
+ converting to a mismatching constant dimension.
+
+ Example:
+
+ ```mlir
+ // Convert from unknown rank to rank 2 with unknown dimension sizes.
+ %2 = tensor.cast %1 : tensor<*xf32> to tensor<?x?xf32>
+
+ // Convert to a type with more known dimensions.
+ %3 = tensor.cast %2 : tensor<?x?xf32> to tensor<4x?xf32>
+
+ // Discard static dimension and rank information.
+ %4 = tensor.cast %3 : tensor<4x?xf32> to tensor<?x?xf32>
+ %5 = tensor.cast %4 : tensor<?x?xf32> to tensor<*xf32>
+ ```
+ }];
+
+ let arguments = (ins AnyTensor:$source);
+ let results = (outs AnyTensor:$dest);
+ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+ let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
+
+ let extraClassDeclaration = [{
+ /// Return true if `a` and `b` are valid operand and result pairs for
+ /// the operation.
+ static bool areCastCompatible(Type a, Type b);
+
+ /// The result of a tensor.cast is always a tensor.
+ TensorType getType() { return getResult().getType().cast<TensorType>(); }
+ }];
+
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+}
+
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 9e4da63c3618..beb45ebfe08c 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1775,11 +1775,18 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p);
// These functions are out-of-line implementations of the methods in CastOp,
// which avoids them being template instantiated/duplicated.
namespace impl {
+// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
+// need for them, but some older ODS code in `std` still depends on them).
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType);
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
void printCastOp(Operation *op, OpAsmPrinter &p);
+// TODO: Create a CastOpInterface with a method areCastCompatible.
+// Also, consider adding functionality to CastOpInterface to be able to perform
+// the ChainedTensorCast canonicalization generically.
Value foldCastOp(Operation *op);
+LogicalResult verifyCastOp(Operation *op,
+ function_ref<bool(Type, Type)> areCastCompatible);
} // namespace impl
} // end namespace mlir
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
index 81fcc70e91be..f18b26506447 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -tensor-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
@@ -8,7 +8,7 @@ func @main() {
%b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
%addf = addf %a, %b : tensor<3xf32>
- %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
+ %addf_unranked = tensor.cast %addf : tensor<3xf32> to tensor<*xf32>
call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
// CHECK-NEXT: [11, 22, 33]
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
index ac90fb28517b..f327bc97d7e0 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
+// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
+// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
+// RUN: -finalizing-bufferize \
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@@ -15,14 +17,14 @@ func @main() {
%inserted_at_position_0 = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
%inserted_at_position_1 = subtensor_insert %insert_val into %const[1][1][1] : tensor<1xf32> into tensor<2xf32>
- %unranked_at_position_0 = tensor_cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
+ %unranked_at_position_0 = tensor.cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
call @print_memref_f32(%unranked_at_position_0) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
// CHECK-NEXT: [20, 10]
- %unranked_at_position_1 = tensor_cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
+ %unranked_at_position_1 = tensor.cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
call @print_memref_f32(%unranked_at_position_1) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
index e568d6acf9ee..6bc886a5d4c4 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
+// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
+// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
+// RUN: -finalizing-bufferize \
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@@ -9,7 +11,7 @@ func @main() {
%insert_val = constant dense<20.0> : tensor<1xf32>
%inserted = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
- %unranked = tensor_cast %inserted : tensor<2xf32> to tensor<*xf32>
+ %unranked = tensor.cast %inserted : tensor<2xf32> to tensor<*xf32>
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
index 61fc05f8c20b..71e27733c83f 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -tensor-constant-bufferize -std-bufferize -linalg-bufferize \
-// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
+// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
// RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@@ -19,7 +19,7 @@ func @main() {
// Note that this is skipping a step and we would need at least some function
// attribute to declare that this conversion is valid (e.g. when we statically
// know that things will play nicely at the C ABI boundary).
- %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
+ %unranked = tensor.cast %0 : tensor<4xf32> to tensor<*xf32>
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
index e535febcf7dc..38d97332f0d7 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
@@ -1,12 +1,13 @@
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize \
-// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
+// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
// RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,2,3" -linalg-bufferize \
-// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
+// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -tensor-bufferize \
+// RUN: -func-bufferize \
// RUN: -finalizing-bufferize -convert-linalg-to-loops -convert-scf-to-std \
// RUN: -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
@@ -23,7 +24,7 @@ func @main() {
%D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
- %unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32>
+ %unranked = tensor.cast %D : tensor<2x4xf32> to tensor<*xf32>
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 7189c30e766a..0d87d4f10975 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -103,9 +103,9 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
auto erasedRankType =
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
Value rankErasedLhs =
- rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
+ rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
Value rankErasedRhs =
- rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
+ rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
Value greaterRankOperand =
@@ -186,7 +186,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
Value tensor =
rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
- rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
return success();
}
@@ -246,9 +246,9 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
auto erasedRankType =
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
Value rankErasedLhs =
- rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
+ rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
Value rankErasedRhs =
- rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
+ rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
Value greaterRankOperand =
@@ -528,8 +528,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
// Materialize extent tensor.
Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
loc, rewriter.getIndexType(), extentValues);
- rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
- op.getType());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
+ staticExtentTensor);
return success();
}
@@ -561,8 +561,8 @@ class ToExtentTensorOpConversion
if (!adaptor.input().getType().isa<RankedTensorType>())
return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
- rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
- op.getType());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
+ adaptor.input());
return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index 15e29d749e65..3ed79a554b31 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -16,4 +16,5 @@ add_mlir_dialect_library(MLIRLinalg
MLIRSideEffectInterfaces
MLIRViewLikeInterface
MLIRStandard
+ MLIRTensor
)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f2b05448dbd0..3a7249df8e79 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -651,7 +651,7 @@ struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
auto newOp =
rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
rewriter.getI64ArrayAttr(staticSizes));
- rewriter.replaceOpWithNewOp<TensorCastOp>(op, op.getType(), newOp);
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
};
@@ -1815,12 +1815,12 @@ struct FoldTensorCastOp : public RewritePattern {
if (!linalgOp)
return failure();
- // If no operand comes from a TensorCastOp and can be folded then fail.
+ // If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
if (v.isa<BlockArgument>())
return false;
- auto castOp = v.getDefiningOp<TensorCastOp>();
+ auto castOp = v.getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
if (!hasTensorCastOperand)
@@ -1832,7 +1832,7 @@ struct FoldTensorCastOp : public RewritePattern {
newOperands.reserve(op->getNumOperands());
// Inputs may fold.
for (Value v : linalgOp.getInputs()) {
- auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
+ auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
newOperands.push_back(
canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
}
@@ -1841,7 +1841,7 @@ struct FoldTensorCastOp : public RewritePattern {
linalgOp.getOutputBuffers().end());
// Init tensors may fold, in which case the resultType must also change.
for (Value v : linalgOp.getInitTensors()) {
- auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
+ auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
newResultTypes.push_back(newOperands.back().getType());
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index c77398def98c..ba31ca5a034b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -59,6 +59,7 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
void mlir::linalg::LinalgDialect::initialize() {
getContext()->getOrLoadDialect("std");
+ getContext()->getOrLoadDialect("tensor");
addTypes<RangeType>();
addOperations<
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 6de4ce6ac341..42e2d4dcd244 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -36,6 +36,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRStandard
MLIRStandardOpsTransforms
MLIRStandardToLLVM
+ MLIRTensor
MLIRTransforms
MLIRTransformUtils
MLIRVector
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 681309687d18..d9ea7d8ccb29 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
@@ -517,13 +518,13 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
// Replace use.
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
- // mismatches. Insert a `tensor_cast` op to propagate the transformation
+ // mismatches. Insert a `tensor.cast` op to propagate the transformation
// invariant that types are compatible.
Value def = fusedProducer->getResult(producerIdx);
OpOperand &use = consumer->getOpOperand(consumerIdx);
Type consumerType = use.get().getType();
if (consumerType != def.getType())
- def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
+ def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
use.set(def);
return FusionInfo{producerOp, fusedProducer};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 50a18d4fb01c..423d687c1eb8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
@@ -569,7 +570,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
SubViewOp::getCanonicalizationPatterns(patterns, ctx);
- TensorCastOp::getCanonicalizationPatterns(patterns, ctx);
+ tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
ViewOp::getCanonicalizationPatterns(patterns, ctx);
CanonicalizationPatternList<
#define GET_OP_LIST
diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
index 1ac5b3b1e856..f8321842db31 100644
--- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
@@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRShape
MLIRIR
MLIRSideEffectInterfaces
MLIRStandard
+ MLIRTensor
)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ef29ddc510ae..0478cb7872cc 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 4e6d062a232f..45c699baece3 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -1,5 +1,6 @@
include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
+include "mlir/Dialect/Tensor/IR/TensorOps.td"
def AllInputShapesEq : Constraint<CPred< [{
llvm::all_of($0, [&](mlir::Value val) {
@@ -32,7 +33,7 @@ def SizeToIndexToSizeCanonicalization : Pat<
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
(replaceWithValue $arg)>;
-// Fold tensor_cast(const_shape) to const_shape. This changes the type of
+// Fold tensor.cast(const_shape) to const_shape. This changes the type of
// const_shape to the destination type of the cast.
def TensorCastConstShape : Pat <
- (TensorCastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
+ (Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 7ed9cffa8806..c0af06314086 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -141,18 +141,6 @@ static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
<< op->getResult(0).getType();
}
-/// A custom cast operation verifier.
-template <typename T>
-static LogicalResult verifyCastOp(T op) {
- auto opType = op.getOperand().getType();
- auto resType = op.getType();
- if (!T::areCastCompatible(opType, resType))
- return op.emitError("operand type ") << opType << " and result type "
- << resType << " are cast incompatible";
-
- return success();
-}
-
void StandardOpsDialect::initialize() {
getContext()->loadDialect<tensor::TensorDialect>();
addOperations<DmaStartOp, DmaWaitOp,
@@ -1494,7 +1482,7 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<DimOfMemRefReshape, DimOfCastOp<TensorCastOp>>(context);
+ results.insert<DimOfMemRefReshape, DimOfCastOp<tensor::CastOp>>(context);
}
// ---------------------------------------------------------------------------
@@ -1870,8 +1858,8 @@ struct StaticDynamicTensorFromElements
newOperands);
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
newOp.body().begin());
- rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
- newOp);
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
+ newOp);
return success();
}
};
@@ -1913,7 +1901,7 @@ struct ExtractFromDynamicTensorFromElements
/// Canonicalizes the pattern of the form
///
-/// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
+/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
///
/// to
@@ -1924,7 +1912,7 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
- auto tensorCast = extract.tensor().getDefiningOp<TensorCastOp>();
+ auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
if (!tensorCast)
return failure();
@@ -3395,7 +3383,7 @@ static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
SubTensorOp newOp) {
- rewriter.replaceOpWithNewOp<TensorCastOp>(op, newOp, op.getType());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
}
/// Pattern to rewrite a subview op with constant arguments.
@@ -3536,60 +3524,6 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
return true;
}
-/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
-/// Determines whether TensorCastOp casts to a more dynamic version of the
-/// source tensor. This is useful to fold a tensor_cast into a consuming op and
-/// implement canonicalization patterns for ops in
diff erent dialects that may
-/// consume the results of tensor_cast operations. Such foldable tensor_cast
-/// operations are typically inserted as `subtensor` ops and are canonicalized,
-/// to preserve the type compatibility of their uses.
-///
-/// Returns true when all conditions are met:
-/// 1. source and result are ranked tensors with same element type and rank.
-/// 2. the tensor type has more static information than the result
-///
-/// Example:
-/// ```mlir
-/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
-/// %2 = consumer %1 ... : tensor<?x?xf32> ...
-/// ```
-///
-/// folds into:
-///
-/// ```mlir
-/// %2 = consumer %0 ... : tensor<8x16xf32> ...
-/// ```
-bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) {
- if (!castOp)
- return false;
-
- RankedTensorType sourceType =
- castOp.source().getType().dyn_cast<RankedTensorType>();
- RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
-
- // Requires RankedTensorType.
- if (!sourceType || !resultType)
- return false;
-
- // Requires same elemental type.
- if (sourceType.getElementType() != resultType.getElementType())
- return false;
-
- // Requires same rank.
- if (sourceType.getRank() != resultType.getRank())
- return false;
-
- // If cast is towards more static sizes along any dimension, don't fold.
- for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
- auto ss = std::get<0>(it), st = std::get<1>(it);
- if (ss != st)
- if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
- return false;
- }
-
- return true;
-}
-
namespace {
/// Pattern to rewrite a subview op with MemRefCast arguments.
/// This essentially pushes memref_cast past its consuming subview when
@@ -3857,107 +3791,6 @@ static LogicalResult verify(SubTensorInsertOp op) {
return success();
}
-//===----------------------------------------------------------------------===//
-// TensorCastOp
-//===----------------------------------------------------------------------===//
-
-bool TensorCastOp::areCastCompatible(Type a, Type b) {
- auto aT = a.dyn_cast<TensorType>();
- auto bT = b.dyn_cast<TensorType>();
- if (!aT || !bT)
- return false;
-
- if (aT.getElementType() != bT.getElementType())
- return false;
-
- return succeeded(verifyCompatibleShape(aT, bT));
-}
-
-OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
- return impl::foldCastOp(*this);
-}
-
-/// Compute a TensorType that has the joined shape knowledge of the two
-/// given TensorTypes. The element types need to match.
-static TensorType joinShapes(TensorType one, TensorType two) {
- assert(one.getElementType() == two.getElementType());
-
- if (!one.hasRank())
- return two;
- if (!two.hasRank())
- return one;
-
- int64_t rank = one.getRank();
- if (rank != two.getRank())
- return {};
-
- SmallVector<int64_t, 4> join;
- join.reserve(rank);
- for (int64_t i = 0; i < rank; ++i) {
- if (one.isDynamicDim(i)) {
- join.push_back(two.getDimSize(i));
- continue;
- }
- if (two.isDynamicDim(i)) {
- join.push_back(one.getDimSize(i));
- continue;
- }
- if (one.getDimSize(i) != two.getDimSize(i))
- return {};
- join.push_back(one.getDimSize(i));
- }
- return RankedTensorType::get(join, one.getElementType());
-}
-
-namespace {
-
-/// Replaces chains of two tensor_cast operations by a single tensor_cast
-/// operation if doing so does not remove runtime constraints.
-struct ChainedTensorCast : public OpRewritePattern<TensorCastOp> {
- using OpRewritePattern<TensorCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TensorCastOp tensorCast,
- PatternRewriter &rewriter) const final {
- auto tensorCastOperand =
- tensorCast.getOperand().getDefiningOp<TensorCastOp>();
-
- if (!tensorCastOperand)
- return failure();
-
- auto sourceType =
- tensorCastOperand.getOperand().getType().cast<TensorType>();
- auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
- auto resultType = tensorCast.getType().cast<TensorType>();
-
- // We can remove the intermediate cast if joining all three produces the
- // same result as just joining the source and result shapes.
- auto firstJoin =
- joinShapes(joinShapes(sourceType, intermediateType), resultType);
-
- // The join might not exist if the cast sequence would fail at runtime.
- if (!firstJoin)
- return failure();
-
- // The newJoin always exists if the above join exists, it might just contain
- // less information. If so, we cannot drop the intermediate cast, as doing
- // so would remove runtime checks.
- auto newJoin = joinShapes(sourceType, resultType);
- if (firstJoin != newJoin)
- return failure();
-
- rewriter.replaceOpWithNewOp<TensorCastOp>(tensorCast, resultType,
- tensorCastOperand.getOperand());
- return success();
- }
-};
-
-} // namespace
-
-void TensorCastOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ChainedTensorCast>(context);
-}
-
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index a84934b0ebb8..98792838deff 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -117,20 +117,6 @@ class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
};
} // namespace
-namespace {
-class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto resultType = getTypeConverter()->convertType(op.getType());
- rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
- return success();
- }
-};
-} // namespace
-
namespace {
class BufferizeTensorFromElementsOp
: public OpConversionPattern<TensorFromElementsOp> {
@@ -162,7 +148,6 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context,
BufferizeDimOp,
BufferizeDynamicTensorFromElementsOp,
BufferizeSelectOp,
- BufferizeTensorCastOp,
BufferizeTensorFromElementsOp
// clang-format on
>(typeConverter, context);
@@ -180,8 +165,7 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
target.addLegalDialect<scf::SCFDialect>();
populateStdBufferizePatterns(context, typeConverter, patterns);
- target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp,
- TensorFromElementsOp>();
+ target.addIllegalOp<DynamicTensorFromElementsOp, TensorFromElementsOp>();
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bb944b21e3c3..aaae7fbf807c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -8,12 +8,165 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::tensor;
+//===----------------------------------------------------------------------===//
+// CastOp
+//===----------------------------------------------------------------------===//
+
+/// Determines whether tensor::CastOp casts to a more dynamic version of the
+/// source tensor. This is useful to fold a tensor.cast into a consuming op and
+/// implement canonicalization patterns for ops in
diff erent dialects that may
+/// consume the results of tensor.cast operations. Such foldable tensor.cast
+/// operations are typically inserted as `subtensor` ops and are canonicalized,
+/// to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked tensors with same element type and rank.
+/// 2. the tensor type has more static information than the result
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = consumer %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = consumer %0 ... : tensor<8x16xf32> ...
+/// ```
+bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
+ if (!castOp)
+ return false;
+
+ RankedTensorType sourceType =
+ castOp.source().getType().dyn_cast<RankedTensorType>();
+ RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
+
+ // Requires RankedTensorType.
+ if (!sourceType || !resultType)
+ return false;
+
+ // Requires same elemental type.
+ if (sourceType.getElementType() != resultType.getElementType())
+ return false;
+
+ // Requires same rank.
+ if (sourceType.getRank() != resultType.getRank())
+ return false;
+
+ // If cast is towards more static sizes along any dimension, don't fold.
+ for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
+ if (ShapedType::isDynamic(std::get<0>(t)) &&
+ !ShapedType::isDynamic(std::get<1>(t)))
+ return false;
+ }
+
+ return true;
+}
+
+bool CastOp::areCastCompatible(Type a, Type b) {
+ auto aT = a.dyn_cast<TensorType>();
+ auto bT = b.dyn_cast<TensorType>();
+ if (!aT || !bT)
+ return false;
+
+ if (aT.getElementType() != bT.getElementType())
+ return false;
+
+ return succeeded(verifyCompatibleShape(aT, bT));
+}
+
+OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+ return impl::foldCastOp(*this);
+}
+
+/// Compute a TensorType that has the joined shape knowledge of the two
+/// given TensorTypes. The element types need to match.
+static TensorType joinShapes(TensorType one, TensorType two) {
+ assert(one.getElementType() == two.getElementType());
+
+ if (!one.hasRank())
+ return two;
+ if (!two.hasRank())
+ return one;
+
+ int64_t rank = one.getRank();
+ if (rank != two.getRank())
+ return {};
+
+ SmallVector<int64_t, 4> join;
+ join.reserve(rank);
+ for (int64_t i = 0; i < rank; ++i) {
+ if (one.isDynamicDim(i)) {
+ join.push_back(two.getDimSize(i));
+ continue;
+ }
+ if (two.isDynamicDim(i)) {
+ join.push_back(one.getDimSize(i));
+ continue;
+ }
+ if (one.getDimSize(i) != two.getDimSize(i))
+ return {};
+ join.push_back(one.getDimSize(i));
+ }
+ return RankedTensorType::get(join, one.getElementType());
+}
+
+namespace {
+
+/// Replaces chains of two tensor.cast operations by a single tensor.cast
+/// operation if doing so does not remove runtime constraints.
+struct ChainedTensorCast : public OpRewritePattern<CastOp> {
+ using OpRewritePattern<CastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CastOp tensorCast,
+ PatternRewriter &rewriter) const final {
+ auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
+
+ if (!tensorCastOperand)
+ return failure();
+
+ auto sourceType =
+ tensorCastOperand.getOperand().getType().cast<TensorType>();
+ auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
+ auto resultType = tensorCast.getType().cast<TensorType>();
+
+ // We can remove the intermediate cast if joining all three produces the
+ // same result as just joining the source and result shapes.
+ auto firstJoin =
+ joinShapes(joinShapes(sourceType, intermediateType), resultType);
+
+ // The join might not exist if the cast sequence would fail at runtime.
+ if (!firstJoin)
+ return failure();
+
+ // The newJoin always exists if the above join exists, it might just contain
+ // less information. If so, we cannot drop the intermediate cast, as doing
+ // so would remove runtime checks.
+ auto newJoin = joinShapes(sourceType, resultType);
+ if (firstJoin != newJoin)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
+ tensorCastOperand.getOperand());
+ return success();
+ }
+};
+
+} // namespace
+
+void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ChainedTensorCast>(context);
+}
+
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 9e6b3dba74a8..05ff96fb8d69 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -19,6 +19,20 @@
using namespace mlir;
+namespace {
+class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultType = getTypeConverter()->convertType(op.getType());
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
+ return success();
+ }
+};
+} // namespace
+
namespace {
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
public:
@@ -37,7 +51,7 @@ class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
void mlir::populateTensorBufferizePatterns(
MLIRContext *context, BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeExtractOp>(typeConverter, context);
+ patterns.insert<BufferizeCastOp, BufferizeExtractOp>(typeConverter, context);
}
namespace {
@@ -49,7 +63,7 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
ConversionTarget target(*context);
populateTensorBufferizePatterns(context, typeConverter, patterns);
- target.addIllegalOp<tensor::ExtractOp>();
+ target.addIllegalOp<tensor::CastOp, tensor::ExtractOp>();
target.addLegalDialect<StandardOpsDialect>();
if (failed(
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 1b8d2875f3c8..c84a11bcec3f 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1213,6 +1213,19 @@ Value impl::foldCastOp(Operation *op) {
return nullptr;
}
+LogicalResult
+impl::verifyCastOp(Operation *op,
+ function_ref<bool(Type, Type)> areCastCompatible) {
+ auto opType = op->getOperand(0).getType();
+ auto resType = op->getResult(0).getType();
+ if (!areCastCompatible(opType, resType))
+ return op->emitError("operand type ")
+ << opType << " and result type " << resType
+ << " are cast incompatible";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Misc. utils
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index b7663a328986..c2ab39f338b1 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -95,7 +95,7 @@ func @const_shape() -> tensor<?xindex> {
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[C3:.*]] = constant 3 : index
// CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]]
- // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
+ // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
return %shape : tensor<?xindex>
@@ -108,7 +108,7 @@ func @const_shape() -> tensor<?xindex> {
// CHECK-SAME: () -> tensor<?xindex>
func @const_shape_zero_elements() -> tensor<?xindex> {
// CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
- // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
+ // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%shape = shape.const_shape [] : tensor<?xindex>
return %shape : tensor<?xindex>
@@ -152,13 +152,13 @@ func @const_size() -> index {
// -----
-// Lower `to_extent_tensor` to `std.tensor_cast`
+// Lower `to_extent_tensor` to `tensor.cast`
// Fold to_extent_tensor when already on tensor.
// CHECK-LABEL: @to_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xindex>
func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> {
// CHECK-NOT: to_extent_tensor
- // CHECK: %[[RES:.*]] = tensor_cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
+ // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
%casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<3xindex>
// CHECK: return %[[RES]]
return %casted : tensor<3xindex>
@@ -316,8 +316,8 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
- // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
- // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+ // CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
+ // CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
@@ -356,8 +356,8 @@ func @broadcast_known_
diff erent_extents(%a : tensor<2xindex>, %b : tensor<3xinde
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
- // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
- // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
+ // CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
+ // CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
@@ -400,8 +400,8 @@ func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
-// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
+// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
@@ -438,8 +438,8 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
-// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 96ab1aa93355..6c12070e07f1 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -317,20 +317,20 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
// -----
-// CHECK-LABEL: func @tensor_cast(
-func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
+// CHECK-LABEL: func @tensor.cast(
+func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
-> tensor<3x?xf32>
{
- %ta = tensor_cast %a : tensor<3x4xf32> to tensor<?x?xf32>
- %tb = tensor_cast %b : tensor<4x?xf32> to tensor<?x?xf32>
- %tc = tensor_cast %c : tensor<3x?xf32> to tensor<?x?xf32>
+ %ta = tensor.cast %a : tensor<3x4xf32> to tensor<?x?xf32>
+ %tb = tensor.cast %b : tensor<4x?xf32> to tensor<?x?xf32>
+ %tc = tensor.cast %c : tensor<3x?xf32> to tensor<?x?xf32>
// CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
// CHECK-SAME: init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
%0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
init(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = tensor_cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
+ %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
return %1: tensor<3x?xf32>
}
@@ -360,7 +360,7 @@ func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
}
// CHECK: func @init_tensor_canonicalize
// CHECK: %[[T0:.+]] = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32>
-// CHECK: %[[T1:.+]] = tensor_cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
+// CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
// CHECK: return %[[T1]]
// -----
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index aa43f515f753..c893ee118a86 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -872,24 +872,24 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
// -----
-// Verify that tensor_cast folding uses the correct type
-// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned
-func @fold_tensor_cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
+// Verify that tensor.cast folding uses the correct type
+// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned
+func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
// CHECK: constant dense<2> : tensor<1xindex>
- // CHECK-NOT: tensor_cast
+ // CHECK-NOT: tensor.cast
%0 = shape.const_shape [2] : tensor<?xindex>
- %1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
+ %1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
return %1 : tensor<1xindex>
}
// -----
-// Verify that tensor_cast folding uses the correct type
-// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned_dynamic
-func @fold_tensor_cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
+// Verify that tensor.cast folding uses the correct type
+// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic
+func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
// CHECK: shape.const_shape [2] : tensor<?xindex>
- // CHECK-NOT: tensor_cast
+ // CHECK-NOT: tensor.cast
%0 = shape.const_shape [2] : tensor<1xindex>
- %1 = tensor_cast %0 : tensor<1xindex> to tensor<?xindex>
+ %1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
return %1 : tensor<?xindex>
}
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 8ae10ccf0f3b..4e8f1282c36b 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -75,39 +75,6 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
return %0 : tensor<f32>
}
-// CHECK-LABEL: func @tensor_cast(
-// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
-// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
-// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
-// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]]
-// CHECK: return %[[RET]] : tensor<2xindex>
-func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
- %0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
- return %0 : tensor<2xindex>
-}
-
-// CHECK-LABEL: func @tensor_cast_from_unranked(
-// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
-// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
-// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
-// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
-// CHECK: return %[[RET]] : tensor<2xf32>
-func @tensor_cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
- %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<2xf32>
- return %0 : tensor<2xf32>
-}
-
-// CHECK-LABEL: func @tensor_cast_to_unranked(
-// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
-// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
-// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
-// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
-// CHECK: return %[[RET]] : tensor<*xf32>
-func @tensor_cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
- %0 = tensor_cast %arg0 : tensor<2xf32> to tensor<*xf32>
- return %0 : tensor<*xf32>
-}
-
// CHECK-LABEL: func @tensor_from_elements(
// CHECK-SAME: %[[ELEM0:.*]]: index,
// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index af67453a1f3c..f3b7ccddd1ff 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -116,17 +116,17 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
return %1 : index
}
-// Test case: Folding dim(tensor_cast %0, %idx) -> dim %0, %idx
-// CHECK-LABEL: func @fold_dim_of_tensor_cast
+// Test case: Folding dim(tensor.cast %0, %idx) -> dim %0, %idx
+// CHECK-LABEL: func @fold_dim_of_tensor.cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK-NEXT: return %[[C4]], %[[T0]]
-func @fold_dim_of_tensor_cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
+func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
- %0 = tensor_cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
+ %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
%1 = dim %0, %c0 : tensor<?x?xf32>
%2 = dim %0, %c1 : tensor<?x?xf32>
return %1, %2: index, index
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index bb23322ca659..0e55040ec116 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,5 +1,38 @@
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
+// CHECK-LABEL: func @tensor.cast(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
+// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
+// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]]
+// CHECK: return %[[RET]] : tensor<2xindex>
+func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
+ %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex>
+ return %0 : tensor<2xindex>
+}
+
+// CHECK-LABEL: func @tensor.cast_from_unranked(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
+// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
+// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
+// CHECK: return %[[RET]] : tensor<2xf32>
+func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
+ %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @tensor.cast_to_unranked(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
+// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
+// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
+// CHECK: return %[[RET]] : tensor<*xf32>
+func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
+ %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
// CHECK-LABEL: func @extract(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> f32 {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 86cb3ee7388a..9dcd4da13cc5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1,4 +1,66 @@
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+// Checks that NOP casts are removed.
+// CHECK-LABEL: cast_values
+func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {
+ // NOP cast
+ %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32>
+ // CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32>
+ %2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32>
+ // NOP cast
+ %4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32>
+ // CHECK-NEXT: return %[[RET]] : tensor<2xi32>
+ return %4 : tensor<2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor.cast_chain_ok
+// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
+func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
+ // CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
+ %0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32>
+ %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
+ // CHECK-NEXT: return %[[RES]]
+ return %1 : tensor<4x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor.cast_chain_regain
+// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
+func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
+ %0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32>
+ %1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32>
+ // CHECK-NEXT: return %[[IN]]
+ return %1 : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor.cast_chain_keep
+// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
+func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
+ // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
+ %0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32>
+ // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
+ %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
+ // CHECK-NEXT: return %[[C2]]
+ return %1 : tensor<?x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor.cast_chain_invalid
+// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
+func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
+ // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
+ %0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32>
+ // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
+ %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
+ // CHECK-NEXT: return %[[C2]]
+ return %1 : tensor<8x4xi32>
+}
// -----
@@ -31,3 +93,17 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_tensor.cast
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
+ // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+ %c0 = constant 0 : index
+ // CHECK-NOT: tensor.cast
+ %casted = tensor.cast %tensor : tensor<*xf32> to tensor<?xf32>
+ // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
+ %result = tensor.extract %casted[%c0] : tensor<?xf32>
+ return %result : f32
+}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 3ddb84365381..cb38ac884bc3 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -1,4 +1,10 @@
-// RUN: mlir-opt <%s -verify-diagnostics
+// RUN: mlir-opt <%s -split-input-file -verify-diagnostics
+
+func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
+ // expected-error at +1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}}
+ %0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32>
+ return
+}
// -----
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 4d89b155f2a1..06db2bb237cd 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -1,5 +1,18 @@
// RUN: mlir-opt <%s | mlir-opt | FileCheck %s
+// CHECK-LABEL: func @cast(
+func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
+ // CHECK: tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+ %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+ // CHECK: tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
+ %1 = tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
+ // CHECK: tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
+ %2 = tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
+ // CHECK: tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
+ %3 = tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
+ return
+}
+
// CHECK-LABEL: func @extract(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?x?x?xf32>,
// CHECK-SAME: %[[INDEX:.*]]: index) {
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 502e7fb358fb..fc712d4939ba 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -696,23 +696,6 @@ func @tensor_from_elements() {
return
}
-// CHECK-LABEL: func @tensor_cast(%arg0
-func @tensor_cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
- // CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
- %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
-
- // CHECK: %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
- %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
-
- // CHECK: %2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
- %2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
-
- // CHECK: %3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
- %3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
-
- return
-}
-
// CHECK-LABEL: func @memref_cast(%arg0
func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) {
// CHECK: %0 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 95812acd10a3..f2296161ed7a 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -661,23 +661,15 @@ func @lowered_affine_ceildiv() -> (index, index) {
// Checks that NOP casts are removed.
// CHECK-LABEL: cast_values
-func @cast_values(%arg0: tensor<*xi32>, %arg1: memref<?xi32>) -> (tensor<2xi32>, memref<2xi32>) {
-
- // NOP casts
- %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<*xi32>
- %1 = memref_cast %arg1 : memref<?xi32> to memref<?xi32>
-
- // CHECK-NEXT: %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<2xi32>
- // CHECK-NEXT: %1 = memref_cast %arg1 : memref<?xi32> to memref<2xi32>
- %2 = tensor_cast %0 : tensor<*xi32> to tensor<2xi32>
+func @cast_values(%arg0: memref<?xi32>) -> memref<2xi32> {
+ // NOP cast
+ %1 = memref_cast %arg0 : memref<?xi32> to memref<?xi32>
+ // CHECK-NEXT: %[[RET:.*]] = memref_cast %arg0 : memref<?xi32> to memref<2xi32>
%3 = memref_cast %1 : memref<?xi32> to memref<2xi32>
-
- // NOP casts
- %4 = tensor_cast %2 : tensor<2xi32> to tensor<2xi32>
+ // NOP cast
%5 = memref_cast %3 : memref<2xi32> to memref<2xi32>
-
- // CHECK-NEXT: return %0, %1 : tensor<2xi32>, memref<2xi32>
- return %4, %5 : tensor<2xi32>, memref<2xi32>
+ // CHECK-NEXT: return %[[RET]] : memref<2xi32>
+ return %5 : memref<2xi32>
}
// -----
@@ -1121,61 +1113,12 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso
yield %1 : index
// CHECK: : tensor<3x?x5x7x?xindex>
} : tensor<3x?x?x7x?xindex>
- // CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
+ // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
return %0 : tensor<3x?x?x7x?xindex>
}
// -----
-// CHECK-LABEL: @tensor_cast_chain_ok
-// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
-func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
- // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
- %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32>
- %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
- // CHECK-NEXT: return %[[RES]]
- return %1 : tensor<4x8xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @tensor_cast_chain_regain
-// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
-func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
- %0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32>
- %1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32>
- // CHECK-NEXT: return %[[IN]]
- return %1 : tensor<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @tensor_cast_chain_keep
-// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
-func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
- // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
- %0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32>
- // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
- %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
- // CHECK-NEXT: return %[[C2]]
- return %1 : tensor<?x8xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @tensor_cast_chain_invalid
-// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
-func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
- // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
- %0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32>
- // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
- %1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
- // CHECK-NEXT: return %[[C2]]
- return %1 : tensor<8x4xi32>
-}
-
-// -----
-
// CHECK-LABEL: func @subtensor
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index
func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
@@ -1189,30 +1132,16 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
// CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] :
// CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32>
- // CHECK: tensor_cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
+ // CHECK: tensor.cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
%1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
: tensor<8x16x4xf32> to tensor<?x?x?xf32>
// Test: subtensor with one dynamic operand can also be folded.
// CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] :
// CHECK-SAME: tensor<?x?x?xf32> to tensor<2x?x2xf32>
- // CHECK: tensor_cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
+ // CHECK: tensor.cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
%2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1]
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
return %2 : tensor<?x?x?xf32>
}
-
-// -----
-
-// CHECK-LABEL: func @extract_from_tensor_cast
-// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
-func @extract_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
- // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
- %c0 = constant 0 : index
- // CHECK-NOT: tensor_cast
- %casted = tensor_cast %tensor : tensor<*xf32> to tensor<?xf32>
- // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
- %result = tensor.extract %casted[%c0] : tensor<?xf32>
- return %result : f32
-}
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index b6ac5b8e65fc..4ee10ef62295 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -68,10 +68,10 @@ func @
diff erent_ops() -> (i32, i32) {
/// types.
// CHECK-LABEL: @
diff erent_results
func @
diff erent_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4x?xf32>) {
- // CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
- // CHECK-NEXT: %1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
- %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
- %1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
+ // CHECK: %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+ // CHECK-NEXT: %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
+ %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+ %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
// CHECK-NEXT: return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
diff --git a/mlir/utils/vim/syntax/mlir.vim b/mlir/utils/vim/syntax/mlir.vim
index 1db630c0223f..88f54eaee293 100644
--- a/mlir/utils/vim/syntax/mlir.vim
+++ b/mlir/utils/vim/syntax/mlir.vim
@@ -40,7 +40,7 @@ syn keyword mlirOps alloc alloca addf addi and call call_indirect cmpf cmpi
syn keyword mlirOps constant dealloc divf dma_start dma_wait dim exp
syn keyword mlirOps getTensor index_cast load log memref_cast
syn keyword mlirOps memref_shape_cast mulf muli negf powf prefetch rsqrt sitofp
-syn keyword mlirOps splat store select sqrt subf subi subview tanh tensor_cast
+syn keyword mlirOps splat store select sqrt subf subi subview tanh
syn keyword mlirOps view
" Affine ops.
More information about the llvm-branch-commits
mailing list