[Mlir-commits] [mlir] 129d6e5 - [mlir] Move `std.tensor_cast` -> `tensor.cast`.

Sean Silva llvmlistbot at llvm.org
Thu Dec 17 16:07:46 PST 2020


Author: Sean Silva
Date: 2020-12-17T16:06:56-08:00
New Revision: 129d6e554e7a0dba3443ffd8f1df185b90cc6fd5

URL: https://github.com/llvm/llvm-project/commit/129d6e554e7a0dba3443ffd8f1df185b90cc6fd5
DIFF: https://github.com/llvm/llvm-project/commit/129d6e554e7a0dba3443ffd8f1df185b90cc6fd5.diff

LOG: [mlir] Move `std.tensor_cast` -> `tensor.cast`.

This is almost entirely mechanical.

Differential Revision: https://reviews.llvm.org/D93357

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Shape/IR/CMakeLists.txt
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
    mlir/lib/IR/Operation.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Standard/bufferize.mlir
    mlir/test/Dialect/Standard/canonicalize.mlir
    mlir/test/Dialect/Tensor/bufferize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/invalid.mlir
    mlir/test/Dialect/Tensor/ops.mlir
    mlir/test/IR/core-ops.mlir
    mlir/test/Transforms/canonicalize.mlir
    mlir/test/Transforms/cse.mlir
    mlir/utils/vim/syntax/mlir.vim

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 3df609f295cc..2ef32cfe378b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -34,7 +34,7 @@ namespace linalg {
 class LinalgDependenceGraph;
 
 /// A struct containing the Linalg producer before and after fusion.
-/// When operating on tensors, `fusedProducer` may feed into a `tensor_cast` op
+/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
 /// before the consumer Linalg op, until enough canonicalizations have applied.
 struct FusionInfo {
   LinalgOp originalProducer;

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 7302bd486657..56ff32252fee 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -354,31 +354,6 @@ computeRankReductionMask(ArrayRef<int64_t> originalShape,
 /// ```
 bool canFoldIntoConsumerOp(MemRefCastOp castOp);
 
-/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
-/// Determines whether TensorCastOp casts to a more dynamic version of the
-/// source tensor. This is useful to fold a tensor_cast into a consuming op and
-/// implement canonicalization patterns for ops in 
diff erent dialects that may
-/// consume the results of tensor_cast operations. Such foldable tensor_cast
-/// operations are typically inserted as `subtensor` ops and are canonicalized,
-/// to preserve the type compatibility of their uses.
-///
-/// Returns true when all conditions are met:
-/// 1. source and result are ranked tensors with same element type and rank.
-/// 2. the tensor type has more static information than the result
-///
-/// Example:
-/// ```mlir
-///   %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
-///   %2 = consumer %1 ... : tensor<?x?xf32> ...
-/// ```
-///
-/// folds into:
-///
-/// ```mlir
-///   %2 = consumer %0 ... : tensor<8x16xf32> ...
-/// ```
-bool canFoldIntoConsumerOp(TensorCastOp castOp);
-
 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
 /// comparison predicates.
 bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 481dfaf4b34d..7af44f8435ff 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -62,7 +62,7 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
   let printer = [{
     return printStandardCastOp(this->getOperation(), p);
   }];
-  let verifier = [{ return ::verifyCastOp(*this); }];
+  let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
 
   let hasFolder = 1;
 }
@@ -3428,56 +3428,6 @@ def TanhOp : FloatUnaryOp<"tanh"> {
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// TensorCastOp
-//===----------------------------------------------------------------------===//
-
-def TensorCastOp : CastOp<"tensor_cast"> {
-  let summary = "tensor cast operation";
-  let description = [{
-    Syntax:
-
-    ```
-    operation ::= ssa-id `=` `std.tensor_cast` ssa-use `:` type `to` type
-    ```
-
-    Convert a tensor from one type to an equivalent type without changing any
-    data elements. The source and destination types must both be tensor types
-    with the same element type. If both are ranked, then the rank should be the
-    same and static dimensions should match. The operation is invalid if
-    converting to a mismatching constant dimension.
-
-    Example:
-
-    ```mlir
-    // Convert from unknown rank to rank 2 with unknown dimension sizes.
-    %2 = "std.tensor_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
-    %2 = tensor_cast %1 : tensor<*xf32> to tensor<?x?xf32>
-
-    // Convert to a type with more known dimensions.
-    %3 = "std.tensor_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
-
-    // Discard static dimension and rank information.
-    %4 = "std.tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
-    %5 = "std.tensor_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
-    ```
-  }];
-
-  let arguments = (ins AnyTensor:$source);
-  let results = (outs AnyTensor);
-
-  let extraClassDeclaration = [{
-    /// Return true if `a` and `b` are valid operand and result pairs for
-    /// the operation.
-    static bool areCastCompatible(Type a, Type b);
-
-    /// The result of a tensor_cast is always a tensor.
-    TensorType getType() { return getResult().getType().cast<TensorType>(); }
-  }];
-
-  let hasCanonicalizer = 1;
-}
-
 //===----------------------------------------------------------------------===//
 // TensorLoadOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index ee517de3fca0..53980db64dc0 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -28,4 +28,38 @@
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
 
+//===----------------------------------------------------------------------===//
+// Tensor Dialect Helpers
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace tensor {
+
+/// Determines whether tensor::CastOp casts to a more dynamic version of the
+/// source tensor. This is useful to fold a tensor.cast into a consuming op and
+/// implement canonicalization patterns for ops in 
diff erent dialects that may
+/// consume the results of tensor.cast operations. Such foldable tensor.cast
+/// operations are typically inserted as `subtensor` ops and are canonicalized,
+/// to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked tensors with same element type and rank.
+/// 2. the tensor type has more static information than the result
+///
+/// Example:
+/// ```mlir
+///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+///   %2 = consumer %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = consumer %0 ... : tensor<8x16xf32> ...
+/// ```
+bool canFoldIntoConsumerOp(CastOp castOp);
+
+} // namespace tensor
+} // namespace mlir
+
 #endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 4eb989b2f3b5..e0500b8fcfa6 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -19,6 +19,52 @@ class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
+//===----------------------------------------------------------------------===//
+// CastOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
+  let summary = "tensor cast operation";
+  let description = [{
+    Convert a tensor from one type to an equivalent type without changing any
+    data elements. The source and destination types must both be tensor types
+    with the same element type. If both are ranked, then the rank should be the
+    same and static dimensions should match. The operation is invalid if
+    converting to a mismatching constant dimension.
+
+    Example:
+
+    ```mlir
+    // Convert from unknown rank to rank 2 with unknown dimension sizes.
+    %2 = tensor.cast %1 : tensor<*xf32> to tensor<?x?xf32>
+
+    // Convert to a type with more known dimensions.
+    %3 = tensor.cast %2 : tensor<?x?xf32> to tensor<4x?xf32>
+
+    // Discard static dimension and rank information.
+    %4 = tensor.cast %3 : tensor<4x?xf32> to tensor<?x?xf32>
+    %5 = tensor.cast %4 : tensor<?x?xf32> to tensor<*xf32>
+    ```
+  }];
+
+  let arguments = (ins AnyTensor:$source);
+  let results = (outs AnyTensor:$dest);
+  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+  let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
+
+  let extraClassDeclaration = [{
+    /// Return true if `a` and `b` are valid operand and result pairs for
+    /// the operation.
+    static bool areCastCompatible(Type a, Type b);
+
+    /// The result of a tensor.cast is always a tensor.
+    TensorType getType() { return getResult().getType().cast<TensorType>(); }
+  }];
+
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 9e4da63c3618..beb45ebfe08c 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1775,11 +1775,18 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p);
 // These functions are out-of-line implementations of the methods in CastOp,
 // which avoids them being template instantiated/duplicated.
 namespace impl {
+// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
+// need for them, but some older ODS code in `std` still depends on them).
 void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
                  Type destType);
 ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
 void printCastOp(Operation *op, OpAsmPrinter &p);
+// TODO: Create a CastOpInterface with a method areCastCompatible.
+// Also, consider adding functionality to CastOpInterface to be able to perform
+// the ChainedTensorCast canonicalization generically.
 Value foldCastOp(Operation *op);
+LogicalResult verifyCastOp(Operation *op,
+                           function_ref<bool(Type, Type)> areCastCompatible);
 } // namespace impl
 } // end namespace mlir
 

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
index 81fcc70e91be..f18b26506447 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -tensor-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
 // RUN: | FileCheck %s
