[Mlir-commits] [mlir] b56bf30 - [mlir][Vector] Add folding of memref_cast into vector_transfer ops

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jun 5 10:30:56 PDT 2020


Author: Nicolas Vasilache
Date: 2020-06-05T13:27:00-04:00
New Revision: b56bf30d3cc15896956061fdbeb6d078b63ec91f

URL: https://github.com/llvm/llvm-project/commit/b56bf30d3cc15896956061fdbeb6d078b63ec91f
DIFF: https://github.com/llvm/llvm-project/commit/b56bf30d3cc15896956061fdbeb6d078b63ec91f.diff

LOG: [mlir][Vector] Add folding of memref_cast into vector_transfer ops

Summary:
This revision adds a common folding pattern that starts appearing on
vector_transfer ops.

Differential Revision: https://reviews.llvm.org/D81281

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 365795fb9cab..9ae1c74df9e9 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1061,6 +1061,8 @@ def Vector_TransferReadOp :
         return impl::getTransferMinorIdentityMap(memRefType, vectorType);
       }
   }];
+
+  let hasFolder = 1;
 }
 
 def Vector_TransferWriteOp :
@@ -1150,6 +1152,8 @@ def Vector_TransferWriteOp :
         return impl::getTransferMinorIdentityMap(memRefType, vectorType);
       }
   }];
+
+  let hasFolder = 1;
 }
 
 def Vector_ShapeCastOp :

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 21b62ceaa689..019f5fd94621 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1498,6 +1498,30 @@ static LogicalResult verify(TransferReadOp op) {
                               [&op](Twine t) { return op.emitOpError(t); });
 }
 
+/// This is a common class used for patterns of the form
+/// ```
+///    someop(memrefcast) -> someop
+/// ```
+/// It folds the source of the memref_cast into the root operation directly.
+static LogicalResult foldMemRefCast(Operation *op) {
+  bool folded = false;
+  for (OpOperand &operand : op->getOpOperands()) {
+    auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
+    if (castOp && canFoldIntoConsumerOp(castOp)) {
+      operand.set(castOp.getOperand());
+      folded = true;
+    }
+  }
+  return success(folded);
+}
+
+OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
+  /// transfer_read(memrefcast) -> transfer_read
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
@@ -1583,6 +1607,11 @@ static LogicalResult verify(TransferWriteOp op) {
                               [&op](Twine t) { return op.emitOpError(t); });
 }
 
+LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
+                                    SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index db504bf66ca0..5e4ba39895ed 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -159,3 +159,19 @@ func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
   // CHECK-NEXT: return [[ADD]]
   return %7 : vector<4x3x2xf32>
 }
+
+// -----
+
+// CHECK-LABEL: cast_transfers
+func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) {
+  %c0 = constant 0 : index
+  %f0 = constant 0.0 : f32
+  %0 = memref_cast %A : memref<4x8xf32> to memref<?x?xf32>
+
+  // CHECK: vector.transfer_read %{{.*}} : memref<4x8xf32>, vector<4x8xf32>
+  %1 = vector.transfer_read %0[%c0, %c0], %f0 : memref<?x?xf32>, vector<4x8xf32>
+
+  // CHECK: vector.transfer_write %{{.*}} : vector<4x8xf32>, memref<4x8xf32>
+  vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref<?x?xf32>
+  return %1 : vector<4x8xf32>
+}


        


More information about the Mlir-commits mailing list