[Mlir-commits] [mlir] 0ee4bf1 - [mlir] Add folding of tensor.cast -> subtensor_insert

Nicolas Vasilache llvmlistbot at llvm.org
Fri Feb 19 09:25:39 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-19T17:24:16Z
New Revision: 0ee4bf151c0985027e82ec0036a655d68b4d6c37

URL: https://github.com/llvm/llvm-project/commit/0ee4bf151c0985027e82ec0036a655d68b4d6c37
DIFF: https://github.com/llvm/llvm-project/commit/0ee4bf151c0985027e82ec0036a655d68b4d6c37.diff

LOG: [mlir] Add folding of tensor.cast -> subtensor_insert

Differential Revision: https://reviews.llvm.org/D97059

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 98d16c7acfe6..830b682c602b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -61,6 +61,10 @@ namespace tensor {
 /// ```
 bool canFoldIntoConsumerOp(CastOp castOp);
 
+/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
+/// that can be folded.
+LogicalResult foldTensorCast(Operation *op);
+
 } // namespace tensor
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 046033cc7f9d..081908d38b23 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3790,6 +3790,8 @@ OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
   if (getSourceType() == getType() &&
       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
     return this->source();
+  if (succeeded(tensor::foldTensorCast(*this)))
+    return this->source();
   return OpFoldResult();
 }
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9a42224b1158..3da606131a41 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -73,6 +73,20 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
   return true;
 }
 
+/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
+/// that can be folded.
+LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
+  bool folded = false;
+  for (OpOperand &operand : op->getOpOperands()) {
+    auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
+    if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
+      operand.set(castOp.getOperand());
+      folded = true;
+    }
+  }
+  return success(folded);
+}
+
 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (inputs.size() != 1 || outputs.size() != 1)
     return false;

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index ff5ca24f7587..b887e90e931b 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -237,3 +237,18 @@ func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32x
   %1 = subtensor %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
   return %1 : tensor<16x32xi8>
 }
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing_subtensor_insert_of_cast
+//  CHECK-SAME:   %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
+//  CHECK-SAME:   %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+//       CHECK:   %[[S:.+]] = subtensor_insert %[[A]] into %[[B]][0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
+// Tensor cast is folded away.
+//   CHECK-NOT:   tensor.cast
+//       CHECK:   return %[[S]] : tensor<4x6x16x32xi8>
+func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
+  %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
+  %res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
+  return %res : tensor<4x6x16x32xi8>
+}


        


More information about the Mlir-commits mailing list