@@ -8,7 +8,7 @@ func @main() {
   %b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
 
   %addf = addf %a, %b : tensor<3xf32>
-  %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
+  %addf_unranked = tensor.cast %addf : tensor<3xf32> to tensor<*xf32>
   call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
   // CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
   // CHECK-NEXT: [11,  22,  33]

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
index ac90fb28517b..f327bc97d7e0 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
+// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
+// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
+// RUN: -finalizing-bufferize \
 // RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@@ -15,14 +17,14 @@ func @main() {
   %inserted_at_position_0 = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
   %inserted_at_position_1 = subtensor_insert %insert_val into %const[1][1][1] : tensor<1xf32> into tensor<2xf32>
 
-  %unranked_at_position_0 = tensor_cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
+  %unranked_at_position_0 = tensor.cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
   call @print_memref_f32(%unranked_at_position_0) : (tensor<*xf32>) -> ()
 
   //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
   // CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
   // CHECK-NEXT: [20, 10]
 
-  %unranked_at_position_1 = tensor_cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
+  %unranked_at_position_1 = tensor.cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
   call @print_memref_f32(%unranked_at_position_1) : (tensor<*xf32>) -> ()
 
   //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
index e568d6acf9ee..6bc886a5d4c4 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
+// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
+// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
+// RUN: -finalizing-bufferize \
 // RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@@ -9,7 +11,7 @@ func @main() {
   %insert_val = constant dense<20.0> : tensor<1xf32>
   %inserted = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
 
-  %unranked = tensor_cast %inserted : tensor<2xf32> to tensor<*xf32>
+  %unranked = tensor.cast %inserted : tensor<2xf32> to tensor<*xf32>
   call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
 
   //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
index 61fc05f8c20b..71e27733c83f 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -tensor-constant-bufferize -std-bufferize -linalg-bufferize \
-// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
+// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
 // RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@@ -19,7 +19,7 @@ func @main() {
   // Note that this is skipping a step and we would need at least some function
   // attribute to declare that this conversion is valid (e.g. when we statically
   // know that things will play nicely at the C ABI boundary).
-  %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
+  %unranked = tensor.cast %0 : tensor<4xf32> to tensor<*xf32>
   call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
 
   //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
index e535febcf7dc..38d97332f0d7 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
@@ -1,12 +1,13 @@
 // RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize \
-// RUN: -func-bufferize -finalizing-bufferize  -convert-linalg-to-loops \
+// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize  -convert-linalg-to-loops \
 // RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
 // RUN: | FileCheck %s
 
 // RUN: mlir-opt %s  -linalg-tile="linalg-tile-sizes=1,2,3" -linalg-bufferize \
-// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
+// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -tensor-bufferize \
+// RUN: -func-bufferize \
 // RUN: -finalizing-bufferize -convert-linalg-to-loops -convert-scf-to-std \
 // RUN: -convert-linalg-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
@@ -23,7 +24,7 @@ func @main() {
   %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
                      init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
 
-  %unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32>
+  %unranked = tensor.cast %D : tensor<2x4xf32> to tensor<*xf32>
   call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
 
   //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 7189c30e766a..0d87d4f10975 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -103,9 +103,9 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
   auto erasedRankType =
       RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
   Value rankErasedLhs =
-      rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
+      rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
   Value rankErasedRhs =
-      rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
+      rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
   Value lesserRankOperand =
       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
   Value greaterRankOperand =
@@ -186,7 +186,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
   Value tensor =
       rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
-  rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
+  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
   return success();
 }
 
@@ -246,9 +246,9 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
   auto erasedRankType =
       RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
   Value rankErasedLhs =
-      rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
+      rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
   Value rankErasedRhs =
-      rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
+      rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
   Value lesserRankOperand =
       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
   Value greaterRankOperand =
@@ -528,8 +528,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
     // Materialize extent tensor.
     Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
         loc, rewriter.getIndexType(), extentValues);
