[Mlir-commits] [mlir] 6ad31c0 - [mlir][vector] Support N-D vector in InsertMap/ExtractMap op

Thomas Raoux llvmlistbot at llvm.org
Fri Nov 13 12:40:44 PST 2020


Author: Thomas Raoux
Date: 2020-11-13T12:40:17-08:00
New Revision: 6ad31c0f4a6192e216ac9137dc093a8268ae11b1

URL: https://github.com/llvm/llvm-project/commit/6ad31c0f4a6192e216ac9137dc093a8268ae11b1
DIFF: https://github.com/llvm/llvm-project/commit/6ad31c0f4a6192e216ac9137dc093a8268ae11b1.diff

LOG: [mlir][vector] Support N-D vector in InsertMap/ExtractMap op

Support multi-dimension vector for InsertMap/ExtractMap op and update the
transformations. Currently the relation between IDs and dimension is implicitly
deduced from the types. We can then calculate an AffineMap based on it. In the
future the AffineMap could be part of the operation itself.

Differential Revision: https://reviews.llvm.org/D90995

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-distribution.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 2ddd06ccf44f..76b98c1f3f36 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -457,11 +457,21 @@ def Vector_ExtractSlicesOp :
 
 def Vector_ExtractMapOp :
   Vector_Op<"extract_map", [NoSideEffect]>,
-    Arguments<(ins AnyVector:$vector, Index:$id)>,
+    Arguments<(ins AnyVector:$vector, Variadic<Index>:$ids)>,
     Results<(outs AnyVector)> {
   let summary = "vector extract map operation";
   let description = [{
-    Takes an 1-D vector and extracts a sub-part of the vector starting at id.
+    Takes an N-D vector and extracts a sub-part of the vector starting at id
+    along each dimension.
+
+    The dimension associated to each element of `ids` used to extract are
+    implicitly deduced from the the destination type. For each dimension the
+    multiplicity is the destination dimension size divided by the source
+    dimension size, each dimension with a multiplicity greater than 1 is
+    associated to the next id, following ids order.
+    For example if the source type is `vector<64x4x32xf32>` and the destination
+    type is `vector<4x4x2xf32>`, the first id maps to dimension 0 and the second
+    id to dimension 2.
 
     Similarly to vector.tuple_get, this operation is used for progressive
     lowering and should be folded away before converting to LLVM.
@@ -488,10 +498,14 @@ def Vector_ExtractMapOp :
 
     ```mlir
     %ev = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
+    %ev1 = vector.extract_map %v1[%id1, %id2] : vector<64x4x32xf32>
+      to vector<4x4x2xf32>
     ```
   }];
   let builders = [
-    OpBuilderDAG<(ins "Value":$vector, "Value":$id, "int64_t":$multiplicity)>];
+    OpBuilderDAG<(ins "Value":$vector, "ValueRange":$ids,
+                  "ArrayRef<int64_t>":$multiplicity,
+                  "AffineMap":$map)>];
   let extraClassDeclaration = [{
     VectorType getSourceVectorType() {
       return vector().getType().cast<VectorType>();
@@ -499,13 +513,11 @@ def Vector_ExtractMapOp :
     VectorType getResultType() {
       return getResult().getType().cast<VectorType>();
     }
-    int64_t multiplicity() {
-      return getSourceVectorType().getNumElements() /
-        getResultType().getNumElements();
-    }
+    void getMultiplicity(SmallVectorImpl<int64_t> &multiplicity);
+    AffineMap map();
   }];
   let assemblyFormat = [{
-    $vector `[` $id `]` attr-dict `:` type($vector) `to` type(results)
+    $vector `[` $ids `]` attr-dict `:` type($vector) `to` type(results)
   }];
 
   let hasFolder = 1;
@@ -686,13 +698,19 @@ def Vector_InsertSlicesOp :
 
 def Vector_InsertMapOp :
   Vector_Op<"insert_map", [NoSideEffect, AllTypesMatch<["dest", "result"]>]>,
-    Arguments<(ins AnyVector:$vector, AnyVector:$dest, Index:$id)>,
+    Arguments<(ins AnyVector:$vector, AnyVector:$dest, Variadic<Index>:$ids)>,
     Results<(outs AnyVector:$result)> {
   let summary = "vector insert map operation";
   let description = [{
-    Inserts a 1-D vector and within a larger vector starting at id. The new
+    Inserts a N-D vector and within a larger vector starting at id. The new
     vector created will have the same size as the destination operand vector.
 
+    The dimension associated to each element of `ids` used to insert is
+    implicitly deduced from the source type (see `ExtractMapOp` for details).
+    For example if source type is `vector<4x4x2xf32>` and the destination type
+    is `vector<64x4x32xf32>`, the first id maps to dimension 0 and the second id
+    to dimension 2.
+
     Similarly to vector.tuple_get, this operation is used for progressive
     lowering and should be folded away before converting to LLVM.
 
@@ -723,10 +741,12 @@ def Vector_InsertMapOp :
 
     ```mlir
     %v = vector.insert_map %ev %v[%id] : vector<1xf32> into vector<32xf32>
+    %v1 = vector.insert_map %ev1, %v1[%arg0, %arg1] : vector<2x4x1xf32>
+      into vector<64x4x32xf32>
     ```
   }];
   let builders = [OpBuilderDAG<(ins "Value":$vector, "Value":$dest,
-                                "Value":$id, "int64_t":$multiplicity)>];
+                                "ValueRange":$ids)>];
   let extraClassDeclaration = [{
     VectorType getSourceVectorType() {
       return vector().getType().cast<VectorType>();
@@ -734,13 +754,11 @@ def Vector_InsertMapOp :
     VectorType getResultType() {
       return getResult().getType().cast<VectorType>();
     }
-    int64_t multiplicity() {
-      return getResultType().getNumElements() /
-        getSourceVectorType().getNumElements();
-    }
+    // Return a map indicating the dimension mapping to the given ids.
+    AffineMap map();
   }];
   let assemblyFormat = [{
-    $vector `,` $dest `[` $id `]` attr-dict
+    $vector `,` $dest `[` $ids `]` attr-dict
       `:` type($vector) `into` type($result)
   }];
 }

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 73282ad2a4d7..22f580642a2f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -231,7 +231,7 @@ struct DistributeOps {
   InsertMapOp insert;
 };
 
