[Mlir-commits] [mlir] dd14e58 - [mlir][vector] First step of vector distribution transformation

Thomas Raoux llvmlistbot at llvm.org
Wed Sep 30 13:16:18 PDT 2020


Author: Thomas Raoux
Date: 2020-09-30T13:14:55-07:00
New Revision: dd14e5825209386129770296f9bc3a14ab0b4592

URL: https://github.com/llvm/llvm-project/commit/dd14e5825209386129770296f9bc3a14ab0b4592
DIFF: https://github.com/llvm/llvm-project/commit/dd14e5825209386129770296f9bc3a14ab0b4592.diff

LOG: [mlir][vector] First step of vector distribution transformation

This is the first of several steps to support distributing large vectors. This
adds instructions extract_map and insert_map that allow us to do incremental
lowering. Right now the transformation only apply to simple pointwise operation
with a vector size matching the multiplicity of the IDs used to distribute the
vector.
This can be used to distribute large vectors to loops or SPMD.

Differential Revision: https://reviews.llvm.org/D88341

Added: 
    mlir/test/Dialect/Vector/vector-distribution.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index f74c8687bf53..42e947071403 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -454,6 +454,71 @@ def Vector_ExtractSlicesOp :
   }];
 }
 
+def Vector_ExtractMapOp :
+  Vector_Op<"extract_map", [NoSideEffect]>,
+    Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>,
+    Results<(outs AnyVector)> {
+  let summary = "vector extract map operation";
+  let description = [{
+    Takes an 1-D vector and extract a sub-part of the vector starting at id with
+    a size of `vector size / multiplicity`. This maps a given multiplicity of
+    the vector to a Value such as a loop induction variable or an SPMD id.
+
+    Similarly to vector.tuple_get, this operation is used for progressive
+    lowering and should be folded away before converting to LLVM.
+
+
+    For instance, the following code:
+    ```mlir
+    %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32>
+    %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32>
+    %c = addf %a, %b: vector<32xf32>
+    vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32>
+    ```
+    can be rewritten to:
+    ```mlir
+    %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32>
+    %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32>
+    %ea = vector.extract_map %a[%id : 32] : vector<32xf32> to vector<1xf32>
+    %eb = vector.extract_map %b[%id : 32] : vector<32xf32> to vector<1xf32>
+    %ec = addf %ea, %eb : vector<1xf32>
+    %c = vector.insert_map %ec, %id, 32 : vector<1xf32> to vector<32xf32>
+    vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32>
+    ```
+
+    Where %id can be an induction variable or an SPMD id going from 0 to 31.
+
+    And then be rewritten to:
+    ```mlir
+    %a = vector.transfer_read %A[%id]: memref<32xf32>, vector<1xf32>
+    %b = vector.transfer_read %B[%id]: memref<32xf32>, vector<1xf32>
+    %c = addf %a, %b: vector<1xf32>
+    vector.transfer_write %c, %C[%id]: memref<32xf32>, vector<1xf32>
+    ```
+
+    Example:
+
+    ```mlir
+    %ev = vector.extract_map %v[%id:32] : vector<32xf32> to vector<1xf32>
+    ```
+  }];
+  let builders = [OpBuilder<
+    "OpBuilder &builder, OperationState &result, " #
+    "Value vector, Value id, int64_t multiplicity">];
+  let extraClassDeclaration = [{
+    VectorType getSourceVectorType() {
+      return vector().getType().cast<VectorType>();
+    }
+    VectorType getResultType() {
+      return getResult().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = [{
+    $vector `[` $id `:` $multiplicity `]` attr-dict `:` type($vector) `to`
+    type(results)
+  }];
+}
+
 def Vector_FMAOp :
   Op<Vector_Dialect, "fma", [NoSideEffect,
                              AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
@@ -626,6 +691,46 @@ def Vector_InsertSlicesOp :
   }];
 }
 
+def Vector_InsertMapOp :
+  Vector_Op<"insert_map", [NoSideEffect]>,
+    Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>,
+    Results<(outs AnyVector)> {
+  let summary = "vector insert map operation";
+  let description = [{
+    insert an 1-D vector and within a larger vector starting at id. The new
+    vector created will have a size of `vector size * multiplicity`. This
+    represents how a sub-part of the vector is written for a given Value such as
+    a loop induction variable or an SPMD id.
+
+    Similarly to vector.tuple_get, this operation is used for progressive
+    lowering and should be folded away before converting to LLVM.
+
+    This operations is meant to be used in combination with vector.extract_map.
+    See example in extract.map description.
+
+    Example:
+
+    ```mlir
+    %v = vector.insert_map %ev, %id, 32 : vector<1xf32> to vector<32xf32>
+    ```
+  }];
+  let builders = [OpBuilder<
+    "OpBuilder &builder, OperationState &result, " #
+    "Value vector, Value id, int64_t multiplicity">];
+  let extraClassDeclaration = [{
+    VectorType getSourceVectorType() {
+      return vector().getType().cast<VectorType>();
+    }
+    VectorType getResultType() {
+      return getResult().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = [{
+    $vector `,` $id `,` $multiplicity attr-dict `:` type($vector) `to`
+    type(results)
+  }];
+}
+
 def Vector_InsertStridedSliceOp :
   Vector_Op<"insert_strided_slice", [NoSideEffect,
     PredOpTrait<"operand #0 and result have same element type",

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 9587c56c0255..da9650c67efb 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -172,6 +172,47 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
   FilterConstraintType filter;
 };
 
+struct DistributeOps {
+  ExtractMapOp extract;
+  InsertMapOp insert;
+};
+
+/// Distribute a 1D vector pointwise operation over a range of given IDs taking
+/// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or
+/// SPMD id). This transformation only inserts
+/// vector.extract_map/vector.insert_map. It is meant to be used with
+/// canonicalizations pattern to propagate and fold the vector
+/// insert_map/extract_map operations.
+/// Transforms:
+//  %v = addf %a, %b : vector<32xf32>
+/// to:
+/// %v = addf %a, %b : vector<32xf32> %ev =
+/// vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> %nv =
+/// vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32>
+Optional<DistributeOps> distributPointwiseVectorOp(OpBuilder &builder,
+                                                   Operation *op, Value id,
+                                                   int64_t multiplicity);
+/// Canonicalize an extra element using the result of a pointwise operation.
+/// Transforms:
+/// %v = addf %a, %b : vector32xf32>
+/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
+/// to:
+/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
+/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
+/// %dv = addf %da, %db : vector<1xf32>
+struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
+  using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
+  PointwiseExtractPattern(
+      MLIRContext *context, FilterConstraintType constraint =
+                                [](ExtractMapOp op) { return success(); })
+      : OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
+  LogicalResult matchAndRewrite(ExtractMapOp extract,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  FilterConstraintType filter;
+};
+
 } // namespace vector
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 348ccf841308..1a83c556d47b 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -900,6 +900,29 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
   populateFromInt64AttrArray(strides(), results);
 }
 
+//===----------------------------------------------------------------------===//
+// ExtractMapOp
+//===----------------------------------------------------------------------===//
+
+void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
+                         Value vector, Value id, int64_t multiplicity) {
+  VectorType type = vector.getType().cast<VectorType>();
+  VectorType resultType = VectorType::get(type.getNumElements() / multiplicity,
+                                          type.getElementType());
+  ExtractMapOp::build(builder, result, resultType, vector, id, multiplicity);
+}
+
+static LogicalResult verify(ExtractMapOp op) {
+  if (op.getSourceVectorType().getShape().size() != 1 ||
+      op.getResultType().getShape().size() != 1)
+    return op.emitOpError("expects source and destination vectors of rank 1");
+  if (op.getResultType().getNumElements() * (int64_t)op.multiplicity() !=
+      op.getSourceVectorType().getNumElements())
+    return op.emitOpError("vector sizes mismatch. Source size must be equal "
+                          "to destination size * multiplicity");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
@@ -1122,6 +1145,30 @@ void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
   populateFromInt64AttrArray(strides(), results);
 }
 
+//===----------------------------------------------------------------------===//
+// InsertMapOp
+//===----------------------------------------------------------------------===//
+
+void InsertMapOp::build(OpBuilder &builder, OperationState &result,
+                        Value vector, Value id, int64_t multiplicity) {
+  VectorType type = vector.getType().cast<VectorType>();
+  VectorType resultType = VectorType::get(type.getNumElements() * multiplicity,
+                                          type.getElementType());
+  InsertMapOp::build(builder, result, resultType, vector, id, multiplicity);
+}
+
+static LogicalResult verify(InsertMapOp op) {
+  if (op.getSourceVectorType().getShape().size() != 1 ||
+      op.getResultType().getShape().size() != 1)
+    return op.emitOpError("expected source and destination vectors of rank 1");
+  if ((int64_t)op.multiplicity() * op.getSourceVectorType().getNumElements() !=
+      op.getResultType().getNumElements())
+    return op.emitOpError(
+        "vector sizes mismatch. Destination size must be equal "
+        "to source size * multiplicity");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // InsertStridedSliceOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 5bf7857a66e8..6a244a454e06 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2418,6 +2418,40 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
   return failure();
 }
 
+LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
+    ExtractMapOp extract, PatternRewriter &rewriter) const {
+  Operation *definedOp = extract.vector().getDefiningOp();
+  if (!definedOp || definedOp->getNumResults() != 1)
+    return failure();
+  // TODO: Create an interfaceOp for elementwise operations.
+  if (!isa<AddFOp>(definedOp))
+    return failure();
+  Location loc = extract.getLoc();
+  SmallVector<Value, 4> extractOperands;
+  for (OpOperand &operand : definedOp->getOpOperands())
+    extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
+        loc, operand.get(), extract.id(), extract.multiplicity()));
+  Operation *newOp = cloneOpWithOperandsAndTypes(
+      rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
+  rewriter.replaceOp(extract, newOp->getResult(0));
+  return success();
+}
+
+Optional<mlir::vector::DistributeOps>
+mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
+                                         Value id, int64_t multiplicity) {
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointAfter(op);
+  Location loc = op->getLoc();
+  Value result = op->getResult(0);
+  DistributeOps ops;
+  ops.extract =
+      builder.create<vector::ExtractMapOp>(loc, result, id, multiplicity);
+  ops.insert =
+      builder.create<vector::InsertMapOp>(loc, ops.extract, id, multiplicity);
+  return ops;
+}
+
 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3a081231fc7d..25e002fed35a 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1328,3 +1328,31 @@ func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %va
   // expected-error at +1 {{'vector.compressstore' op expected value dim to match mask dim}}
   vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
 }
