[Mlir-commits] [mlir] f69175b - [mlir][vector] Add unrolling pattern for multidim_reduce op
Thomas Raoux
llvmlistbot at llvm.org
Mon Mar 14 08:22:50 PDT 2022
Author: Thomas Raoux
Date: 2022-03-14T15:22:24Z
New Revision: f69175b1e6aba63ad349644256c58c0e3b3316f1
URL: https://github.com/llvm/llvm-project/commit/f69175b1e6aba63ad349644256c58c0e3b3316f1
DIFF: https://github.com/llvm/llvm-project/commit/f69175b1e6aba63ad349644256c58c0e3b3316f1.diff
LOG: [mlir][vector] Add unrolling pattern for multidim_reduce op
Implement the vectorLoopUnroll interface for MultiDimReduceOp and add a
pattern to do the unrolling following the same interface other vector
unroll patterns.
Differential Revision: https://reviews.llvm.org/D121263
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
mlir/test/Dialect/Vector/vector-unroll-options.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2e7f06903824f..b0012924e5bae 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -314,7 +314,9 @@ def Vector_MultiDimReductionOp :
Vector_Op<"multi_reduction", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
- DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface,
+ ["getShapeForUnroll"]>]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVector:$source,
I64ArrayAttr:$reduction_dims)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6547e46d5418..61b5c7aac9f6f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -371,6 +371,10 @@ OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getSourceVectorType().getShape());
+}
+
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 1da965e3fafc9..7ec69f006dba5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -207,23 +207,23 @@ struct UnrollTransferWritePattern
vector::UnrollVectorOptions options;
};
-struct UnrollContractionPattern
- : public OpRewritePattern<vector::ContractionOp> {
- struct OffsetMapInfo {
- static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
+struct OffsetMapInfo {
+ static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
- static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
+ static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
- static unsigned getHashValue(const SmallVector<int64_t> &v) {
- return static_cast<unsigned>(
- llvm::hash_combine_range(v.begin(), v.end()));
- }
+ static unsigned getHashValue(const SmallVector<int64_t> &v) {
+ return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
+ }
- static bool isEqual(const SmallVector<int64_t> &lhs,
- const SmallVector<int64_t> &rhs) {
- return lhs == rhs;
- }
- };
+ static bool isEqual(const SmallVector<int64_t> &lhs,
+ const SmallVector<int64_t> &rhs) {
+ return lhs == rhs;
+ }
+};
+
+struct UnrollContractionPattern
+ : public OpRewritePattern<vector::ContractionOp> {
UnrollContractionPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
: OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
@@ -320,6 +320,74 @@ struct UnrollContractionPattern
vector::UnrollVectorOptions options;
};
+struct UnrollMultiReductionPattern
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ UnrollMultiReductionPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options)
+ : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
+ PatternRewriter &rewriter) const override {
+ Optional<SmallVector<int64_t, 4>> targetShape =
+ getTargetShape(options, reductionOp);
+ if (!targetShape)
+ return failure();
+ SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
+ SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+ llvm::MapVector<
+ SmallVector<int64_t>, Value,
+ llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
+ accCache;
+ // Compute shape ratio of 'shape' and 'sizes'.
+ int64_t sliceCount = computeMaxLinearIndex(ratio);
+ Location loc = reductionOp.getLoc();
+ for (int64_t i = 0; i < sliceCount; i++) {
+ SmallVector<int64_t, 4> offsets =
+ getVectorOffset(originalSize, *targetShape, i);
+
+ SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
+ Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides);
+
+ SmallVector<int64_t> dstShape;
+ SmallVector<int64_t> destOffset;
+ for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
+ if (!reductionOp.isReducedDim(i)) {
+ destOffset.push_back(offsets[i]);
+ dstShape.push_back((*targetShape)[i]);
+ }
+ }
+ auto targetType = VectorType::get(
+ dstShape, reductionOp.getSourceVectorType().getElementType());
+ Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
+ slicedOperand, targetType);
+ Value result = newOp->getResult(0);
+ // Save the accumulated value until all the loops are unrolled since
+ // reduction loop keeps updating the accumulator.
+ auto accIt = accCache.find(destOffset);
+ if (accIt != accCache.end())
+ result = makeArithReduction(rewriter, loc, reductionOp.kind(), result,
+ accIt->second);
+ accCache[destOffset] = result;
+ }
+ // Assemble back the accumulator into a single vector.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, reductionOp.getDestType(),
+ rewriter.getZeroAttr(reductionOp.getDestType()));
+ for (const auto &it : accCache) {
+ SmallVector<int64_t> dstStrides(it.first.size(), 1);
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, it.second, result, it.first, dstStrides);
+ }
+ rewriter.replaceOp(reductionOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
struct UnrollElementwisePattern : public RewritePattern {
UnrollElementwisePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
@@ -568,8 +636,8 @@ struct TransferWriteInsertPattern
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
- UnrollContractionPattern, UnrollElementwisePattern>(
- patterns.getContext(), options);
+ UnrollContractionPattern, UnrollElementwisePattern,
+ UnrollMultiReductionPattern>(patterns.getContext(), options);
}
void mlir::vector::populatePropagateVectorDistributionPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 581039c48cb73..dd1a6fd781e47 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -80,3 +80,29 @@ func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf32>)
}
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
+
+func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
+ %0 = vector.multi_reduction #vector.kind<add>, %v [1] : vector<4x6xf32> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction
+// CHECK: %[[V0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK: %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]] [1] : vector<2x2xf32> to vector<2xf32>
+// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK: %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]] [1] : vector<2x2xf32> to vector<2xf32>
+// CHECK: %[[A0:.*]] = arith.addf %[[R1]], %[[R0]] : vector<2xf32>
+// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK: %[[R2:.*]] = vector.multi_reduction <add>, %5 [1] : vector<2x2xf32> to vector<2xf32>
+// CHECK: %[[A1:.*]] = arith.addf %[[R2]], %[[A0]] : vector<2xf32>
+// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK: %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]] [1] : vector<2x2xf32> to vector<2xf32>
+// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK: %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]] [1] : vector<2x2xf32> to vector<2xf32>
+// CHECK: %[[A2:.*]] = arith.addf %[[R4]], %[[R3]] : vector<2xf32>
+// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+// CHECK: %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]] [1] : vector<2x2xf32> to vector<2xf32>
+// CHECK: %[[A3:.*]] = arith.addf %[[R5]], %[[A2]] : vector<2xf32>
+// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: return %[[V2]] : vector<4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 67e33d3aa0b4a..2bf5e3f1a8e7d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -265,7 +265,8 @@ struct TestVectorUnrollingPatterns
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
- return success(isa<arith::AddFOp, vector::FMAOp>(op));
+ return success(isa<arith::AddFOp, vector::FMAOp,
+ vector::MultiDimReductionOp>(op));
}));
if (unrollBasedOnType) {
More information about the Mlir-commits
mailing list