-/// Distribute a 1D vector pointwise operation over a range of given IDs taking
+/// Distribute a N-D vector pointwise operation over a range of given ids taking
 /// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or
 /// SPMD id). This transformation only inserts
 /// vector.extract_map/vector.insert_map. It is meant to be used with
@@ -243,9 +243,10 @@ struct DistributeOps {
 /// %v = addf %a, %b : vector<32xf32>
 /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
 /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32>
-Optional<DistributeOps> distributPointwiseVectorOp(OpBuilder &builder,
-                                                   Operation *op, Value id,
-                                                   int64_t multiplicity);
+Optional<DistributeOps>
+distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
+                           ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
+                           const AffineMap &map);
 /// Canonicalize an extra element using the result of a pointwise operation.
 /// Transforms:
 /// %v = addf %a, %b : vector32xf32>

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 04b8b757a14b..39aed718de0a 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -999,33 +999,79 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
 //===----------------------------------------------------------------------===//
 
 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
-                         Value vector, Value id, int64_t multiplicity) {
+                         Value vector, ValueRange ids,
+                         ArrayRef<int64_t> multiplicity,
+                         AffineMap permutationMap) {
+  assert(ids.size() == multiplicity.size() &&
+         ids.size() == permutationMap.getNumResults());
+  assert(permutationMap.isProjectedPermutation());
   VectorType type = vector.getType().cast<VectorType>();
-  VectorType resultType = VectorType::get(type.getNumElements() / multiplicity,
-                                          type.getElementType());
-  ExtractMapOp::build(builder, result, resultType, vector, id);
+  SmallVector<int64_t, 4> newShape(type.getShape().begin(),
+                                   type.getShape().end());
+  for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
+    AffineExpr expr = permutationMap.getResult(i);
+    auto dim = expr.cast<AffineDimExpr>();
+    newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
+  }
+  VectorType resultType = VectorType::get(newShape, type.getElementType());
+  ExtractMapOp::build(builder, result, resultType, vector, ids);
 }
 
 static LogicalResult verify(ExtractMapOp op) {
-  if (op.getSourceVectorType().getShape().size() != 1 ||
-      op.getResultType().getShape().size() != 1)
-    return op.emitOpError("expects source and destination vectors of rank 1");
-  if (op.getSourceVectorType().getNumElements() %
-          op.getResultType().getNumElements() !=
-      0)
+  if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
     return op.emitOpError(
-        "source vector size must be a multiple of destination vector size");
+        "expected source and destination vectors of same rank");
+  unsigned numId = 0;
+  for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
+    if (op.getSourceVectorType().getDimSize(i) %
+            op.getResultType().getDimSize(i) !=
+        0)
+      return op.emitOpError("source vector dimensions must be a multiple of "
+                            "destination vector dimensions");
+    if (op.getSourceVectorType().getDimSize(i) !=
+        op.getResultType().getDimSize(i))
+      numId++;
+  }
+  if (numId != op.ids().size())
+    return op.emitOpError("expected number of ids must match the number of "
+                          "dimensions distributed");
   return success();
 }
 
 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
   auto insert = vector().getDefiningOp<vector::InsertMapOp>();