+
+// -----
+
+func @extract_map_rank(%v: vector<2x32xf32>, %id : index) {
+  // expected-error at +1 {{'vector.extract_map' op expects source and destination vectors of rank 1}}
+  %0 = vector.extract_map %v[%id : 32] : vector<2x32xf32> to vector<2x1xf32>
+}
+
+// -----
+
+func @extract_map_size(%v: vector<63xf32>, %id : index) {
+  // expected-error at +1 {{'vector.extract_map' op vector sizes mismatch. Source size must be equal to destination size * multiplicity}}
+  %0 = vector.extract_map %v[%id : 32] : vector<63xf32> to vector<2xf32>
+}
+
+// -----
+
+func @insert_map_rank(%v: vector<2x1xf32>, %id : index) {
+  // expected-error at +1 {{'vector.insert_map' op expected source and destination vectors of rank 1}}
+  %0 = vector.insert_map %v, %id, 32 : vector<2x1xf32> to vector<2x32xf32>
+}
+
+// -----
+
+func @insert_map_size(%v: vector<1xf32>, %id : index) {
+  // expected-error at +1 {{'vector.insert_map' op vector sizes mismatch. Destination size must be equal to source size * multiplicity}}
+  %0 = vector.insert_map %v, %id, 32 : vector<1xf32> to vector<64xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2a62be94e01b..7315d2189c67 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -432,3 +432,14 @@ func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru:
   vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return
 }