-    rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
-                                              op.getType());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
+                                                staticExtentTensor);
     return success();
   }
 
@@ -561,8 +561,8 @@ class ToExtentTensorOpConversion
     if (!adaptor.input().getType().isa<RankedTensorType>())
       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
 
-    rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
-                                              op.getType());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
+                                                adaptor.input());
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index 15e29d749e65..3ed79a554b31 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -16,4 +16,5 @@ add_mlir_dialect_library(MLIRLinalg
   MLIRSideEffectInterfaces
   MLIRViewLikeInterface
   MLIRStandard
+  MLIRTensor
   )

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f2b05448dbd0..3a7249df8e79 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -651,7 +651,7 @@ struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
     auto newOp =
         rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
                                       rewriter.getI64ArrayAttr(staticSizes));
-    rewriter.replaceOpWithNewOp<TensorCastOp>(op, op.getType(), newOp);
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
     return success();
   }
 };
@@ -1815,12 +1815,12 @@ struct FoldTensorCastOp : public RewritePattern {
     if (!linalgOp)
       return failure();
 
-    // If no operand comes from a TensorCastOp and can be folded then fail.
+    // If no operand comes from a tensor::CastOp and can be folded then fail.
     bool hasTensorCastOperand =
         llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
           if (v.isa<BlockArgument>())
             return false;
-          auto castOp = v.getDefiningOp<TensorCastOp>();
+          auto castOp = v.getDefiningOp<tensor::CastOp>();
           return castOp && canFoldIntoConsumerOp(castOp);
         });
     if (!hasTensorCastOperand)
