[Mlir-commits] [mlir] b56bf30 - [mlir][Vector] Add folding of memref_cast into vector_transfer ops
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jun 5 10:30:56 PDT 2020
Author: Nicolas Vasilache
Date: 2020-06-05T13:27:00-04:00
New Revision: b56bf30d3cc15896956061fdbeb6d078b63ec91f
URL: https://github.com/llvm/llvm-project/commit/b56bf30d3cc15896956061fdbeb6d078b63ec91f
DIFF: https://github.com/llvm/llvm-project/commit/b56bf30d3cc15896956061fdbeb6d078b63ec91f.diff
LOG: [mlir][Vector] Add folding of memref_cast into vector_transfer ops
Summary:
This revision adds a common folding pattern that starts appearing on
vector_transfer ops.
Differential Revision: https://reviews.llvm.org/D81281
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 365795fb9cab..9ae1c74df9e9 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1061,6 +1061,8 @@ def Vector_TransferReadOp :
return impl::getTransferMinorIdentityMap(memRefType, vectorType);
}
}];
+
+ let hasFolder = 1;
}
def Vector_TransferWriteOp :
@@ -1150,6 +1152,8 @@ def Vector_TransferWriteOp :
return impl::getTransferMinorIdentityMap(memRefType, vectorType);
}
}];
+
+ let hasFolder = 1;
}
def Vector_ShapeCastOp :
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 21b62ceaa689..019f5fd94621 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1498,6 +1498,30 @@ static LogicalResult verify(TransferReadOp op) {
[&op](Twine t) { return op.emitOpError(t); });
}
+/// This is a common class used for patterns of the form
+/// ```
+/// someop(memrefcast) -> someop
+/// ```
+/// It folds the source of the memref_cast into the root operation directly.
+static LogicalResult foldMemRefCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
+ if (castOp && canFoldIntoConsumerOp(castOp)) {
+ operand.set(castOp.getOperand());
+ folded = true;
+ }
+ }
+ return success(folded);
+}
+
+OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
+ /// transfer_read(memrefcast) -> transfer_read
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return OpFoldResult();
+}
+
//===----------------------------------------------------------------------===//
// TransferWriteOp
//===----------------------------------------------------------------------===//
@@ -1583,6 +1607,11 @@ static LogicalResult verify(TransferWriteOp op) {
[&op](Twine t) { return op.emitOpError(t); });
}
+LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+
//===----------------------------------------------------------------------===//
// ShapeCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index db504bf66ca0..5e4ba39895ed 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -159,3 +159,19 @@ func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
// CHECK-NEXT: return [[ADD]]
return %7 : vector<4x3x2xf32>
}
+
+// -----
+
+// CHECK-LABEL: cast_transfers
+func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) {
+ %c0 = constant 0 : index
+ %f0 = constant 0.0 : f32
+ %0 = memref_cast %A : memref<4x8xf32> to memref<?x?xf32>
+
+ // CHECK: vector.transfer_read %{{.*}} : memref<4x8xf32>, vector<4x8xf32>
+ %1 = vector.transfer_read %0[%c0, %c0], %f0 : memref<?x?xf32>, vector<4x8xf32>
+
+ // CHECK: vector.transfer_write %{{.*}} : vector<4x8xf32>, memref<4x8xf32>
+ vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref<?x?xf32>
+ return %1 : vector<4x8xf32>
+}
More information about the Mlir-commits
mailing list