-  if (insert == nullptr || multiplicity() != insert.multiplicity() ||
-      id() != insert.id())
+  if (insert == nullptr || getType() != insert.vector().getType() ||
+      ids() != insert.ids())
     return {};
   return insert.vector();
 }
 
+void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
+  assert(multiplicity.empty());
+  for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
+    if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
+      multiplicity.push_back(getSourceVectorType().getDimSize(i) /
+                             getResultType().getDimSize(i));
+  }
+}
+
+template <typename MapOp>
+AffineMap calculateImplicitMap(MapOp op) {
+  SmallVector<AffineExpr, 4> perm;
+  // Check which dimension have a multiplicity greater than 1 and associated
+  // them to the IDs in order.
+  for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
+    if (op.getSourceVectorType().getDimSize(i) !=
+        op.getResultType().getDimSize(i))
+      perm.push_back(getAffineDimExpr(i, op.getContext()));
+  }
+  auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
+                            op.getContext());
+  return map;
+}
+
+AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
@@ -1253,26 +1299,33 @@ void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
 //===----------------------------------------------------------------------===//
 
 void InsertMapOp::build(OpBuilder &builder, OperationState &result,
-                        Value vector, Value dest, Value id,
-                        int64_t multiplicity) {
-  VectorType type = vector.getType().cast<VectorType>();
-  VectorType resultType = VectorType::get(type.getNumElements() * multiplicity,
-                                          type.getElementType());
-  InsertMapOp::build(builder, result, resultType, vector, dest, id);
+                        Value vector, Value dest, ValueRange ids) {
+  InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
 }
 
 static LogicalResult verify(InsertMapOp op) {
-  if (op.getSourceVectorType().getShape().size() != 1 ||
-      op.getResultType().getShape().size() != 1)
-    return op.emitOpError("expected source and destination vectors of rank 1");
-  if (op.getResultType().getNumElements() %
-          op.getSourceVectorType().getNumElements() !=
-      0)
+  if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
     return op.emitOpError(
-        "destination vector size must be a multiple of source vector size");
+        "expected source and destination vectors of same rank");
+  unsigned numId = 0;
+  for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
+    if (op.getResultType().getDimSize(i) %
+            op.getSourceVectorType().getDimSize(i) !=
+        0)
+      return op.emitOpError(
+          "destination vector size must be a multiple of source vector size");
+    if (op.getResultType().getDimSize(i) !=
+        op.getSourceVectorType().getDimSize(i))
+      numId++;
+  }
+  if (numId != op.ids().size())
+    return op.emitOpError("expected number of ids must match the number of "
+                          "dimensions distributed");
   return success();
 }
 
+AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
+
 //===----------------------------------------------------------------------===//
 // InsertStridedSliceOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index c24a1d5b85ec..49865fddba4c 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2483,16 +2483,16 @@ LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
   SmallVector<Value, 4> extractOperands;
   for (OpOperand &operand : definedOp->getOpOperands())
     extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
-        loc, operand.get(), extract.id(), extract.multiplicity()));
+        loc, extract.getResultType(), operand.get(), extract.ids()));
   Operation *newOp = cloneOpWithOperandsAndTypes(
       rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
   rewriter.replaceOp(extract, newOp->getResult(0));
   return success();
 }
 