@@ -1832,7 +1832,7 @@ struct FoldTensorCastOp : public RewritePattern {
     newOperands.reserve(op->getNumOperands());
     // Inputs may fold.
     for (Value v : linalgOp.getInputs()) {
-      auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
+      auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
       newOperands.push_back(
           canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
     }
@@ -1841,7 +1841,7 @@ struct FoldTensorCastOp : public RewritePattern {
                        linalgOp.getOutputBuffers().end());
     // Init tensors may fold, in which case the resultType must also change.
     for (Value v : linalgOp.getInitTensors()) {
-      auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
+      auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
       bool fold = canFoldIntoConsumerOp(tensorCastOp);
       newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
       newResultTypes.push_back(newOperands.back().getType());

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index c77398def98c..ba31ca5a034b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -59,6 +59,7 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
 
 void mlir::linalg::LinalgDialect::initialize() {
   getContext()->getOrLoadDialect("std");
+  getContext()->getOrLoadDialect("tensor");
 
   addTypes<RangeType>();
   addOperations<

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 6de4ce6ac341..42e2d4dcd244 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -36,6 +36,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRStandard
   MLIRStandardOpsTransforms
   MLIRStandardToLLVM
+  MLIRTensor
   MLIRTransforms
   MLIRTransformUtils
   MLIRVector

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 681309687d18..d9ea7d8ccb29 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Dominance.h"
@@ -517,13 +518,13 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
   // Replace use.
   // Canonicalizations are not guaranteed to have happened before constructing
   // `fusedProducer`. In the tensor case this can result in temporary type
-  // mismatches. Insert a `tensor_cast` op to propagate the transformation
+  // mismatches. Insert a `tensor.cast` op to propagate the transformation
   // invariant that types are compatible.
   Value def = fusedProducer->getResult(producerIdx);
   OpOperand &use = consumer->getOpOperand(consumerIdx);
   Type consumerType = use.get().getType();
   if (consumerType != def.getType())
-    def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
+    def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
   use.set(def);
   return FusionInfo{producerOp, fusedProducer};
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 50a18d4fb01c..423d687c1eb8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SCF/EDSC/Builders.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
@@ -569,7 +570,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
   ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
   SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
   SubViewOp::getCanonicalizationPatterns(patterns, ctx);
-  TensorCastOp::getCanonicalizationPatterns(patterns, ctx);
+  tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
   ViewOp::getCanonicalizationPatterns(patterns, ctx);
   CanonicalizationPatternList<
 #define GET_OP_LIST

diff  --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
index 1ac5b3b1e856..f8321842db31 100644
--- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
@@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRShape
   MLIRIR
   MLIRSideEffectInterfaces
   MLIRStandard
+  MLIRTensor
   )

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ef29ddc510ae..0478cb7872cc 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Shape/IR/Shape.h"
 
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"

diff  --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 4e6d062a232f..45c699baece3 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -1,5 +1,6 @@
 include "mlir/Dialect/Shape/IR/ShapeOps.td"
 include "mlir/Dialect/StandardOps/IR/Ops.td"
+include "mlir/Dialect/Tensor/IR/TensorOps.td"
 
 def AllInputShapesEq : Constraint<CPred< [{
   llvm::all_of($0, [&](mlir::Value val) {
@@ -32,7 +33,7 @@ def SizeToIndexToSizeCanonicalization : Pat<
   (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
   (replaceWithValue $arg)>;
 
-// Fold tensor_cast(const_shape) to const_shape. This changes the type of
+// Fold tensor.cast(const_shape) to const_shape. This changes the type of
 // const_shape to the destination type of the cast.
 def TensorCastConstShape : Pat <
-  (TensorCastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
+  (Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 7ed9cffa8806..c0af06314086 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -141,18 +141,6 @@ static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
     << op->getResult(0).getType();
 }
 
-/// A custom cast operation verifier.
-template <typename T>
-static LogicalResult verifyCastOp(T op) {
-  auto opType = op.getOperand().getType();
-  auto resType = op.getType();
-  if (!T::areCastCompatible(opType, resType))
-    return op.emitError("operand type ") << opType << " and result type "
-                                         << resType << " are cast incompatible";
-
-  return success();
-}
-
 void StandardOpsDialect::initialize() {
   getContext()->loadDialect<tensor::TensorDialect>();
   addOperations<DmaStartOp, DmaWaitOp,
@@ -1494,7 +1482,7 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
 
 void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
-  results.insert<DimOfMemRefReshape, DimOfCastOp<TensorCastOp>>(context);
+  results.insert<DimOfMemRefReshape, DimOfCastOp<tensor::CastOp>>(context);
 }
 
 // ---------------------------------------------------------------------------
@@ -1870,8 +1858,8 @@ struct StaticDynamicTensorFromElements
         newOperands);
     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
                                 newOp.body().begin());
-    rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
-                                              newOp);
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
+                                                newOp);
     return success();
   }
 };
@@ -1913,7 +1901,7 @@ struct ExtractFromDynamicTensorFromElements
 
 /// Canonicalizes the pattern of the form
 ///
-/// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
+/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
 ///
 /// to
@@ -1924,7 +1912,7 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
 
   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
                                 PatternRewriter &rewriter) const final {
-    auto tensorCast = extract.tensor().getDefiningOp<TensorCastOp>();
+    auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
     if (!tensorCast)
       return failure();
 
@@ -3395,7 +3383,7 @@ static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
 
 static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
                              SubTensorOp newOp) {
-  rewriter.replaceOpWithNewOp<TensorCastOp>(op, newOp, op.getType());
+  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
 }
 
 /// Pattern to rewrite a subview op with constant arguments.
@@ -3536,60 +3524,6 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
   return true;
 }
 
-/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
-/// Determines whether TensorCastOp casts to a more dynamic version of the
-/// source tensor. This is useful to fold a tensor_cast into a consuming op and
-/// implement canonicalization patterns for ops in 
diff erent dialects that may
-/// consume the results of tensor_cast operations. Such foldable tensor_cast
-/// operations are typically inserted as `subtensor` ops and are canonicalized,
-/// to preserve the type compatibility of their uses.
-///
-/// Returns true when all conditions are met:
-/// 1. source and result are ranked tensors with same element type and rank.
-/// 2. the tensor type has more static information than the result
-///
-/// Example:
-/// ```mlir
-///   %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
-///   %2 = consumer %1 ... : tensor<?x?xf32> ...
-/// ```
-///
-/// folds into:
-///
-/// ```mlir
-///   %2 = consumer %0 ... : tensor<8x16xf32> ...
-/// ```
-bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) {
-  if (!castOp)
-    return false;
-
-  RankedTensorType sourceType =
-      castOp.source().getType().dyn_cast<RankedTensorType>();
-  RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
-
-  // Requires RankedTensorType.
-  if (!sourceType || !resultType)
-    return false;
-
-  // Requires same elemental type.
-  if (sourceType.getElementType() != resultType.getElementType())
-    return false;
-
-  // Requires same rank.
-  if (sourceType.getRank() != resultType.getRank())
-    return false;
-
-  // If cast is towards more static sizes along any dimension, don't fold.
-  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
-    auto ss = std::get<0>(it), st = std::get<1>(it);
-    if (ss != st)
-      if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
-        return false;
-  }
-
-  return true;
-}
-
 namespace {
 /// Pattern to rewrite a subview op with MemRefCast arguments.
 /// This essentially pushes memref_cast past its consuming subview when
@@ -3857,107 +3791,6 @@ static LogicalResult verify(SubTensorInsertOp op) {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// TensorCastOp
-//===----------------------------------------------------------------------===//
-
-bool TensorCastOp::areCastCompatible(Type a, Type b) {
-  auto aT = a.dyn_cast<TensorType>();
-  auto bT = b.dyn_cast<TensorType>();
-  if (!aT || !bT)
-    return false;
-
-  if (aT.getElementType() != bT.getElementType())
-    return false;
-
-  return succeeded(verifyCompatibleShape(aT, bT));
-}
-
-OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
-  return impl::foldCastOp(*this);
-}
-
-/// Compute a TensorType that has the joined shape knowledge of the two
-/// given TensorTypes. The element types need to match.
-static TensorType joinShapes(TensorType one, TensorType two) {
-  assert(one.getElementType() == two.getElementType());
-
-  if (!one.hasRank())
-    return two;
-  if (!two.hasRank())
-    return one;
-
-  int64_t rank = one.getRank();
-  if (rank != two.getRank())
-    return {};
-
-  SmallVector<int64_t, 4> join;
-  join.reserve(rank);
-  for (int64_t i = 0; i < rank; ++i) {
-    if (one.isDynamicDim(i)) {
-      join.push_back(two.getDimSize(i));
-      continue;
-    }
-    if (two.isDynamicDim(i)) {
-      join.push_back(one.getDimSize(i));
-      continue;
-    }
-    if (one.getDimSize(i) != two.getDimSize(i))
-      return {};
-    join.push_back(one.getDimSize(i));
-  }
-  return RankedTensorType::get(join, one.getElementType());
-}
-
-namespace {
-
-/// Replaces chains of two tensor_cast operations by a single tensor_cast
-/// operation if doing so does not remove runtime constraints.
-struct ChainedTensorCast : public OpRewritePattern<TensorCastOp> {
-  using OpRewritePattern<TensorCastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TensorCastOp tensorCast,
-                                PatternRewriter &rewriter) const final {
-    auto tensorCastOperand =
-        tensorCast.getOperand().getDefiningOp<TensorCastOp>();
-
-    if (!tensorCastOperand)
-      return failure();
-
-    auto sourceType =
-        tensorCastOperand.getOperand().getType().cast<TensorType>();
-    auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
-    auto resultType = tensorCast.getType().cast<TensorType>();
-
-    // We can remove the intermediate cast if joining all three produces the
-    // same result as just joining the source and result shapes.
-    auto firstJoin =
-        joinShapes(joinShapes(sourceType, intermediateType), resultType);
-
-    // The join might not exist if the cast sequence would fail at runtime.
-    if (!firstJoin)
-      return failure();
-
-    // The newJoin always exists if the above join exists, it might just contain
-    // less information. If so, we cannot drop the intermediate cast, as doing
-    // so would remove runtime checks.
-    auto newJoin = joinShapes(sourceType, resultType);
-    if (firstJoin != newJoin)
-      return failure();
-
-    rewriter.replaceOpWithNewOp<TensorCastOp>(tensorCast, resultType,
-                                              tensorCastOperand.getOperand());
-    return success();
-  }
-};
-
-} // namespace
-
-void TensorCastOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<ChainedTensorCast>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // TensorLoadOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index a84934b0ebb8..98792838deff 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -117,20 +117,6 @@ class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
 };
 } // namespace
 
