[Mlir-commits] [mlir] 0ee4bf1 - [mlir] Add folding of tensor.cast -> subtensor_insert
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Feb 19 09:25:39 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-19T17:24:16Z
New Revision: 0ee4bf151c0985027e82ec0036a655d68b4d6c37
URL: https://github.com/llvm/llvm-project/commit/0ee4bf151c0985027e82ec0036a655d68b4d6c37
DIFF: https://github.com/llvm/llvm-project/commit/0ee4bf151c0985027e82ec0036a655d68b4d6c37.diff
LOG: [mlir] Add folding of tensor.cast -> subtensor_insert
Differential Revision: https://reviews.llvm.org/D97059
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 98d16c7acfe6..830b682c602b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -61,6 +61,10 @@ namespace tensor {
/// ```
bool canFoldIntoConsumerOp(CastOp castOp);
+/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
+/// that can be folded.
+LogicalResult foldTensorCast(Operation *op);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 046033cc7f9d..081908d38b23 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3790,6 +3790,8 @@ OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
if (getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->source();
+ if (succeeded(tensor::foldTensorCast(*this)))
+ return this->source();
return OpFoldResult();
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9a42224b1158..3da606131a41 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -73,6 +73,20 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
return true;
}
+/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
+/// that can be folded.
+LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
+ if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
+ operand.set(castOp.getOperand());
+ folded = true;
+ }
+ }
+ return success(folded);
+}
+
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index ff5ca24f7587..b887e90e931b 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -237,3 +237,18 @@ func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32x
%1 = subtensor %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
return %1 : tensor<16x32xi8>
}
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing_subtensor_insert_of_cast
+// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
+// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+// CHECK: %[[S:.+]] = subtensor_insert %[[A]] into %[[B]][0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
+// Tensor cast is folded away.
+// CHECK-NOT: tensor.cast
+// CHECK: return %[[S]] : tensor<4x6x16x32xi8>
+func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
+ %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
+ %res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
+ return %res : tensor<4x6x16x32xi8>
+}
More information about the Mlir-commits
mailing list