[Mlir-commits] [mlir] 4c901bf - [mlir] Match Arithmetic::ConstantOp and Tensor::ExtractSliceOp.
Okwan Kwon
llvmlistbot at llvm.org
Mon Feb 28 15:09:15 PST 2022
Author: Okwan Kwon
Date: 2022-02-28T23:09:03Z
New Revision: 4c901bf4471932852bc872040788c84c0f26d4a6
URL: https://github.com/llvm/llvm-project/commit/4c901bf4471932852bc872040788c84c0f26d4a6
DIFF: https://github.com/llvm/llvm-project/commit/4c901bf4471932852bc872040788c84c0f26d4a6.diff
LOG: [mlir] Match Arithmetic::ConstantOp and Tensor::ExtractSliceOp.
Add a pattern matcher for ExtractSliceOp when its source is a constant.
The matching heuristics can be governed by the control function since
generating a new constant is not always beneficial.
Differential Revision: https://reviews.llvm.org/D119605
Added:
mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index c8a84be219b2e..7b0d62e502e58 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -103,6 +103,19 @@ Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
Value tensor, Value dest);
+/// Function to control the folding of constant and extract slice
+using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
+
+/// Patterns to fold the extract slice op with its constant operand
+void populateFoldConstantExtractSlicePatterns(
+ RewritePatternSet &patterns,
+ const ControlConstantExtractSliceFusionFn &controlFn =
+ [](ExtractSliceOp op) {
+ // Disable by default because the folding can generate a large
+ // constant tensor, which would affect the compile time and storage.
+ return false;
+ });
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 70aa7b5fe57f6..0c5d182c82d3e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -6,9 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -16,7 +14,6 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -1158,8 +1155,134 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
return success();
}
};
+
+/// Slice elements from `values` into `outValues`. `counts` represents the
+/// numbers of elements to stride in the original values for each dimension.
+/// The output values can be used to construct a DenseElementsAttr.
+template <typename IterTy, typename ElemTy>
+static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
+ ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides,
+ llvm::SmallVectorImpl<ElemTy> *outValues) {
+ assert(offsets.size() == sizes.size());
+ assert(offsets.size() == strides.size());
+ if (offsets.empty())
+ return;
+
+ int64_t offset = offsets.front();
+ int64_t size = sizes.front();
+ int64_t stride = strides.front();
+ if (offsets.size() == 1) {
+ for (int64_t i = 0; i < size; ++i, offset += stride)
+ outValues->push_back(*(values + offset));
+
+ return;
+ }
+
+ for (int64_t i = 0; i < size; ++i, offset += stride) {
+ auto begin = values + offset * counts.front();
+ sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
+ offsets.drop_front(), sizes.drop_front(),
+ strides.drop_front(), outValues);
+ }
+}
+
+/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
+/// operation might introduce more constant data; Users can control their
+/// heuristics by the control function.
+class ConstantOpExtractSliceFolder final
+ : public OpRewritePattern<ExtractSliceOp> {
+public:
+ using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
+
+ ConstantOpExtractSliceFolder(MLIRContext *context,
+ ControlConstantExtractSliceFusionFn controlFn)
+ : OpRewritePattern<ExtractSliceOp>(context),
+ controlFn(std::move(controlFn)) {}
+
+ LogicalResult matchAndRewrite(ExtractSliceOp op,
+ PatternRewriter &rewriter) const override {
+ DenseElementsAttr attr;
+ if (!matchPattern(op.source(), m_Constant(&attr)))
+ return failure();
+
+ // A constant splat is handled by fold().
+ if (attr.isSplat())
+ return failure();
+
+ // Dynamic result shape is not supported.
+ auto sourceType = op.source().getType().cast<ShapedType>();
+ auto resultType = op.result().getType().cast<ShapedType>();
+ if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
+ return failure();
+
+ // Customized control over the folding.
+ if (!controlFn(op))
+ return failure();
+
+ int64_t count = sourceType.getNumElements();
+ if (count == 0)
+ return failure();
+
+ // Check if there are any dynamic parts, which are not supported.
+ auto offsets = extractFromI64ArrayAttr(op.static_offsets());
+ if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
+ return failure();
+ auto sizes = extractFromI64ArrayAttr(op.static_sizes());
+ if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
+ return failure();
+ auto strides = extractFromI64ArrayAttr(op.static_strides());
+ if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
+ return failure();
+
+ // Compute the stride for each dimension.
+ SmallVector<int64_t> counts;
+ ArrayRef<int64_t> shape = sourceType.getShape();
+ counts.reserve(shape.size());
+ for (int64_t v : shape) {
+ count = count / v;
+ counts.push_back(count);
+ }
+
+ // New attribute constructed by the sliced values.
+ DenseElementsAttr newAttr;
+
+ if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
+ SmallVector<APInt> outValues;
+ outValues.reserve(sourceType.getNumElements());
+ sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
+ elems.begin(), counts, offsets, sizes, strides, &outValues);
+ newAttr = DenseElementsAttr::get(resultType, outValues);
+ } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
+ SmallVector<APFloat> outValues;
+ outValues.reserve(sourceType.getNumElements());
+ sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
+ elems.begin(), counts, offsets, sizes, strides, &outValues);
+ newAttr = DenseElementsAttr::get(resultType, outValues);
+ }
+
+ if (newAttr) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
+ return success();
+ }
+
+ return failure();
+ }
+
+private:
+ /// This additionally controls whether the fold happens or not. Users can
+ /// impose their heuristics in the function.
+ ControlConstantExtractSliceFusionFn controlFn;
+};
+
} // namespace
+void mlir::tensor::populateFoldConstantExtractSlicePatterns(
+ RewritePatternSet &patterns,
+ const ControlConstantExtractSliceFusionFn &controlFn) {
+ patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
+}
+
/// Return the canonical type of the result of an extract_slice op.
struct SliceReturnTypeCanonicalizer {
RankedTensorType operator()(ExtractSliceOp op,
@@ -1238,6 +1361,7 @@ OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
return this->source();
if (Value slice = foldExtractAfterInsertSlice(*this))
return slice;
+
return OpFoldResult();
}
diff --git a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
new file mode 100644
index 0000000000000..03c6195d40374
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s
+
+// CHECK-LABEL: func @slice_constant
+// CHECK-NOT: tensor.extract_slice
+// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32>
+// CHECK: return %[[CONST]] : tensor<1x1xf32>
+func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32>
+{
+ %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32>
+ %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32>
+ return %slice : tensor<1x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @slice_constant_3x4
+// CHECK-NOT: tensor.extract_slice
+// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32>
+// CHECK: return %[[CONST]] : tensor<2x2xf32>
+func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
+{
+ %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
+ %slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
+ return %slice : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @slice_constant_3x4_offsets
+// CHECK-NOT: tensor.extract_slice
+// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32>
+// CHECK: return %[[CONST]] : tensor<2x2xf32>
+func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
+{
+ %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
+ %slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
+ return %slice : tensor<2x2xf32>
+}
+
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index c720ca1e3a235..7fdc185946e15 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -41,6 +42,11 @@ struct TestTensorTransforms
*this, "test-split-padding-patterns",
llvm::cl::desc("Test patterns to split tensor.pad ops"),
llvm::cl::init(false)};
+
+ Option<bool> testFoldConstantExtractSlice{
+ *this, "test-fold-constant-extract-slice",
+ llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -50,10 +56,31 @@ static void applySplitPaddingPatterns(FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ tensor::ControlConstantExtractSliceFusionFn controlFn =
+ [](tensor::ExtractSliceOp op) {
+ if (!op.source().hasOneUse())
+ return false;
+
+ auto resultType = op.result().getType().cast<ShapedType>();
+ constexpr int64_t kConstantFoldingMaxNumElements = 1024;
+ if (resultType.getNumElements() > kConstantFoldingMaxNumElements)
+ return false;
+
+ return true;
+ };
+
+ tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
void TestTensorTransforms::runOnOperation() {
FuncOp func = getOperation();
if (testSplitPaddingPatterns)
applySplitPaddingPatterns(func);
+ if (testFoldConstantExtractSlice)
+ applyFoldConstantExtractSlicePatterns(func);
}
namespace mlir {
More information about the Mlir-commits
mailing list