[Mlir-commits] [mlir] 8e01e2e - [mlir][Vector] Fold tensor_cast + vector.transfer_read

Nicolas Vasilache llvmlistbot at llvm.org
Thu Feb 18 12:50:16 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-18T20:47:16Z
New Revision: 8e01e2ec0f3e7a87f49fd3bfc5f18b48e0958213

URL: https://github.com/llvm/llvm-project/commit/8e01e2ec0f3e7a87f49fd3bfc5f18b48e0958213
DIFF: https://github.com/llvm/llvm-project/commit/8e01e2ec0f3e7a87f49fd3bfc5f18b48e0958213.diff

LOG: [mlir][Vector] Fold tensor_cast + vector.transfer_read

Differential Revision: https://reviews.llvm.org/D96988

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index af884f9d6ce6..06f39ecf1e84 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -2408,6 +2409,18 @@ static LogicalResult foldMemRefCast(Operation *op) {
   return success(folded);
 }
 
+static LogicalResult foldTensorCast(Operation *op) {
+  bool folded = false;
+  for (OpOperand &operand : op->getOpOperands()) {
+    auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
+    if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
+      operand.set(castOp.getOperand());
+      folded = true;
+    }
+  }
+  return success(folded);
+}
+
 template <typename TransferOp>
 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
   // TODO: support more aggressive createOrFold on:
@@ -2460,6 +2473,8 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
     return getResult();
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
+  if (succeeded(foldTensorCast(*this)))
+    return getResult();
   return OpFoldResult();
 }
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d665a2b49d5e..d427cb952f09 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -267,6 +267,20 @@ func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) {
 
 // -----
 
+// CHECK-LABEL: cast_transfers
+func @cast_transfers(%A: tensor<4x8xf32>) -> (vector<4x8xf32>) {
+  %c0 = constant 0 : index
+  %f0 = constant 0.0 : f32
+  %0 = tensor.cast %A : tensor<4x8xf32> to tensor<?x?xf32>
+
+  // CHECK: vector.transfer_read %{{.*}} {masked = [false, false]} : tensor<4x8xf32>, vector<4x8xf32>
+  %1 = vector.transfer_read %0[%c0, %c0], %f0 : tensor<?x?xf32>, vector<4x8xf32>
+
+  return %1 : vector<4x8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @insert_extract_transpose_2d(
 //  CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3xf32>,
 //  CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32,


        


More information about the Mlir-commits mailing list