-Optional<mlir::vector::DistributeOps>
-mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
-                                         Value id, int64_t multiplicity) {
+Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
+    OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
+    ArrayRef<int64_t> multiplicity, const AffineMap &map) {
   OpBuilder::InsertionGuard guard(builder);
   builder.setInsertionPointAfter(op);
   Location loc = op->getLoc();
@@ -2500,15 +2500,24 @@ mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
     return {};
   Value result = op->getResult(0);
   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
-  // Currently only support distributing 1-D vectors of size multiple of the
-  // given multiplicty. To handle more sizes we would need to support masking.
-  if (!type || type.getRank() != 1 || type.getNumElements() % multiplicity != 0)
+  if (!type || map.getNumResults() != multiplicity.size())
     return {};
+  // For each dimension being distributed check that the size is a multiple of
+  // the multiplicity. To handle more sizes we would need to support masking.
+  unsigned multiplictyCount = 0;
+  for (auto exp : map.getResults()) {
+    auto affinExp = exp.dyn_cast<AffineDimExpr>();
+    if (!affinExp || affinExp.getPosition() >= type.getRank() ||
+        type.getDimSize(affinExp.getPosition()) %
+                multiplicity[multiplictyCount++] !=
+            0)
+      return {};
+  }
   DistributeOps ops;
   ops.extract =
-      builder.create<vector::ExtractMapOp>(loc, result, id, multiplicity);
-  ops.insert = builder.create<vector::InsertMapOp>(loc, ops.extract, result, id,
-                                                   multiplicity);
+      builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
+  ops.insert =
+      builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
   return ops;
 }
 
@@ -2529,17 +2538,22 @@ struct TransferReadExtractPattern
     using mlir::edsc::op::operator*;
     using namespace mlir::edsc::intrinsics;
     SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
-    indices.back() =
-        indices.back() +
-        (extract.id() *
-         std_constant_index(extract.getResultType().getDimSize(0)));
+    AffineMap map = extract.map();
+    unsigned idCount = 0;
+    for (auto expr : map.getResults()) {
+      unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+      indices[pos] =
+          indices[pos] +
+          extract.ids()[idCount++] *
+              std_constant_index(extract.getResultType().getDimSize(pos));
+    }
     Value newRead = vector_transfer_read(extract.getType(), read.memref(),
                                          indices, read.permutation_map(),
-                                         read.padding(), ArrayAttr());
+                                         read.padding(), read.maskedAttr());
     Value dest = rewriter.create<ConstantOp>(
         read.getLoc(), read.getType(), rewriter.getZeroAttr(read.getType()));
-    newRead = rewriter.create<vector::InsertMapOp>(
-        read.getLoc(), newRead, dest, extract.id(), extract.multiplicity());
+    newRead = rewriter.create<vector::InsertMapOp>(read.getLoc(), newRead, dest,
+                                                   extract.ids());
     rewriter.replaceOp(read, newRead);
     return success();
   }
@@ -2560,12 +2574,17 @@ struct TransferWriteInsertPattern
     using namespace mlir::edsc::intrinsics;
     SmallVector<Value, 4> indices(write.indices().begin(),
                                   write.indices().end());
-    indices.back() =
-        indices.back() +
-        (insert.id() *
-         std_constant_index(insert.getSourceVectorType().getDimSize(0)));
+    AffineMap map = insert.map();
+    unsigned idCount = 0;
+    for (auto expr : map.getResults()) {
+      unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+      indices[pos] =
+          indices[pos] +
+          insert.ids()[idCount++] *
+              std_constant_index(insert.getSourceVectorType().getDimSize(pos));
+    }
     vector_transfer_write(insert.vector(), write.memref(), indices,
-                          write.permutation_map(), ArrayAttr());
+                          write.permutation_map(), write.maskedAttr());
     rewriter.eraseOp(write);
     return success();
   }

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 6c0dc1a9aa7f..73b1f9e1e06e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1331,23 +1331,30 @@ func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %va
 
 // -----
 
