[Mlir-commits] [mlir] 67b1053 - [mlir][Vector] Allow a 0-d for for vector transfer ops.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Oct 12 05:10:19 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-12T11:48:42Z
New Revision: 67b10532c637b22c0926517d27f84759893a7258

URL: https://github.com/llvm/llvm-project/commit/67b10532c637b22c0926517d27f84759893a7258
DIFF: https://github.com/llvm/llvm-project/commit/67b10532c637b22c0926517d27f84759893a7258.diff

LOG: [mlir][Vector] Allow a 0-d for for vector transfer ops.

This revision updates the op semantics, printer, parser and verifier to allow 0-d transfers.
Until 0-d vectors are available, such transfers have a special form that transits through vector<1xt>.
This is a stepping stone towards the longer term work of adding 0-d vectors and will help significantly reduce corner cases in vectorization.

Transformations and lowerings do not yet support this form, extensions will follow.

Differential Revision: https://reviews.llvm.org/D111559

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index c334773d6654e..f48ef35cdab07 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1275,6 +1275,11 @@ def Vector_TransferReadOp :
     %4 = vector.transfer_read %arg1[%c3, %c3], %vf0
       {permutation_map = (d0, d1)->(d0, d1)}
         : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+
+    // Special encoding for 0-d transfer with 0-d tensor/memref, vector shape
+    // {1} and permutation_map () -> (0).
+    %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} :
+      tensor<f32>, vector<1xf32>
     ```
   }];
 
@@ -1402,6 +1407,11 @@ def Vector_TransferWriteOp :
     %5 = vector.transfer_write %4, %arg1[%c3, %c3]
       {permutation_map = (d0, d1)->(d0, d1)}
         : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+
+    // Special encoding for 0-d transfer with 0-d tensor/memref, vector shape
+    // {1} and permutation_map () -> (0).
+    %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
+      vector<1xf32>, tensor<f32>
     ```
   }];
 

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 27c302aa37682..c713f1806d1f1 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -114,6 +114,29 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodBody=*/"return $_op.permutation_map();"
       /*defaultImplementation=*/
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if op involves a 0-d tensor/memref and a vector
+        of shape {1}. This is temporary until we have 0-d vectors.
+        // TODO: turn this into 0-d vectors + empty permutation_map.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isZeroD",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        if (getShapedType().getRank() > 0)
+          return false;
+        if (getVectorType().getShape() != ArrayRef<int64_t>{1})
+          return false;
+        AffineMap map = AffineMap::get(
+          /*numDims=*/0, /*numSymbols=*/0,
+          getAffineConstantExpr(0, $_op->getContext()));
+        if ($_op.permutation_map() != map)
+          return false;
+        return true;
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{ Returns true if the specified dimension is a broadcast. }],
       /*retTy=*/"bool",
@@ -134,6 +157,9 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
+        // 0-d transfers are not considered broadcasts but they need to be 
+        // represented with a vector<1xt> until we have 0-d vectors.
+        if ($_op.isZeroD()) return false;
         for (unsigned i = 0; i < $_op.permutation_map().getNumResults(); ++i) {
           if ($_op.isBroadcastDim(i))
             return true;

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 879996a041bf9..09d1bdc5349c2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2292,11 +2292,14 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
   return success();
 }
 
