[Mlir-commits] [mlir] dd14e58 - [mlir][vector] First step of vector distribution transformation
Thomas Raoux
llvmlistbot at llvm.org
Wed Sep 30 13:16:18 PDT 2020
Author: Thomas Raoux
Date: 2020-09-30T13:14:55-07:00
New Revision: dd14e5825209386129770296f9bc3a14ab0b4592
URL: https://github.com/llvm/llvm-project/commit/dd14e5825209386129770296f9bc3a14ab0b4592
DIFF: https://github.com/llvm/llvm-project/commit/dd14e5825209386129770296f9bc3a14ab0b4592.diff
LOG: [mlir][vector] First step of vector distribution transformation
This is the first of several steps to support distributing large vectors. This
adds instructions extract_map and insert_map that allow us to do incremental
lowering. Right now the transformation only apply to simple pointwise operation
with a vector size matching the multiplicity of the IDs used to distribute the
vector.
This can be used to distribute large vectors to loops or SPMD.
Differential Revision: https://reviews.llvm.org/D88341
Added:
mlir/test/Dialect/Vector/vector-distribution.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index f74c8687bf53..42e947071403 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -454,6 +454,71 @@ def Vector_ExtractSlicesOp :
}];
}
+def Vector_ExtractMapOp :
+ Vector_Op<"extract_map", [NoSideEffect]>,
+ Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>,
+ Results<(outs AnyVector)> {
+ let summary = "vector extract map operation";
+ let description = [{
+ Takes an 1-D vector and extract a sub-part of the vector starting at id with
+ a size of `vector size / multiplicity`. This maps a given multiplicity of
+ the vector to a Value such as a loop induction variable or an SPMD id.
+
+ Similarly to vector.tuple_get, this operation is used for progressive
+ lowering and should be folded away before converting to LLVM.
+
+
+ For instance, the following code:
+ ```mlir
+ %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32>
+ %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32>
+ %c = addf %a, %b: vector<32xf32>
+ vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32>
+ ```
+ can be rewritten to:
+ ```mlir
+ %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32>
+ %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32>
+ %ea = vector.extract_map %a[%id : 32] : vector<32xf32> to vector<1xf32>
+ %eb = vector.extract_map %b[%id : 32] : vector<32xf32> to vector<1xf32>
+ %ec = addf %ea, %eb : vector<1xf32>
+ %c = vector.insert_map %ec, %id, 32 : vector<1xf32> to vector<32xf32>
+ vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32>
+ ```
+
+ Where %id can be an induction variable or an SPMD id going from 0 to 31.
+
+ And then be rewritten to:
+ ```mlir
+ %a = vector.transfer_read %A[%id]: memref<32xf32>, vector<1xf32>
+ %b = vector.transfer_read %B[%id]: memref<32xf32>, vector<1xf32>
+ %c = addf %a, %b: vector<1xf32>
+ vector.transfer_write %c, %C[%id]: memref<32xf32>, vector<1xf32>
+ ```
+
+ Example:
+
+ ```mlir
+ %ev = vector.extract_map %v[%id:32] : vector<32xf32> to vector<1xf32>
+ ```
+ }];
+ let builders = [OpBuilder<
+ "OpBuilder &builder, OperationState &result, " #
+ "Value vector, Value id, int64_t multiplicity">];
+ let extraClassDeclaration = [{
+ VectorType getSourceVectorType() {
+ return vector().getType().cast<VectorType>();
+ }
+ VectorType getResultType() {
+ return getResult().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = [{
+ $vector `[` $id `:` $multiplicity `]` attr-dict `:` type($vector) `to`
+ type(results)
+ }];
+}
+
def Vector_FMAOp :
Op<Vector_Dialect, "fma", [NoSideEffect,
AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
@@ -626,6 +691,46 @@ def Vector_InsertSlicesOp :
}];
}
+def Vector_InsertMapOp :
+ Vector_Op<"insert_map", [NoSideEffect]>,
+ Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>,
+ Results<(outs AnyVector)> {
+ let summary = "vector insert map operation";
+ let description = [{
+ insert an 1-D vector and within a larger vector starting at id. The new
+ vector created will have a size of `vector size * multiplicity`. This
+ represents how a sub-part of the vector is written for a given Value such as
+ a loop induction variable or an SPMD id.
+
+ Similarly to vector.tuple_get, this operation is used for progressive
+ lowering and should be folded away before converting to LLVM.
+
+ This operations is meant to be used in combination with vector.extract_map.
+ See example in extract.map description.
+
+ Example:
+
+ ```mlir
+ %v = vector.insert_map %ev, %id, 32 : vector<1xf32> to vector<32xf32>
+ ```
+ }];
+ let builders = [OpBuilder<
+ "OpBuilder &builder, OperationState &result, " #
+ "Value vector, Value id, int64_t multiplicity">];
+ let extraClassDeclaration = [{
+ VectorType getSourceVectorType() {
+ return vector().getType().cast<VectorType>();
+ }
+ VectorType getResultType() {
+ return getResult().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = [{
+ $vector `,` $id `,` $multiplicity attr-dict `:` type($vector) `to`
+ type(results)
+ }];
+}
+
def Vector_InsertStridedSliceOp :
Vector_Op<"insert_strided_slice", [NoSideEffect,
PredOpTrait<"operand #0 and result have same element type",
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 9587c56c0255..da9650c67efb 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -172,6 +172,47 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
FilterConstraintType filter;
};
+struct DistributeOps {
+ ExtractMapOp extract;
+ InsertMapOp insert;
+};
+
+/// Distribute a 1D vector pointwise operation over a range of given IDs taking
+/// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or
+/// SPMD id). This transformation only inserts
+/// vector.extract_map/vector.insert_map. It is meant to be used with
+/// canonicalizations pattern to propagate and fold the vector
+/// insert_map/extract_map operations.
+/// Transforms:
+// %v = addf %a, %b : vector<32xf32>
+/// to:
+/// %v = addf %a, %b : vector<32xf32> %ev =
+/// vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> %nv =
+/// vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32>
+Optional<DistributeOps> distributPointwiseVectorOp(OpBuilder &builder,
+ Operation *op, Value id,
+ int64_t multiplicity);
+/// Canonicalize an extra element using the result of a pointwise operation.
+/// Transforms:
+/// %v = addf %a, %b : vector32xf32>
+/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
+/// to:
+/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
+/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
+/// %dv = addf %da, %db : vector<1xf32>
+struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
+ using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
+ PointwiseExtractPattern(
+ MLIRContext *context, FilterConstraintType constraint =
+ [](ExtractMapOp op) { return success(); })
+ : OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
+ LogicalResult matchAndRewrite(ExtractMapOp extract,
+ PatternRewriter &rewriter) const override;
+
+private:
+ FilterConstraintType filter;
+};
+
} // namespace vector
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 348ccf841308..1a83c556d47b 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -900,6 +900,29 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
populateFromInt64AttrArray(strides(), results);
}
+//===----------------------------------------------------------------------===//
+// ExtractMapOp
+//===----------------------------------------------------------------------===//
+
+void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
+ Value vector, Value id, int64_t multiplicity) {
+ VectorType type = vector.getType().cast<VectorType>();
+ VectorType resultType = VectorType::get(type.getNumElements() / multiplicity,
+ type.getElementType());
+ ExtractMapOp::build(builder, result, resultType, vector, id, multiplicity);
+}
+
+static LogicalResult verify(ExtractMapOp op) {
+ if (op.getSourceVectorType().getShape().size() != 1 ||
+ op.getResultType().getShape().size() != 1)
+ return op.emitOpError("expects source and destination vectors of rank 1");
+ if (op.getResultType().getNumElements() * (int64_t)op.multiplicity() !=
+ op.getSourceVectorType().getNumElements())
+ return op.emitOpError("vector sizes mismatch. Source size must be equal "
+ "to destination size * multiplicity");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
@@ -1122,6 +1145,30 @@ void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
populateFromInt64AttrArray(strides(), results);
}
+//===----------------------------------------------------------------------===//
+// InsertMapOp
+//===----------------------------------------------------------------------===//
+
+void InsertMapOp::build(OpBuilder &builder, OperationState &result,
+ Value vector, Value id, int64_t multiplicity) {
+ VectorType type = vector.getType().cast<VectorType>();
+ VectorType resultType = VectorType::get(type.getNumElements() * multiplicity,
+ type.getElementType());
+ InsertMapOp::build(builder, result, resultType, vector, id, multiplicity);
+}
+
+static LogicalResult verify(InsertMapOp op) {
+ if (op.getSourceVectorType().getShape().size() != 1 ||
+ op.getResultType().getShape().size() != 1)
+ return op.emitOpError("expected source and destination vectors of rank 1");
+ if ((int64_t)op.multiplicity() * op.getSourceVectorType().getNumElements() !=
+ op.getResultType().getNumElements())
+ return op.emitOpError(
+ "vector sizes mismatch. Destination size must be equal "
+ "to source size * multiplicity");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// InsertStridedSliceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 5bf7857a66e8..6a244a454e06 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2418,6 +2418,40 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
return failure();
}
+LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
+ ExtractMapOp extract, PatternRewriter &rewriter) const {
+ Operation *definedOp = extract.vector().getDefiningOp();
+ if (!definedOp || definedOp->getNumResults() != 1)
+ return failure();
+ // TODO: Create an interfaceOp for elementwise operations.
+ if (!isa<AddFOp>(definedOp))
+ return failure();
+ Location loc = extract.getLoc();
+ SmallVector<Value, 4> extractOperands;
+ for (OpOperand &operand : definedOp->getOpOperands())
+ extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
+ loc, operand.get(), extract.id(), extract.multiplicity()));
+ Operation *newOp = cloneOpWithOperandsAndTypes(
+ rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
+ rewriter.replaceOp(extract, newOp->getResult(0));
+ return success();
+}
+
+Optional<mlir::vector::DistributeOps>
+mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
+ Value id, int64_t multiplicity) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointAfter(op);
+ Location loc = op->getLoc();
+ Value result = op->getResult(0);
+ DistributeOps ops;
+ ops.extract =
+ builder.create<vector::ExtractMapOp>(loc, result, id, multiplicity);
+ ops.insert =
+ builder.create<vector::InsertMapOp>(loc, ops.extract, id, multiplicity);
+ return ops;
+}
+
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3a081231fc7d..25e002fed35a 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1328,3 +1328,31 @@ func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %va
// expected-error at +1 {{'vector.compressstore' op expected value dim to match mask dim}}
vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
}
+
+// -----
+
+func @extract_map_rank(%v: vector<2x32xf32>, %id : index) {
+ // expected-error at +1 {{'vector.extract_map' op expects source and destination vectors of rank 1}}
+ %0 = vector.extract_map %v[%id : 32] : vector<2x32xf32> to vector<2x1xf32>
+}
+
+// -----
+
+func @extract_map_size(%v: vector<63xf32>, %id : index) {
+ // expected-error at +1 {{'vector.extract_map' op vector sizes mismatch. Source size must be equal to destination size * multiplicity}}
+ %0 = vector.extract_map %v[%id : 32] : vector<63xf32> to vector<2xf32>
+}
+
+// -----
+
+func @insert_map_rank(%v: vector<2x1xf32>, %id : index) {
+ // expected-error at +1 {{'vector.insert_map' op expected source and destination vectors of rank 1}}
+ %0 = vector.insert_map %v, %id, 32 : vector<2x1xf32> to vector<2x32xf32>
+}
+
+// -----
+
+func @insert_map_size(%v: vector<1xf32>, %id : index) {
+ // expected-error at +1 {{'vector.insert_map' op vector sizes mismatch. Destination size must be equal to source size * multiplicity}}
+ %0 = vector.insert_map %v, %id, 32 : vector<1xf32> to vector<64xf32>
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2a62be94e01b..7315d2189c67 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -432,3 +432,14 @@ func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru:
vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
+
+// CHECK-LABEL: @extract_insert_map
+func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> {
+ // CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}} : 16] : vector<32xf32> to vector<2xf32>
+ %vd = vector.extract_map %v[%id : 16] : vector<32xf32> to vector<2xf32>
+ // CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}, 16 : vector<2xf32> to vector<32xf32>
+ %r = vector.insert_map %vd, %id, 16 : vector<2xf32> to vector<32xf32>
+ // CHECK: return %[[R]] : vector<32xf32>
+ return %r : vector<32xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
new file mode 100644
index 000000000000..0216a017d7af
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -test-vector-distribute-patterns | FileCheck %s
+
+// CHECK-LABEL: func @distribute_vector_add
+// CHECK-SAME: (%[[ID:.*]]: index
+// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32>
+// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32>
+// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
+// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ID]], 32 : vector<1xf32> to vector<32xf32>
+// CHECK-NEXT: return %[[INS]] : vector<32xf32>
+func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
+ %0 = addf %A, %B : vector<32xf32>
+ return %0: vector<32xf32>
+}
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index ab8460318b49..2ffe10bc1682 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -125,6 +125,28 @@ struct TestVectorUnrollingPatterns
}
};
+struct TestVectorDistributePatterns
+ : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<VectorDialect>();
+ }
+ void runOnFunction() override {
+ MLIRContext *ctx = &getContext();
+ OwningRewritePatternList patterns;
+ FuncOp func = getFunction();
+ func.walk([&](AddFOp op) {
+ OpBuilder builder(op);
+ Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
+ builder, op.getOperation(), func.getArgument(0), 32);
+ assert(ops.hasValue());
+ SmallPtrSet<Operation *, 1> extractOp({ops->extract});
+ op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
+ });
+ patterns.insert<PointwiseExtractPattern>(ctx);
+ applyPatternsAndFoldGreedily(getFunction(), patterns);
+ }
+};
+
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
FunctionPass> {
@@ -178,5 +200,9 @@ void registerTestVectorConversions() {
vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
"Test conversion patterns to split "
"transfer ops via scf.if + linalg ops");
+ PassRegistration<TestVectorDistributePatterns> distributePass(
+ "test-vector-distribute-patterns",
+ "Test conversion patterns to distribute vector ops in the vector "
+ "dialect");
}
} // namespace mlir
More information about the Mlir-commits
mailing list