+
+// CHECK-LABEL: @extract_insert_map
+func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> {
+  // CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}} : 16] : vector<32xf32> to vector<2xf32>
+  %vd = vector.extract_map %v[%id : 16] : vector<32xf32> to vector<2xf32>
+  // CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}, 16 : vector<2xf32> to vector<32xf32>
+  %r = vector.insert_map %vd, %id, 16 : vector<2xf32> to vector<32xf32>
+  // CHECK: return %[[R]] : vector<32xf32>
+  return %r : vector<32xf32>
+}
+

diff  --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
new file mode 100644
index 000000000000..0216a017d7af
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -test-vector-distribute-patterns | FileCheck %s
+
+// CHECK-LABEL: func @distribute_vector_add
+//  CHECK-SAME: (%[[ID:.*]]: index
+//  CHECK-NEXT:    %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32>
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32>
+//  CHECK-NEXT:    %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
+//  CHECK-NEXT:    %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ID]], 32 : vector<1xf32> to vector<32xf32>
+//  CHECK-NEXT:    return %[[INS]] : vector<32xf32>
+func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
+  %0 = addf %A, %B : vector<32xf32>
+  return %0: vector<32xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index ab8460318b49..2ffe10bc1682 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -125,6 +125,28 @@ struct TestVectorUnrollingPatterns
   }
 };
 
+struct TestVectorDistributePatterns
+    : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<VectorDialect>();
+  }
+  void runOnFunction() override {
+    MLIRContext *ctx = &getContext();
+    OwningRewritePatternList patterns;
+    FuncOp func = getFunction();
+    func.walk([&](AddFOp op) {
+      OpBuilder builder(op);
+      Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
+          builder, op.getOperation(), func.getArgument(0), 32);
+      assert(ops.hasValue());
+      SmallPtrSet<Operation *, 1> extractOp({ops->extract});
+      op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
+    });
+    patterns.insert<PointwiseExtractPattern>(ctx);
+    applyPatternsAndFoldGreedily(getFunction(), patterns);
+  }
+};
+
 struct TestVectorTransferFullPartialSplitPatterns
     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
                          FunctionPass> {
@@ -178,5 +200,9 @@ void registerTestVectorConversions() {
       vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
                                      "Test conversion patterns to split "
                                      "transfer ops via scf.if + linalg ops");
+  PassRegistration<TestVectorDistributePatterns> distributePass(
+      "test-vector-distribute-patterns",
+      "Test conversion patterns to distribute vector ops in the vector "
+      "dialect");
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list