-func @extract_map_rank(%v: vector<2x32xf32>, %id : index) {
-  // expected-error at +1 {{'vector.extract_map' op expects source and destination vectors of rank 1}}
-  %0 = vector.extract_map %v[%id] : vector<2x32xf32> to vector<2x1xf32>
+func @extract_map_rank(%v: vector<32xf32>, %id : index) {
+  // expected-error at +1 {{'vector.extract_map' op expected source and destination vectors of same rank}}
+  %0 = vector.extract_map %v[%id] : vector<32xf32> to vector<2x1xf32>
 }
 
 // -----
 
 func @extract_map_size(%v: vector<63xf32>, %id : index) {
-  // expected-error at +1 {{'vector.extract_map' op source vector size must be a multiple of destination vector size}}
+  // expected-error at +1 {{'vector.extract_map' op source vector dimensions must be a multiple of destination vector dimensions}}
   %0 = vector.extract_map %v[%id] : vector<63xf32> to vector<2xf32>
 }
 
 // -----
 
-func @insert_map_rank(%v: vector<2x1xf32>, %v1: vector<2x32xf32>, %id : index) {
-  // expected-error at +1 {{'vector.insert_map' op expected source and destination vectors of rank 1}}
-  %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<2x32xf32>
+func @extract_map_id(%v: vector<2x32xf32>, %id : index) {
+  // expected-error at +1 {{'vector.extract_map' op expected number of ids must match the number of dimensions distributed}}
+  %0 = vector.extract_map %v[%id] : vector<2x32xf32> to vector<1x1xf32>
+}
+
+// -----
+
+func @insert_map_rank(%v: vector<2x1xf32>, %v1: vector<32xf32>, %id : index) {
+  // expected-error at +1 {{'vector.insert_map' op expected source and destination vectors of same rank}}
+  %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<32xf32>
 }
 
 // -----
@@ -1356,3 +1363,10 @@ func @insert_map_size(%v: vector<3xf32>, %v1: vector<64xf32>, %id : index) {
   // expected-error at +1 {{'vector.insert_map' op destination vector size must be a multiple of source vector size}}
   %0 = vector.insert_map %v, %v1[%id] : vector<3xf32> into vector<64xf32>
 }
+
+// -----
+
+func @insert_map_id(%v: vector<2x1xf32>, %v1: vector<4x32xf32>, %id : index) {
+  // expected-error at +1 {{'vector.insert_map' op expected number of ids must match the number of dimensions distributed}}
+  %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6a5e6ec671a1..aab6cabf759d 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -432,12 +432,17 @@ func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru:
 }
 
 // CHECK-LABEL: @extract_insert_map
-func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> {
+func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
+  %id0 : index, %id1 : index) -> (vector<32xf32>, vector<16x32xf32>) {
   // CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}}] : vector<32xf32> to vector<2xf32>
-  %vd = vector.extract_map %v[%id] : vector<32xf32> to vector<2xf32>
+  %vd = vector.extract_map %v[%id0] : vector<32xf32> to vector<2xf32>
+  // CHECK: %[[V1:.*]] = vector.extract_map %{{.*}}[%{{.*}}, %{{.*}}] : vector<16x32xf32> to vector<4x2xf32>
+  %vd2 = vector.extract_map %v2[%id0, %id1] : vector<16x32xf32> to vector<4x2xf32>
   // CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}[%{{.*}}] : vector<2xf32> into vector<32xf32>
-  %r = vector.insert_map %vd, %v[%id] : vector<2xf32> into vector<32xf32>
-  // CHECK: return %[[R]] : vector<32xf32>
-  return %r : vector<32xf32>
+  %r = vector.insert_map %vd, %v[%id0] : vector<2xf32> into vector<32xf32>
+  // CHECK: %[[R1:.*]] = vector.insert_map %[[V1]], %{{.*}}[%{{.*}}, %{{.*}}] : vector<4x2xf32> into vector<16x32xf32>
+  %r2 = vector.insert_map %vd2, %v2[%id0, %id1] : vector<4x2xf32> into vector<16x32xf32>
+  // CHECK: return %[[R]], %[[R1]] : vector<32xf32>, vector<16x32xf32>
+  return %r, %r2 : vector<32xf32>, vector<16x32xf32>
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
index 0fece6617e09..950786e86caa 100644
--- a/mlir/test/Dialect/Vector/vector-distribution.mlir
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @distribute_vector_add
 //  CHECK-SAME: (%[[ID:.*]]: index
@@ -22,7 +22,7 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
 //  CHECK-NEXT:    %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
 //  CHECK-NEXT:    %[[EXC:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
 //  CHECK-NEXT:    %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32>
-//  CHECK-NEXT:    vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] : vector<1xf32>, memref<32xf32>
+//  CHECK-NEXT:    vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] {{.*}} : vector<1xf32>, memref<32xf32>
 //  CHECK-NEXT:    return
 func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) {
   %c0 = constant 0 : index
@@ -38,7 +38,7 @@ func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>,
 
 // -----
 
-// CHECK-DAG: #[[MAP0:map[0-9]*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
 
 //       CHECK: func @vector_add_cycle
 //  CHECK-SAME: (%[[ID:.*]]: index
@@ -48,7 +48,7 @@ func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>,
 //  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]]], %{{.*}} : memref<64xf32>, vector<2xf32>
 //  CHECK-NEXT:    %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2xf32>
 //  CHECK-NEXT:    %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
-//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]]] : vector<2xf32>, memref<64xf32>
+//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]]] {{.*}} : vector<2xf32>, memref<64xf32>
 //  CHECK-NEXT:    return
 func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
   %c0 = constant 0 : index
@@ -81,4 +81,46 @@ func @vector_negative_test(%id : index, %A: memref<64xf32>, %B: memref<64xf32>,
   return
 }
 
+// -----
+
+// CHECK-LABEL: func @distribute_vector_add_3d
+//  CHECK-SAME: (%[[ID0:.*]]: index, %[[ID1:.*]]: index
+//  CHECK-NEXT:    %[[ADDV:.*]] = addf %{{.*}}, %{{.*}} : vector<64x4x32xf32>
+//  CHECK-NEXT:    %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32>
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32>
+//  CHECK-NEXT:    %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
+//  CHECK-NEXT:    %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID0]], %[[ID1]]] : vector<2x4x1xf32> into vector<64x4x32xf32>
+//  CHECK-NEXT:    return %[[INS]] : vector<64x4x32xf32>
+func @distribute_vector_add_3d(%id0 : index, %id1 : index,
+  %A: vector<64x4x32xf32>, %B: vector<64x4x32xf32>) -> vector<64x4x32xf32> {
+  %0 = addf %A, %B : vector<64x4x32xf32>
+  return %0: vector<64x4x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
+
+//       CHECK: func @vector_add_transfer_3d
+//  CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index
+//       CHECK:    %[[C0:.*]] = constant 0 : index
+//       CHECK:    %[[ID1:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+//  CHECK-NEXT:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID1]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32>
+//  CHECK-NEXT:    %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32>
+//  CHECK-NEXT:    %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
+//  CHECK-NEXT:    %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]], %[[C0]], %[[ID_1]]] {{.*}} : vector<2x4x1xf32>, memref<64x64x64xf32>
+//  CHECK-NEXT:    return
+func @vector_add_transfer_3d(%id0 : index, %id1 : index, %A: memref<64x64x64xf32>,
+  %B: memref<64x64x64xf32>, %C: memref<64x64x64xf32>) {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32>
+  %b = vector.transfer_read %B[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32>
+  %acc = addf %a, %b: vector<64x4x32xf32>
+  vector.transfer_write %acc, %C[%c0, %c0, %c0]: vector<64x4x32xf32>, memref<64x64x64xf32>
+  return
+}
+
 

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 8daed18649d3..484e78f2b596 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -163,21 +163,40 @@ struct TestVectorDistributePatterns
     registry.insert<VectorDialect>();
     registry.insert<AffineDialect>();
   }
-  Option<int32_t> multiplicity{
-      *this, "distribution-multiplicity",
-      llvm::cl::desc("Set the multiplicity used for distributing vector"),
-      llvm::cl::init(32)};
+  ListOption<int32_t> multiplicity{
+      *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
+      llvm::cl::desc("Set the multiplicity used for distributing vector")};
+
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
     OwningRewritePatternList patterns;
     FuncOp func = getFunction();
     func.walk([&](AddFOp op) {
       OpBuilder builder(op);
-      Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
-          builder, op.getOperation(), func.getArgument(0), multiplicity);
-      if (ops.hasValue()) {
-        SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
-        op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
+      if (auto vecType = op.getType().dyn_cast<VectorType>()) {
+        SmallVector<int64_t, 2> mul;
+        SmallVector<AffineExpr, 2> perm;
+        SmallVector<Value, 2> ids;
+        unsigned count = 0;
+        // Remove the multiplicity of 1 and calculate the affine map based on
+        // the multiplicity.
+        SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
+        for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
+          if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
+            mul.push_back(m[i]);
+            ids.push_back(func.getArgument(count++));
+            perm.push_back(getAffineDimExpr(i, ctx));
+          }
+        }
+        auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
+                                  perm, ctx);
+        Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
+            builder, op.getOperation(), ids, mul, map);
+        if (ops.hasValue()) {
+          SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
+          op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
+                                              extractOp);
+        }
       }
     });
     patterns.insert<PointwiseExtractPattern>(ctx);
@@ -229,9 +248,11 @@ struct TestVectorToLoopPatterns
       for (Operation *it : dependentOps) {
         it->moveBefore(forOp.getBody()->getTerminator());
       }
+      auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
       // break up the original op and let the patterns propagate.
       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
-          builder, op.getOperation(), forOp.getInductionVar(), multiplicity);
+          builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
+          map);
       if (ops.hasValue()) {
         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);


        


More information about the Mlir-commits mailing list