-namespace {
-class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto resultType = getTypeConverter()->convertType(op.getType());
-    rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
-    return success();
-  }
-};
-} // namespace
-
 namespace {
 class BufferizeTensorFromElementsOp
     : public OpConversionPattern<TensorFromElementsOp> {
@@ -162,7 +148,6 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context,
       BufferizeDimOp,
       BufferizeDynamicTensorFromElementsOp,
       BufferizeSelectOp,
-      BufferizeTensorCastOp,
       BufferizeTensorFromElementsOp
       // clang-format on
       >(typeConverter, context);
@@ -180,8 +165,7 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
     target.addLegalDialect<scf::SCFDialect>();
 
     populateStdBufferizePatterns(context, typeConverter, patterns);
-    target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp,
-                        TensorFromElementsOp>();
+    target.addIllegalOp<DynamicTensorFromElementsOp, TensorFromElementsOp>();
     // We only bufferize the case of tensor selected type and scalar condition,
     // as that boils down to a select over memref descriptors (don't need to
     // touch the data).

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bb944b21e3c3..aaae7fbf807c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -8,12 +8,165 @@
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
 
+//===----------------------------------------------------------------------===//
+// CastOp
+//===----------------------------------------------------------------------===//
+
+/// Determines whether tensor::CastOp casts to a more dynamic version of the
+/// source tensor. This is useful to fold a tensor.cast into a consuming op and
+/// implement canonicalization patterns for ops in 
diff erent dialects that may
+/// consume the results of tensor.cast operations. Such foldable tensor.cast
+/// operations are typically inserted as `subtensor` ops and are canonicalized,
+/// to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked tensors with same element type and rank.
+/// 2. the tensor type has more static information than the result
+///
+/// Example:
+/// ```mlir
+///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+///   %2 = consumer %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = consumer %0 ... : tensor<8x16xf32> ...
+/// ```
+bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
+  if (!castOp)
+    return false;
+
+  RankedTensorType sourceType =
+      castOp.source().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
+
+  // Requires RankedTensorType.
+  if (!sourceType || !resultType)
+    return false;
+
+  // Requires same elemental type.
+  if (sourceType.getElementType() != resultType.getElementType())
+    return false;
+
+  // Requires same rank.
+  if (sourceType.getRank() != resultType.getRank())
+    return false;
+
+  // If cast is towards more static sizes along any dimension, don't fold.
+  for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
+    if (ShapedType::isDynamic(std::get<0>(t)) &&
+        !ShapedType::isDynamic(std::get<1>(t)))
+      return false;
+  }
+
+  return true;
+}
+
+bool CastOp::areCastCompatible(Type a, Type b) {
+  auto aT = a.dyn_cast<TensorType>();
+  auto bT = b.dyn_cast<TensorType>();
+  if (!aT || !bT)
+    return false;
+
+  if (aT.getElementType() != bT.getElementType())
+    return false;
+
+  return succeeded(verifyCompatibleShape(aT, bT));
+}
+
+OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+  return impl::foldCastOp(*this);
+}
+
+/// Compute a TensorType that has the joined shape knowledge of the two
+/// given TensorTypes. The element types need to match.
+static TensorType joinShapes(TensorType one, TensorType two) {
+  assert(one.getElementType() == two.getElementType());
+
+  if (!one.hasRank())
+    return two;
+  if (!two.hasRank())
+    return one;
+
+  int64_t rank = one.getRank();
+  if (rank != two.getRank())
+    return {};
+
+  SmallVector<int64_t, 4> join;
+  join.reserve(rank);
+  for (int64_t i = 0; i < rank; ++i) {
+    if (one.isDynamicDim(i)) {
+      join.push_back(two.getDimSize(i));
+      continue;
+    }
+    if (two.isDynamicDim(i)) {
+      join.push_back(one.getDimSize(i));
+      continue;
+    }
+    if (one.getDimSize(i) != two.getDimSize(i))
+      return {};
+    join.push_back(one.getDimSize(i));
+  }
+  return RankedTensorType::get(join, one.getElementType());
+}
+
+namespace {
+
+/// Replaces chains of two tensor.cast operations by a single tensor.cast
+/// operation if doing so does not remove runtime constraints.
+struct ChainedTensorCast : public OpRewritePattern<CastOp> {
+  using OpRewritePattern<CastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CastOp tensorCast,
+                                PatternRewriter &rewriter) const final {
+    auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
+
+    if (!tensorCastOperand)
+      return failure();
+
+    auto sourceType =
+        tensorCastOperand.getOperand().getType().cast<TensorType>();
+    auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
+    auto resultType = tensorCast.getType().cast<TensorType>();
+
+    // We can remove the intermediate cast if joining all three produces the
+    // same result as just joining the source and result shapes.
+    auto firstJoin =
+        joinShapes(joinShapes(sourceType, intermediateType), resultType);
+
+    // The join might not exist if the cast sequence would fail at runtime.
+    if (!firstJoin)
+      return failure();
+
+    // The newJoin always exists if the above join exists, it might just contain
+    // less information. If so, we cannot drop the intermediate cast, as doing
+    // so would remove runtime checks.
+    auto newJoin = joinShapes(sourceType, resultType);
+    if (firstJoin != newJoin)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
+                                        tensorCastOperand.getOperand());
+    return success();
+  }
+};
+
+} // namespace
+
+void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                         MLIRContext *context) {
+  results.insert<ChainedTensorCast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 9e6b3dba74a8..05ff96fb8d69 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -19,6 +19,20 @@
 
 using namespace mlir;
 
+namespace {
+class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto resultType = getTypeConverter()->convertType(op.getType());
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
+    return success();
+  }
+};
+} // namespace
+
 namespace {
 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
 public:
@@ -37,7 +51,7 @@ class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
 void mlir::populateTensorBufferizePatterns(
     MLIRContext *context, BufferizeTypeConverter &typeConverter,
     OwningRewritePatternList &patterns) {
-  patterns.insert<BufferizeExtractOp>(typeConverter, context);
+  patterns.insert<BufferizeCastOp, BufferizeExtractOp>(typeConverter, context);
 }
 
 namespace {
@@ -49,7 +63,7 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
     ConversionTarget target(*context);
 
     populateTensorBufferizePatterns(context, typeConverter, patterns);
-    target.addIllegalOp<tensor::ExtractOp>();
+    target.addIllegalOp<tensor::CastOp, tensor::ExtractOp>();
     target.addLegalDialect<StandardOpsDialect>();
 
     if (failed(

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 1b8d2875f3c8..c84a11bcec3f 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1213,6 +1213,19 @@ Value impl::foldCastOp(Operation *op) {
   return nullptr;
 }
 
+LogicalResult
+impl::verifyCastOp(Operation *op,
+                   function_ref<bool(Type, Type)> areCastCompatible) {
+  auto opType = op->getOperand(0).getType();
+  auto resType = op->getResult(0).getType();
+  if (!areCastCompatible(opType, resType))
+    return op->emitError("operand type ")
+           << opType << " and result type " << resType
+           << " are cast incompatible";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Misc. utils
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index b7663a328986..c2ab39f338b1 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -95,7 +95,7 @@ func @const_shape() -> tensor<?xindex> {
   // CHECK: %[[C2:.*]] = constant 2 : index
   // CHECK: %[[C3:.*]] = constant 3 : index
   // CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]]
-  // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
+  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
   // CHECK: return %[[RESULT]] : tensor<?xindex>
   %shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
   return %shape : tensor<?xindex>
@@ -108,7 +108,7 @@ func @const_shape() -> tensor<?xindex> {
 // CHECK-SAME: () -> tensor<?xindex>
 func @const_shape_zero_elements() -> tensor<?xindex> {
   // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
-  // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
+  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
   // CHECK: return %[[RESULT]] : tensor<?xindex>
   %shape = shape.const_shape [] : tensor<?xindex>
   return %shape : tensor<?xindex>
@@ -152,13 +152,13 @@ func @const_size() -> index {
 
 // -----
 
-// Lower `to_extent_tensor` to `std.tensor_cast`
+// Lower `to_extent_tensor` to `tensor.cast`
 // Fold to_extent_tensor when already on tensor.
 // CHECK-LABEL: @to_extent_tensor
 // CHECK-SAME: (%[[ARG:.*]]: tensor<?xindex>
 func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> {
   // CHECK-NOT: to_extent_tensor
-  // CHECK: %[[RES:.*]] = tensor_cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
+  // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
   %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<3xindex>
   // CHECK: return %[[RES]]
   return %casted : tensor<3xindex>
@@ -316,8 +316,8 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
   // CHECK:           %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
   // CHECK:           %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
   // CHECK:           %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-  // CHECK:           %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
-  // CHECK:           %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+  // CHECK:           %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
+  // CHECK:           %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
   // CHECK:           %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
   // CHECK:           %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
   // CHECK:           %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
@@ -356,8 +356,8 @@ func @broadcast_known_
diff erent_extents(%a : tensor<2xindex>, %b : tensor<3xinde
   // CHECK:           %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
   // CHECK:           %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
   // CHECK:           %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-  // CHECK:           %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
-  // CHECK:           %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
+  // CHECK:           %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
+  // CHECK:           %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
   // CHECK:           %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
   // CHECK:           %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
   // CHECK:           %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
@@ -400,8 +400,8 @@ func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
 // CHECK:           %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
 // CHECK:           %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
 // CHECK:           %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
-// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
+// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
 // CHECK:           %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
 // CHECK:           %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
 // CHECK:           %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
@@ -438,8 +438,8 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
 // CHECK:           %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
 // CHECK:           %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
 // CHECK:           %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
-// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
 // CHECK:           %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
 // CHECK:           %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
 // CHECK:           %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 96ab1aa93355..6c12070e07f1 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -317,20 +317,20 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
 
 // -----
 
-// CHECK-LABEL: func @tensor_cast(
-func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
+// CHECK-LABEL: func @tensor.cast(
+func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
   -> tensor<3x?xf32>
 {
-  %ta = tensor_cast %a : tensor<3x4xf32> to tensor<?x?xf32>
-  %tb = tensor_cast %b : tensor<4x?xf32> to tensor<?x?xf32>
-  %tc = tensor_cast %c : tensor<3x?xf32> to tensor<?x?xf32>
+  %ta = tensor.cast %a : tensor<3x4xf32> to tensor<?x?xf32>
+  %tb = tensor.cast %b : tensor<4x?xf32> to tensor<?x?xf32>
+  %tc = tensor.cast %c : tensor<3x?xf32> to tensor<?x?xf32>
 
   //      CHECK:  linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
   // CHECK-SAME:    init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
   %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
                init(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
 
-  %1 = tensor_cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
+  %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
 
   return %1: tensor<3x?xf32>
 }
@@ -360,7 +360,7 @@ func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
 }
 // CHECK: func @init_tensor_canonicalize
 // CHECK:   %[[T0:.+]] = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32>
-// CHECK:   %[[T1:.+]] = tensor_cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
+// CHECK:   %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
 // CHECK:   return %[[T1]]
 
 // -----

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index aa43f515f753..c893ee118a86 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -872,24 +872,24 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
 
 // -----
 
-// Verify that tensor_cast folding uses the correct type
-// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned
-func @fold_tensor_cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
+// Verify that tensor.cast folding uses the correct type
+// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned
+func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
   // CHECK: constant dense<2> : tensor<1xindex>
-  // CHECK-NOT: tensor_cast
+  // CHECK-NOT: tensor.cast
   %0 = shape.const_shape [2] : tensor<?xindex>
-  %1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
+  %1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
   return %1 : tensor<1xindex>
 }
 
 // -----
 
-// Verify that tensor_cast folding uses the correct type
-// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned_dynamic
-func @fold_tensor_cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
+// Verify that tensor.cast folding uses the correct type
+// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic
+func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
   // CHECK: shape.const_shape [2] : tensor<?xindex>
-  // CHECK-NOT: tensor_cast
+  // CHECK-NOT: tensor.cast
   %0 = shape.const_shape [2] : tensor<1xindex>
-  %1 = tensor_cast %0 : tensor<1xindex> to tensor<?xindex>
+  %1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
   return %1 : tensor<?xindex>
 }

diff  --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 8ae10ccf0f3b..4e8f1282c36b 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -75,39 +75,6 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
   return %0 : tensor<f32>
 }
 
-// CHECK-LABEL:   func @tensor_cast(
-// CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
-// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
-// CHECK:           %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
-// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED]]
-// CHECK:           return %[[RET]] : tensor<2xindex>
-func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
-  %0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
-  return %0 : tensor<2xindex>
-}
-
-// CHECK-LABEL:   func @tensor_cast_from_unranked(
-// CHECK-SAME:                                    %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
-// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
-// CHECK:           %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
-// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
-// CHECK:           return %[[RET]] : tensor<2xf32>
-func @tensor_cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
-  %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<2xf32>
-  return %0 : tensor<2xf32>
-}
-
-// CHECK-LABEL:   func @tensor_cast_to_unranked(
-// CHECK-SAME:                                  %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
-// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
-// CHECK:           %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
-// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
-// CHECK:           return %[[RET]] : tensor<*xf32>
-func @tensor_cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
-  %0 = tensor_cast %arg0 : tensor<2xf32> to tensor<*xf32>
-  return %0 : tensor<*xf32>
-}
-
 // CHECK-LABEL:   func @tensor_from_elements(
 // CHECK-SAME:                               %[[ELEM0:.*]]: index,
 // CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index af67453a1f3c..f3b7ccddd1ff 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -116,17 +116,17 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
   return %1 : index
 }
 
-// Test case: Folding dim(tensor_cast %0, %idx) -> dim %0, %idx
-// CHECK-LABEL: func @fold_dim_of_tensor_cast
+// Test case: Folding dim(tensor.cast %0, %idx) -> dim %0, %idx
+// CHECK-LABEL: func @fold_dim_of_tensor.cast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
 //   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
 //   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
 //       CHECK:   %[[T0:.+]] = dim %[[ARG0]], %[[C1]]
 //  CHECK-NEXT:   return %[[C4]], %[[T0]]
-func @fold_dim_of_tensor_cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
+func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
-  %0 = tensor_cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
+  %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
   %1 = dim %0, %c0 : tensor<?x?xf32>
   %2 = dim %0, %c1 : tensor<?x?xf32>
   return %1, %2: index, index

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index bb23322ca659..0e55040ec116 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,5 +1,38 @@
 // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
 
+// CHECK-LABEL:   func @tensor.cast(
+// CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
+// CHECK:           %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
+// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED]]
+// CHECK:           return %[[RET]] : tensor<2xindex>
+func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
+  %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex>
+  return %0 : tensor<2xindex>
+}
+
+// CHECK-LABEL:   func @tensor.cast_from_unranked(
+// CHECK-SAME:                                    %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
+// CHECK:           %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
+// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
+// CHECK:           return %[[RET]] : tensor<2xf32>
+func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
+  %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL:   func @tensor.cast_to_unranked(
+// CHECK-SAME:                                  %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
+// CHECK:           %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
+// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
+// CHECK:           return %[[RET]] : tensor<*xf32>
+func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
+  %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
 // CHECK-LABEL:   func @extract(
 // CHECK-SAME:                  %[[TENSOR:.*]]: tensor<?xf32>,
 // CHECK-SAME:                  %[[IDX:.*]]: index) -> f32 {

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 86cb3ee7388a..9dcd4da13cc5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1,4 +1,66 @@
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+// Checks that NOP casts are removed.
+// CHECK-LABEL: cast_values
+func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {
+  // NOP cast
+  %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32>
+  // CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32>
+  %2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32>
+  // NOP cast
+  %4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32>
+  // CHECK-NEXT: return %[[RET]] : tensor<2xi32>
+  return %4 : tensor<2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor.cast_chain_ok
+// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
+func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
+  // CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
+  %0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32>
+  %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
+  // CHECK-NEXT: return %[[RES]]
+  return %1 : tensor<4x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor.cast_chain_regain
+// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
+func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
+  %0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32>
+  %1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32>
+  // CHECK-NEXT: return %[[IN]]
+  return %1 : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor.cast_chain_keep
+// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
+func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
+  // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
+  %0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32>
+  // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
+  %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
+  // CHECK-NEXT: return %[[C2]]
+  return %1 : tensor<?x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor.cast_chain_invalid
+// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
+func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
+  // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
+  %0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32>
+  // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
+  %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
+  // CHECK-NEXT: return %[[C2]]
+  return %1 : tensor<8x4xi32>
+}
 
 // -----
 
@@ -31,3 +93,17 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
   // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
   return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_from_tensor.cast
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
+  // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+  %c0 = constant 0 : index
+  // CHECK-NOT: tensor.cast
+  %casted = tensor.cast %tensor : tensor<*xf32> to tensor<?xf32>
+  // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
+  %result = tensor.extract %casted[%c0] : tensor<?xf32>
+  return %result : f32
+}

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 3ddb84365381..cb38ac884bc3 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -1,4 +1,10 @@
-// RUN: mlir-opt <%s -verify-diagnostics
+// RUN: mlir-opt <%s -split-input-file -verify-diagnostics
+
+func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
+  // expected-error at +1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}}
+  %0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32>
+  return
+}
 
 // -----
 

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 4d89b155f2a1..06db2bb237cd 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -1,5 +1,18 @@
 // RUN: mlir-opt <%s | mlir-opt | FileCheck %s
 
+// CHECK-LABEL: func @cast(
+func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
+  // CHECK: tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+  %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+  // CHECK: tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
+  %1 = tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
+  // CHECK: tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
+  %2 = tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
+  // CHECK: tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
+  %3 = tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
+  return
+}
+
 // CHECK-LABEL:   func @extract(
 // CHECK-SAME:                  %[[TENSOR:.*]]: tensor<?x?x?xf32>,
 // CHECK-SAME:                  %[[INDEX:.*]]: index) {

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 502e7fb358fb..fc712d4939ba 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -696,23 +696,6 @@ func @tensor_from_elements() {
   return
 }
 
-// CHECK-LABEL: func @tensor_cast(%arg0
-func @tensor_cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
-  // CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
-  %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
-
-  // CHECK: %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
-  %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
-
-  // CHECK: %2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
-  %2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
-
-  // CHECK: %3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
-  %3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
-
-  return
-}
-
 // CHECK-LABEL: func @memref_cast(%arg0
 func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) {
   // CHECK: %0 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 95812acd10a3..f2296161ed7a 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -661,23 +661,15 @@ func @lowered_affine_ceildiv() -> (index, index) {
 
 // Checks that NOP casts are removed.
 // CHECK-LABEL: cast_values
-func @cast_values(%arg0: tensor<*xi32>, %arg1: memref<?xi32>) -> (tensor<2xi32>, memref<2xi32>) {
-
-  // NOP casts
-  %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<*xi32>
-  %1 = memref_cast %arg1 : memref<?xi32> to memref<?xi32>
-
-  // CHECK-NEXT: %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<2xi32>
-  // CHECK-NEXT: %1 = memref_cast %arg1 : memref<?xi32> to memref<2xi32>
-  %2 = tensor_cast %0 : tensor<*xi32> to tensor<2xi32>
+func @cast_values(%arg0: memref<?xi32>) -> memref<2xi32> {
+  // NOP cast
+  %1 = memref_cast %arg0 : memref<?xi32> to memref<?xi32>
+  // CHECK-NEXT: %[[RET:.*]] = memref_cast %arg0 : memref<?xi32> to memref<2xi32>
   %3 = memref_cast %1 : memref<?xi32> to memref<2xi32>
-
-  // NOP casts
-  %4 = tensor_cast %2 : tensor<2xi32> to tensor<2xi32>
+  // NOP cast
   %5 = memref_cast %3 : memref<2xi32> to memref<2xi32>
-
-  // CHECK-NEXT: return %0, %1 : tensor<2xi32>, memref<2xi32>
-  return %4, %5 : tensor<2xi32>, memref<2xi32>
+  // CHECK-NEXT: return %[[RET]] : memref<2xi32>
+  return %5 : memref<2xi32>
 }
 
 // -----
@@ -1121,61 +1113,12 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso
     yield %1 : index
   // CHECK: : tensor<3x?x5x7x?xindex>
   } : tensor<3x?x?x7x?xindex>
-  // CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
+  // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
   return %0 : tensor<3x?x?x7x?xindex>
 }
 
 // -----
 