-static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
-                                      VectorType vectorType,
-                                      VectorType maskType,
-                                      AffineMap permutationMap,
-                                      ArrayAttr inBounds) {
+static LogicalResult
+verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
+                 VectorType vectorType, VectorType maskType,
+                 AffineMap permutationMap, ArrayAttr inBounds) {
+  if (shapedType.getRank() == 0 && !op.isZeroD())
+    return op->emitOpError("0-d transfer requires vector<1xt> shape and () -> "
+                           "(0) permutation_map");
+
   if (op->hasAttr("masked")) {
     return op->emitOpError("masked attribute has been removed. "
                            "Use in_bounds instead.");
@@ -2358,7 +2361,8 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
 
   if (permutationMap.getNumSymbols() != 0)
     return op->emitOpError("requires permutation_map without symbols");
-  if (permutationMap.getNumInputs() != shapedType.getRank())
+  // TODO: implement 0-d vector corner cases.
+  if (!op.isZeroD() && permutationMap.getNumInputs() != shapedType.getRank())
     return op->emitOpError("requires a permutation_map with input dims of the "
                            "same rank as the source type");
 
@@ -2534,9 +2538,10 @@ static LogicalResult verify(TransferReadOp op) {
   if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
     return op.emitOpError("requires ") << shapedType.getRank() << " indices";
 
-  if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
-                              maskType, permutationMap,
-                              op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
+  if (failed(
+          verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
+                           shapedType, vectorType, maskType, permutationMap,
+                           op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
     return failure();
 
   if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
@@ -2609,6 +2614,9 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
 
 template <typename TransferOp>
 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
+  // TODO: Be less conservative once we have 0-d vectors.
+  if (op.isZeroD())
+    return failure();
   AffineMap permutationMap = op.permutation_map();
   bool changed = false;
   SmallVector<bool, 4> newInBounds;
@@ -2885,9 +2893,10 @@ static LogicalResult verify(TransferWriteOp op) {
   if (op.hasBroadcastDim())
     return op.emitOpError("should not have broadcast dimensions");
 
-  if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
-                              maskType, permutationMap,
-                              op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
+  if (failed(
+          verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
+                           shapedType, vectorType, maskType, permutationMap,
+                           op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
     return failure();
 
   return verifyPermutationMap(permutationMap,

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 82018c2ccdd3b..d76b93f872918 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -239,6 +239,13 @@ AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType,
       shapedType.getElementType().dyn_cast<VectorType>();
   if (elementVectorType)
     elementVectorRank += elementVectorType.getRank();
+  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
+  // TODO: replace once we have 0-d vectors.
+  if (shapedType.getRank() == 0 &&
+      vectorType.getShape() == ArrayRef<int64_t>{1})
+    return AffineMap::get(
+        /*numDims=*/0, /*numSymbols=*/0,
+        getAffineConstantExpr(0, shapedType.getContext()));
   return AffineMap::getMinorIdentityMap(
       shapedType.getRank(), vectorType.getRank() - elementVectorRank,
       shapedType.getContext());

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 26845172e1a6d..4e811a768d70e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1394,3 +1394,16 @@ func @insert_map_id(%v: vector<2x1xf32>, %v1: vector<4x32xf32>, %id : index) {
   // expected-error at +1 {{'vector.insert_map' op expected number of ids must match the number of dimensions distributed}}
   %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32>
 }
+
+// -----
+
+func @vector_transfer_ops_0d(%arg0: tensor<f32>)
+  -> tensor<f32> {
+    %f0 = constant 0.0 : f32
+    // expected-error at +1 {{0-d transfer requires vector<1xt> shape and () -> (0) permutation_map}}
+    %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<(d0)->(d0)>} :
+      tensor<f32>, vector<1xf32>
+    %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
+      vector<1xf32>, tensor<f32>
+    return %1: tensor<f32>
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6f715ce95ba2f..a8ec95e533f6a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1,5 +1,20 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
+// CHECK-LABEL: func @vector_transfer_ops_0d(
+func @vector_transfer_ops_0d(%arg0: tensor<f32>, %arg1: memref<f32>)
+  -> tensor<f32> {
+    %f0 = constant 0.0 : f32
+    %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} :
+      tensor<f32>, vector<1xf32>
+    %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
+      vector<1xf32>, tensor<f32>
+    %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->(0)>} :
+      memref<f32>, vector<1xf32>
+    vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->(0)>} :
+      vector<1xf32>, memref<f32>
+    return %1: tensor<f32>
+}
+
 // CHECK-LABEL: func @vector_transfer_ops(
 func @vector_transfer_ops(%arg0: memref<?x?xf32>,
                           %arg1 : memref<?x?xvector<4x3xf32>>,


        


More information about the Mlir-commits mailing list