[Mlir-commits] [mlir] 8e01e2e - [mlir][Vector] Fold tensor_cast + vector.transfer_read
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Feb 18 12:50:16 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-18T20:47:16Z
New Revision: 8e01e2ec0f3e7a87f49fd3bfc5f18b48e0958213
URL: https://github.com/llvm/llvm-project/commit/8e01e2ec0f3e7a87f49fd3bfc5f18b48e0958213
DIFF: https://github.com/llvm/llvm-project/commit/8e01e2ec0f3e7a87f49fd3bfc5f18b48e0958213.diff
LOG: [mlir][Vector] Fold tensor_cast + vector.transfer_read
Differential Revision: https://reviews.llvm.org/D96988
Added:
Modified:
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index af884f9d6ce6..06f39ecf1e84 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -2408,6 +2409,18 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}
+static LogicalResult foldTensorCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
+ if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
+ operand.set(castOp.getOperand());
+ folded = true;
+ }
+ }
+ return success(folded);
+}
+
template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
// TODO: support more aggressive createOrFold on:
@@ -2460,6 +2473,8 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
return getResult();
if (succeeded(foldMemRefCast(*this)))
return getResult();
+ if (succeeded(foldTensorCast(*this)))
+ return getResult();
return OpFoldResult();
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d665a2b49d5e..d427cb952f09 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -267,6 +267,20 @@ func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) {
// -----
+// CHECK-LABEL: cast_transfers
+func @cast_transfers(%A: tensor<4x8xf32>) -> (vector<4x8xf32>) {
+ %c0 = constant 0 : index
+ %f0 = constant 0.0 : f32
+ %0 = tensor.cast %A : tensor<4x8xf32> to tensor<?x?xf32>
+
+ // CHECK: vector.transfer_read %{{.*}} {masked = [false, false]} : tensor<4x8xf32>, vector<4x8xf32>
+ %1 = vector.transfer_read %0[%c0, %c0], %f0 : tensor<?x?xf32>, vector<4x8xf32>
+
+ return %1 : vector<4x8xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @insert_extract_transpose_2d(
// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3xf32>,
// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32,
More information about the Mlir-commits
mailing list