-// CHECK-LABEL: @tensor_cast_chain_ok
-// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
-func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
-  // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
-  %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32>
-  %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
-  // CHECK-NEXT: return %[[RES]]
-  return %1 : tensor<4x8xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @tensor_cast_chain_regain
-// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
-func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
-  %0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32>
-  %1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32>
-  // CHECK-NEXT: return %[[IN]]
-  return %1 : tensor<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @tensor_cast_chain_keep
-// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
-func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
-  // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
-  %0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32>
-  // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
-  %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
-  // CHECK-NEXT: return %[[C2]]
-  return %1 : tensor<?x8xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @tensor_cast_chain_invalid
-// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
-func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
-  // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
-  %0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32>
-  // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
-  %1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
-  // CHECK-NEXT: return %[[C2]]
-  return %1 : tensor<8x4xi32>
-}
-
-// -----
-
 // CHECK-LABEL: func @subtensor
 // CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index
 func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) 
@@ -1189,30 +1132,16 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
 
   // CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] :
   // CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32>
-  // CHECK: tensor_cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
+  // CHECK: tensor.cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
   %1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
     : tensor<8x16x4xf32> to tensor<?x?x?xf32>
 
   // Test: subtensor with one dynamic operand can also be folded.
   // CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] :
   // CHECK-SAME: tensor<?x?x?xf32> to tensor<2x?x2xf32>
