[Mlir-commits] [mlir] 627733b - [mlir][vector] Extend vector distribution to all elementwise and contract
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 30 16:36:08 PDT 2021
Author: thomasraoux
Date: 2021-06-30T16:22:31-07:00
New Revision: 627733b5f045e870577e5abf70944d3ffac7a6fb
URL: https://github.com/llvm/llvm-project/commit/627733b5f045e870577e5abf70944d3ffac7a6fb
DIFF: https://github.com/llvm/llvm-project/commit/627733b5f045e870577e5abf70944d3ffac7a6fb.diff
LOG: [mlir][vector] Extend vector distribution to all elementwise and contract
Uses elementwise interface to generalize canonicalization pattern and add a new
pattern for vector.contract case.
Differential Revision: https://reviews.llvm.org/D104343
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-distribution.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 4b5f5ce8e035f..1fc5f1b45ceef 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -105,6 +105,10 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
// a sequence of vector.reduction ops.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
+/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
+/// chain.
+void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
+
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
class CombiningKindAttr
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 173183ee9b7c2..8419f49c5385f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -251,26 +251,6 @@ Optional<DistributeOps>
distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
const AffineMap &map);
-/// Canonicalize an extra element using the result of a pointwise operation.
-/// Transforms:
-/// %v = addf %a, %b : vector32xf32>
-/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
-/// to:
-/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
-/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
-/// %dv = addf %da, %db : vector<1xf32>
-struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
- using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
- PointwiseExtractPattern(
- MLIRContext *context, FilterConstraintType constraint =
- [](ExtractMapOp op) { return success(); })
- : OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
- LogicalResult matchAndRewrite(ExtractMapOp extract,
- PatternRewriter &rewriter) const override;
-
-private:
- FilterConstraintType filter;
-};
/// Implements transfer op write to read forwarding and dead transfer write
/// optimizations.
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 3a10fb3de6416..33d57623f6be5 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2793,25 +2793,6 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
return failure();
}
-LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
- ExtractMapOp extract, PatternRewriter &rewriter) const {
- Operation *definedOp = extract.vector().getDefiningOp();
- if (!definedOp || definedOp->getNumResults() != 1)
- return failure();
- // TODO: Create an interfaceOp for elementwise operations.
- if (!isa<AddFOp>(definedOp))
- return failure();
- Location loc = extract.getLoc();
- SmallVector<Value, 4> extractOperands;
- for (OpOperand &operand : definedOp->getOpOperands())
- extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
- loc, extract.getResultType(), operand.get(), extract.ids()));
- Operation *newOp = cloneOpWithOperandsAndTypes(
- rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
- rewriter.replaceOp(extract, newOp->getResult(0));
- return success();
-}
-
Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
ArrayRef<int64_t> multiplicity, const AffineMap &map) {
@@ -2843,6 +2824,91 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
return ops;
}
+/// Canonicalize an extract_map using the result of a pointwise operation.
+/// Transforms:
+/// %v = addf %a, %b : vector32xf32>
+/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
+/// to:
+/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
+/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
+/// %dv = addf %da, %db : vector<1xf32>
+struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
+ using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
+ PatternRewriter &rewriter) const override {
+ Operation *definedOp = extract.vector().getDefiningOp();
+ if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
+ definedOp->getNumResults() != 1)
+ return failure();
+ Location loc = extract.getLoc();
+ SmallVector<Value, 4> extractOperands;
+ for (OpOperand &operand : definedOp->getOpOperands()) {
+ auto vecType = operand.get().getType().template dyn_cast<VectorType>();
+ if (!vecType) {
+ extractOperands.push_back(operand.get());
+ continue;
+ }
+ extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
+ loc,
+ VectorType::get(extract.getResultType().getShape(),
+ vecType.getElementType()),
+ operand.get(), extract.ids()));
+ }
+ Operation *newOp = cloneOpWithOperandsAndTypes(
+ rewriter, loc, definedOp, extractOperands, extract.getResultType());
+ rewriter.replaceOp(extract, newOp->getResult(0));
+ return success();
+ }
+};
+
+/// Canonicalize an extract_map using the result of a contract operation.
+/// This propagate the extract_map to operands.
+struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
+ using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
+ PatternRewriter &rewriter) const override {
+ Operation *definedOp = extract.vector().getDefiningOp();
+ auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
+ if (!contract)
+ return failure();
+ Location loc = contract.getLoc();
+ unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
+ AffineMap affineMap = contract.getIndexingMaps()[accIndex];
+ // Create a map of the dimensions distributed based on the acc affine map.
+ // Only parallel dimensions are being distributed, reduction dimensions are
+ // untouched.
+ DenseMap<int64_t, int64_t> map;
+ for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
+ map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
+ SmallVector<Value, 4> extractOperands;
+ for (auto it : llvm::enumerate(contract.getIndexingMaps())) {
+ // For each operands calculate the new vector type after distribution.
+ Value operand = contract->getOperand(it.index());
+ auto vecType = operand.getType().cast<VectorType>();
+ SmallVector<int64_t> operandShape(vecType.getShape().begin(),
+ vecType.getShape().end());
+ for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
+ unsigned dim = it.value().getDimPosition(i);
+ auto distributedDim = map.find(dim);
+ // If the dimension is not in the map it means it is a reduction and
+ // doesn't get distributed.
+ if (distributedDim == map.end())
+ continue;
+ operandShape[i] = distributedDim->second;
+ }
+ VectorType newVecType =
+ VectorType::get(operandShape, vecType.getElementType());
+ extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
+ loc, newVecType, operand, extract.ids()));
+ }
+ Operation *newOp =
+ cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
+ extract.getResult().getType());
+ rewriter.replaceOp(extract, newOp->getResult(0));
+ return success();
+ }
+};
+
/// Converts TransferRead op used by ExtractMap op into a smaller dimension
/// TransferRead.
/// Example:
@@ -4100,8 +4166,7 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
RewritePatternSet &patterns) {
- patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
- TransferReadExtractPattern, TransferWriteInsertPattern>(
+ patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp>(
patterns.getContext());
}
@@ -4112,6 +4177,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
ignoreFilter);
}
+void mlir::vector::populatePropagateVectorDistributionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<PointwiseExtractPattern, ContractExtractPattern,
+ TransferReadExtractPattern, TransferWriteInsertPattern>(
+ patterns.getContext());
+}
+
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
RewritePatternSet &patterns) {
patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
diff --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
index 0ad46d1b204e1..603493e7456b6 100644
--- a/mlir/test/Dialect/Vector/vector-distribution.mlir
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,4 -split-input-file | FileCheck %s --check-prefix=CHECK2D
// CHECK-LABEL: func @distribute_vector_add
// CHECK-SAME: (%[[ID:.*]]: index
@@ -15,6 +16,24 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
// -----
+// CHECK-LABEL: func @distribute_vector_add_exp
+// CHECK-SAME: (%[[ID:.*]]: index
+// CHECK-NEXT: %[[EXPV:.*]] = math.exp %{{.*}} : vector<32xf32>
+// CHECK-NEXT: %[[ADDV:.*]] = addf %[[EXPV]], %{{.*}} : vector<32xf32>
+// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
+// CHECK-NEXT: %[[EXC:.*]] = math.exp %[[EXA]] : vector<1xf32>
+// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
+// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXC]], %[[EXB]] : vector<1xf32>
+// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32>
+// CHECK-NEXT: return %[[INS]] : vector<32xf32>
+func @distribute_vector_add_exp(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
+ %C = math.exp %A : vector<32xf32>
+ %0 = addf %C, %B : vector<32xf32>
+ return %0: vector<32xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @vector_add_read_write
// CHECK-SAME: (%[[ID:.*]]: index
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
@@ -154,3 +173,32 @@ func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?
vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
return
}
+
+// -----
+
+// CHECK2D-LABEL: vector_add_contract
+// CHECK2D: %[[A:.+]] = vector.transfer_read %arg2[%0, %c0], %cst : memref<?x?xf32>, vector<2x4xf32>
+// CHECK2D: %[[B:.+]] = vector.transfer_read %arg3[%2, %c0], %cst : memref<?x?xf32>, vector<16x4xf32>
+// CHECK2D: %[[C:.+]] = vector.transfer_read %arg4[%4, %5], %cst : memref<?x?xf32>, vector<2x16xf32>
+// CHECK2D: %[[E:.+]] = vector.transfer_read %arg5[%7, %8], %cst : memref<?x?xf32>, vector<2x16xf32>
+// CHECK2D: %[[D:.+]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] : vector<2x4xf32>, vector<16x4xf32> into vector<2x16xf32>
+// CHECK2D: %[[R:.+]] = addf %[[D]], %[[E]] : vector<2x16xf32>
+// CHECK2D: vector.transfer_write %[[R]], {{.*}} : vector<2x16xf32>, memref<?x?xf32>
+func @vector_add_contract(%id0 : index, %id1 : index, %A: memref<?x?xf32>,
+ %B: memref<?x?xf32>, %C: memref<?x?xf32>, %D: memref<?x?xf32>) {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %a = vector.transfer_read %A[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
+ %b = vector.transfer_read %B[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
+ %c = vector.transfer_read %C[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
+ %d = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %a, %b, %c : vector<64x4xf32>, vector<64x4xf32> into vector<64x64xf32>
+ %e = vector.transfer_read %D[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
+ %r = addf %d, %e : vector<64x64xf32>
+ vector.transfer_write %r, %C[%c0, %c0] : vector<64x64xf32>, memref<?x?xf32>
+ return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 95d292612c661..9b0a6aea21b66 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -275,8 +275,7 @@ struct TestVectorDistributePatterns
}
}
});
- patterns.add<PointwiseExtractPattern>(ctx);
- populateVectorToVectorTransformationPatterns(patterns);
+ populatePropagateVectorDistributionPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
@@ -339,8 +338,7 @@ struct TestVectorToLoopPatterns
}
return mlir::WalkResult::interrupt();
});
- patterns.add<PointwiseExtractPattern>(ctx);
- populateVectorToVectorTransformationPatterns(patterns);
+ populatePropagateVectorDistributionPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
More information about the Mlir-commits
mailing list