[Mlir-commits] [mlir] 67b1053 - [mlir][Vector] Allow a 0-d for for vector transfer ops.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Oct 12 05:10:19 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-12T11:48:42Z
New Revision: 67b10532c637b22c0926517d27f84759893a7258
URL: https://github.com/llvm/llvm-project/commit/67b10532c637b22c0926517d27f84759893a7258
DIFF: https://github.com/llvm/llvm-project/commit/67b10532c637b22c0926517d27f84759893a7258.diff
LOG: [mlir][Vector] Allow a 0-d for for vector transfer ops.
This revision updates the op semantics, printer, parser and verifier to allow 0-d transfers.
Until 0-d vectors are available, such transfers have a special form that transits through vector<1xt>.
This is a stepping stone towards the longer term work of adding 0-d vectors and will help significantly reduce corner cases in vectorization.
Transformations and lowerings do not yet support this form, extensions will follow.
Differential Revision: https://reviews.llvm.org/D111559
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/Interfaces/VectorInterfaces.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorUtils.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index c334773d6654e..f48ef35cdab07 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1275,6 +1275,11 @@ def Vector_TransferReadOp :
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0
{permutation_map = (d0, d1)->(d0, d1)}
: tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+
+ // Special encoding for 0-d transfer with 0-d tensor/memref, vector shape
+ // {1} and permutation_map () -> (0).
+ %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} :
+ tensor<f32>, vector<1xf32>
```
}];
@@ -1402,6 +1407,11 @@ def Vector_TransferWriteOp :
%5 = vector.transfer_write %4, %arg1[%c3, %c3]
{permutation_map = (d0, d1)->(d0, d1)}
: vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+
+ // Special encoding for 0-d transfer with 0-d tensor/memref, vector shape
+ // {1} and permutation_map () -> (0).
+ %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
+ vector<1xf32>, tensor<f32>
```
}];
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 27c302aa37682..c713f1806d1f1 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -114,6 +114,29 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodBody=*/"return $_op.permutation_map();"
/*defaultImplementation=*/
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if op involves a 0-d tensor/memref and a vector
+ of shape {1}. This is temporary until we have 0-d vectors.
+ // TODO: turn this into 0-d vectors + empty permutation_map.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isZeroD",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (getShapedType().getRank() > 0)
+ return false;
+ if (getVectorType().getShape() != ArrayRef<int64_t>{1})
+ return false;
+ AffineMap map = AffineMap::get(
+ /*numDims=*/0, /*numSymbols=*/0,
+ getAffineConstantExpr(0, $_op->getContext()));
+ if ($_op.permutation_map() != map)
+ return false;
+ return true;
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{ Returns true if the specified dimension is a broadcast. }],
/*retTy=*/"bool",
@@ -134,6 +157,9 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
+ // 0-d transfers are not considered broadcasts but they need to be
+ // represented with a vector<1xt> until we have 0-d vectors.
+ if ($_op.isZeroD()) return false;
for (unsigned i = 0; i < $_op.permutation_map().getNumResults(); ++i) {
if ($_op.isBroadcastDim(i))
return true;
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 879996a041bf9..09d1bdc5349c2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2292,11 +2292,14 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
return success();
}
-static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
- VectorType vectorType,
- VectorType maskType,
- AffineMap permutationMap,
- ArrayAttr inBounds) {
+static LogicalResult
+verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
+ VectorType vectorType, VectorType maskType,
+ AffineMap permutationMap, ArrayAttr inBounds) {
+ if (shapedType.getRank() == 0 && !op.isZeroD())
+ return op->emitOpError("0-d transfer requires vector<1xt> shape and () -> "
+ "(0) permutation_map");
+
if (op->hasAttr("masked")) {
return op->emitOpError("masked attribute has been removed. "
"Use in_bounds instead.");
@@ -2358,7 +2361,8 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
if (permutationMap.getNumSymbols() != 0)
return op->emitOpError("requires permutation_map without symbols");
- if (permutationMap.getNumInputs() != shapedType.getRank())
+ // TODO: implement 0-d vector corner cases.
+ if (!op.isZeroD() && permutationMap.getNumInputs() != shapedType.getRank())
return op->emitOpError("requires a permutation_map with input dims of the "
"same rank as the source type");
@@ -2534,9 +2538,10 @@ static LogicalResult verify(TransferReadOp op) {
if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
return op.emitOpError("requires ") << shapedType.getRank() << " indices";
- if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
- maskType, permutationMap,
- op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
+ if (failed(
+ verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
+ shapedType, vectorType, maskType, permutationMap,
+ op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
return failure();
if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
@@ -2609,6 +2614,9 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
template <typename TransferOp>
static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
+ // TODO: Be less conservative once we have 0-d vectors.
+ if (op.isZeroD())
+ return failure();
AffineMap permutationMap = op.permutation_map();
bool changed = false;
SmallVector<bool, 4> newInBounds;
@@ -2885,9 +2893,10 @@ static LogicalResult verify(TransferWriteOp op) {
if (op.hasBroadcastDim())
return op.emitOpError("should not have broadcast dimensions");
- if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
- maskType, permutationMap,
- op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
+ if (failed(
+ verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
+ shapedType, vectorType, maskType, permutationMap,
+ op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
return failure();
return verifyPermutationMap(permutationMap,
diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 82018c2ccdd3b..d76b93f872918 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -239,6 +239,13 @@ AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType,
shapedType.getElementType().dyn_cast<VectorType>();
if (elementVectorType)
elementVectorRank += elementVectorType.getRank();
+ // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
+ // TODO: replace once we have 0-d vectors.
+ if (shapedType.getRank() == 0 &&
+ vectorType.getShape() == ArrayRef<int64_t>{1})
+ return AffineMap::get(
+ /*numDims=*/0, /*numSymbols=*/0,
+ getAffineConstantExpr(0, shapedType.getContext()));
return AffineMap::getMinorIdentityMap(
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
shapedType.getContext());
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 26845172e1a6d..4e811a768d70e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1394,3 +1394,16 @@ func @insert_map_id(%v: vector<2x1xf32>, %v1: vector<4x32xf32>, %id : index) {
// expected-error at +1 {{'vector.insert_map' op expected number of ids must match the number of dimensions distributed}}
%0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32>
}
+
+// -----
+
+func @vector_transfer_ops_0d(%arg0: tensor<f32>)
+ -> tensor<f32> {
+ %f0 = constant 0.0 : f32
+ // expected-error at +1 {{0-d transfer requires vector<1xt> shape and () -> (0) permutation_map}}
+ %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<(d0)->(d0)>} :
+ tensor<f32>, vector<1xf32>
+ %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
+ vector<1xf32>, tensor<f32>
+ return %1: tensor<f32>
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6f715ce95ba2f..a8ec95e533f6a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1,5 +1,20 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// CHECK-LABEL: func @vector_transfer_ops_0d(
+func @vector_transfer_ops_0d(%arg0: tensor<f32>, %arg1: memref<f32>)
+ -> tensor<f32> {
+ %f0 = constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} :
+ tensor<f32>, vector<1xf32>
+ %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
+ vector<1xf32>, tensor<f32>
+ %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->(0)>} :
+ memref<f32>, vector<1xf32>
+ vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->(0)>} :
+ vector<1xf32>, memref<f32>
+ return %1: tensor<f32>
+}
+
// CHECK-LABEL: func @vector_transfer_ops(
func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%arg1 : memref<?x?xvector<4x3xf32>>,
More information about the Mlir-commits
mailing list