[Mlir-commits] [mlir] 39c8065 - [mlir][vector] Convert extract_strided_slice to extract & insert chain
Lei Zhang
llvmlistbot at llvm.org
Wed Nov 9 16:42:14 PST 2022
Author: Lei Zhang
Date: 2022-11-09T19:42:07-05:00
New Revision: 39c80656fef68fcaf9707857bf67f643378e6cc8
URL: https://github.com/llvm/llvm-project/commit/39c80656fef68fcaf9707857bf67f643378e6cc8
DIFF: https://github.com/llvm/llvm-project/commit/39c80656fef68fcaf9707857bf67f643378e6cc8.diff
LOG: [mlir][vector] Convert extract_strided_slice to extract & insert chain
This is useful for breaking down extract_strided_slice and potentially
cancel with other extract / insert ops before or after.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D137471
Added:
mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index e7169b60285a9..15756617792ac 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -285,6 +285,17 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
void populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
+/// ops into a chain of Extract ops to extract each element from the source, and
+/// then a chain of Insert ops to insert to the target vector.
+///
+/// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
+/// `controlFn` returns true. Otherwise runs on ops.
+void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ RewritePatternSet &patterns,
+ std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
+ PatternBenefit benefit = 1);
+
/// Populate `patterns` with the following patterns.
///
/// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ad6cf85b62533..313a3f9a9c090 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::vector;
@@ -231,6 +232,53 @@ class Convert1DExtractStridedSliceIntoShuffle
}
};
+/// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
+/// to extract each element from the source, and then a chain of Insert ops
+/// to insert to the target vector.
+class Convert1DExtractStridedSliceIntoExtractInsertChain final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ Convert1DExtractStridedSliceIntoExtractInsertChain(
+ MLIRContext *context,
+ std::function<bool(ExtractStridedSliceOp)> controlFn,
+ PatternBenefit benefit)
+ : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ if (controlFn && !controlFn(op))
+ return failure();
+
+ // Only handle 1-D cases.
+ if (op.getOffsets().getValue().size() != 1)
+ return failure();
+
+ int64_t offset =
+ op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size =
+ op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t stride =
+ op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+
+ Location loc = op.getLoc();
+ SmallVector<Value> elements;
+ elements.reserve(size);
+ for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
+ elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(op.getType()));
+ for (int64_t i = 0; i < size; ++i)
+ result = rewriter.create<InsertOp>(loc, elements[i], result, i);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+private:
+ std::function<bool(ExtractStridedSliceOp)> controlFn;
+};
+
/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
@@ -285,14 +333,22 @@ class DecomposeNDExtractStridedSlice
}
};
-void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
}
+void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ RewritePatternSet &patterns,
+ std::function<bool(ExtractStridedSliceOp)> controlFn,
+ PatternBenefit benefit) {
+ patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>(
+ patterns.getContext(), std::move(controlFn), benefit);
+}
+
/// Populate the given list with patterns that convert from Vector to LLVM.
-void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
+void vector::populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns, PatternBenefit benefit) {
populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
benefit);
diff --git a/mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir b/mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir
new file mode 100644
index 0000000000000..ca14dee32255c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt -split-input-file -test-vector-extract-strided-slice-lowering %s | FileCheck %s
+
+// CHECK-LABEL: func.func @extract_strided_slice_1D
+// CHECK-SAME: (%[[INPUT:.+]]: vector<8xf16>)
+func.func @extract_strided_slice_1D(%input: vector<8xf16>) -> vector<4xf16> {
+ %0 = vector.extract_strided_slice %input {offsets = [1], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+ return %0: vector<4xf16>
+}
+
+// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf16>
+// CHECK: %[[E0:.+]] = vector.extract %[[INPUT]][1] : vector<8xf16>
+// CHECK: %[[E1:.+]] = vector.extract %[[INPUT]][2] : vector<8xf16>
+// CHECK: %[[E2:.+]] = vector.extract %[[INPUT]][3] : vector<8xf16>
+// CHECK: %[[E3:.+]] = vector.extract %[[INPUT]][4] : vector<8xf16>
+// CHECK: %[[I0:.+]] = vector.insert %[[E0]], %[[INIT]] [0] : f16 into vector<4xf16>
+// CHECK: %[[I1:.+]] = vector.insert %[[E1]], %[[I0]] [1] : f16 into vector<4xf16>
+// CHECK: %[[I2:.+]] = vector.insert %[[E2]], %[[I1]] [2] : f16 into vector<4xf16>
+// CHECK: %[[I3:.+]] = vector.insert %[[E3]], %[[I2]] [3] : f16 into vector<4xf16>
+// CHECK: return %[[I3]]
+
+
+// -----
+
+// CHECK-LABEL: func.func @extract_strided_slice_2D
+func.func @extract_strided_slice_2D(%input: vector<1x8xf16>) -> vector<1x4xf16> {
+ // CHECK: vector.extract_strided_slice
+ %0 = vector.extract_strided_slice %input {offsets = [0, 1], sizes = [1, 4], strides = [1, 1]} : vector<1x8xf16> to vector<1x4xf16>
+ return %0: vector<1x4xf16>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index de29fc2a66423..4f44d43cdb317 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -785,6 +786,26 @@ struct TestVectorDistribution
}
};
+struct TestVectorExtractStridedSliceLowering
+ : public PassWrapper<TestVectorExtractStridedSliceLowering,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorExtractStridedSliceLowering)
+
+ StringRef getArgument() const final {
+ return "test-vector-extract-strided-slice-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering patterns that converts vector.extract_strided_slice "
+ "into a chain of vector.extract and vector.insert ops";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
} // namespace
namespace mlir {
@@ -819,6 +840,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorScanLowering>();
PassRegistration<TestVectorDistribution>();
+
+ PassRegistration<TestVectorExtractStridedSliceLowering>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list