-  // CHECK: tensor_cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
+  // CHECK: tensor.cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
   %2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1]
     : tensor<?x?x?xf32> to tensor<?x?x?xf32>
 
   return %2 : tensor<?x?x?xf32>
 }
-
-// -----
-
-// CHECK-LABEL: func @extract_from_tensor_cast
-// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
-func @extract_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
-  // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
-  %c0 = constant 0 : index
-  // CHECK-NOT: tensor_cast
-  %casted = tensor_cast %tensor : tensor<*xf32> to tensor<?xf32>
-  // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
-  %result = tensor.extract %casted[%c0] : tensor<?xf32>
-  return %result : f32
-}

diff  --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index b6ac5b8e65fc..4ee10ef62295 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -68,10 +68,10 @@ func @
diff erent_ops() -> (i32, i32) {
 /// types.
 // CHECK-LABEL: @
diff erent_results
 func @
diff erent_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4x?xf32>) {
-  // CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
-  // CHECK-NEXT: %1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
-  %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
-  %1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
+  // CHECK: %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+  // CHECK-NEXT: %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
+  %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+  %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
 
   // CHECK-NEXT: return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
   return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>

diff  --git a/mlir/utils/vim/syntax/mlir.vim b/mlir/utils/vim/syntax/mlir.vim
index 1db630c0223f..88f54eaee293 100644
--- a/mlir/utils/vim/syntax/mlir.vim
+++ b/mlir/utils/vim/syntax/mlir.vim
@@ -40,7 +40,7 @@ syn keyword mlirOps alloc alloca addf addi and call call_indirect cmpf cmpi
 syn keyword mlirOps constant dealloc divf dma_start dma_wait dim exp
 syn keyword mlirOps getTensor index_cast load log memref_cast
 syn keyword mlirOps memref_shape_cast mulf muli negf powf prefetch rsqrt sitofp
-syn keyword mlirOps splat store select sqrt subf subi subview tanh tensor_cast
+syn keyword mlirOps splat store select sqrt subf subi subview tanh
 syn keyword mlirOps view
 
 " Affine ops.


        


More information about the Mlir-commits mailing list