[Mlir-commits] [mlir] 4f5eb53 - Revert "[mlir] Fold Arithmetic::ConstantOp and Tensor::ExtractSliceOp."
Okwan Kwon
llvmlistbot at llvm.org
Mon Feb 28 11:14:13 PST 2022
Author: Okwan Kwon
Date: 2022-02-28T19:14:05Z
New Revision: 4f5eb53e68b1da47a211a97bd2fe4ea26b590e58
URL: https://github.com/llvm/llvm-project/commit/4f5eb53e68b1da47a211a97bd2fe4ea26b590e58
DIFF: https://github.com/llvm/llvm-project/commit/4f5eb53e68b1da47a211a97bd2fe4ea26b590e58.diff
LOG: Revert "[mlir] Fold Arithmetic::ConstantOp and Tensor::ExtractSliceOp."
This reverts commit 3104994104f0c2f274acf5e01eb6cc82e9cca06b.
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index a4f78750368de..e6267e9cf02e5 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -9,7 +9,6 @@
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@@ -21,19 +20,6 @@ namespace tensor {
void populateSplitPaddingPatterns(RewritePatternSet &patterns,
PatternBenefit baseBenefit = 1);
-/// Function to control the folding of constant and extract slice
-using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
-
-/// Patterns to fold the extract slice op with its constant operand
-void populateFoldConstantExtractSlicePatterns(
- RewritePatternSet &patterns,
- const ControlConstantExtractSliceFusionFn &controlFn =
- [](ExtractSliceOp op) {
- // Disable by default because the folding can generate a large
- // constant tensor, which would affect the compile time and storage.
- return false;
- });
-
} // namespace tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d6a4bb460a7f6..70aa7b5fe57f6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -6,14 +6,17 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -1155,134 +1158,8 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
return success();
}
};
-
-/// Slice elements from `values` into `outValues`. `counts` represents the
-/// numbers of elements to stride in the original values for each dimension.
-/// The output values can be used to construct a DenseElementsAttr.
-template <typename IterTy, typename ElemTy>
-static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
- ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides,
- llvm::SmallVectorImpl<ElemTy> *outValues) {
- assert(offsets.size() == sizes.size());
- assert(offsets.size() == strides.size());
- if (offsets.empty())
- return;
-
- int64_t offset = offsets.front();
- int64_t size = sizes.front();
- int64_t stride = strides.front();
- if (offsets.size() == 1) {
- for (int64_t i = 0; i < size; ++i, offset += stride)
- outValues->push_back(*(values + offset));
-
- return;
- }
-
- for (int64_t i = 0; i < size; ++i, offset += stride) {
- auto begin = values + offset * counts.front();
- sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
- offsets.drop_front(), sizes.drop_front(),
- strides.drop_front(), outValues);
- }
-}
-
-/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
-/// operation might introduce more constant data; Users can control their
-/// heuristics by the control function.
-class ConstantOpExtractSliceFolder final
- : public OpRewritePattern<ExtractSliceOp> {
-public:
- using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
-
- ConstantOpExtractSliceFolder(MLIRContext *context,
- ControlConstantExtractSliceFusionFn controlFn)
- : OpRewritePattern<ExtractSliceOp>(context),
- controlFn(std::move(controlFn)) {}
-
- LogicalResult matchAndRewrite(ExtractSliceOp op,
- PatternRewriter &rewriter) const override {
- DenseElementsAttr attr;
- if (!matchPattern(op.source(), m_Constant(&attr)))
- return failure();
-
- // A constant splat is handled by fold().
- if (attr.isSplat())
- return failure();
-
- // Dynamic result shape is not supported.
- auto sourceType = op.source().getType().cast<ShapedType>();
- auto resultType = op.result().getType().cast<ShapedType>();
- if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
- return failure();
-
- // Customized control over the folding.
- if (!controlFn(op))
- return failure();
-
- int64_t count = sourceType.getNumElements();
- if (count == 0)
- return failure();
-
- // Check if there are any dynamic parts, which are not supported.
- auto offsets = extractFromI64ArrayAttr(op.static_offsets());
- if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
- return failure();
- auto sizes = extractFromI64ArrayAttr(op.static_sizes());
- if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
- return failure();
- auto strides = extractFromI64ArrayAttr(op.static_strides());
- if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
- return failure();
-
- // Compute the stride for each dimension.
- SmallVector<int64_t> counts;
- ArrayRef<int64_t> shape = sourceType.getShape();
- counts.reserve(shape.size());
- for (int64_t v : shape) {
- count = count / v;
- counts.push_back(count);
- }
-
- // New attribute constructed by the sliced values.
- DenseElementsAttr newAttr;
-
- if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
- SmallVector<APInt> outValues;
- outValues.reserve(sourceType.getNumElements());
- sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
- elems.begin(), counts, offsets, sizes, strides, &outValues);
- newAttr = DenseElementsAttr::get(resultType, outValues);
- } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
- SmallVector<APFloat> outValues;
- outValues.reserve(sourceType.getNumElements());
- sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
- elems.begin(), counts, offsets, sizes, strides, &outValues);
- newAttr = DenseElementsAttr::get(resultType, outValues);
- }
-
- if (newAttr) {
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
- return success();
- }
-
- return failure();
- }
-
-private:
- /// This additionally controls whether the fold happens or not. Users can
- /// impose their heuristics in the function.
- ControlConstantExtractSliceFusionFn controlFn;
-};
-
} // namespace
-void mlir::tensor::populateFoldConstantExtractSlicePatterns(
- RewritePatternSet &patterns,
- const ControlConstantExtractSliceFusionFn &controlFn) {
- patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
-}
-
/// Return the canonical type of the result of an extract_slice op.
struct SliceReturnTypeCanonicalizer {
RankedTensorType operator()(ExtractSliceOp op,
@@ -1361,7 +1238,6 @@ OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
return this->source();
if (Value slice = foldExtractAfterInsertSlice(*this))
return slice;
-
return OpFoldResult();
}
diff --git a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
deleted file mode 100644
index 03c6195d40374..0000000000000
--- a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
+++ /dev/null
@@ -1,39 +0,0 @@
-// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s
-
-// CHECK-LABEL: func @slice_constant
-// CHECK-NOT: tensor.extract_slice
-// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32>
-// CHECK: return %[[CONST]] : tensor<1x1xf32>
-func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32>
-{
- %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32>
- %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32>
- return %slice : tensor<1x1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @slice_constant_3x4
-// CHECK-NOT: tensor.extract_slice
-// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32>
-// CHECK: return %[[CONST]] : tensor<2x2xf32>
-func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
-{
- %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
- %slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
- return %slice : tensor<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @slice_constant_3x4_offsets
-// CHECK-NOT: tensor.extract_slice
-// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32>
-// CHECK: return %[[CONST]] : tensor<2x2xf32>
-func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
-{
- %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
- %slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
- return %slice : tensor<2x2xf32>
-}
-
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 4d947ef3ee534..c720ca1e3a235 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -41,11 +41,6 @@ struct TestTensorTransforms
*this, "test-split-padding-patterns",
llvm::cl::desc("Test patterns to split tensor.pad ops"),
llvm::cl::init(false)};
-
- Option<bool> testFoldConstantExtractSlice{
- *this, "test-fold-constant-extract-slice",
- llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
- llvm::cl::init(false)};
};
} // namespace
@@ -55,31 +50,10 @@ static void applySplitPaddingPatterns(FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
-static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) {
- RewritePatternSet patterns(funcOp.getContext());
- tensor::ControlConstantExtractSliceFusionFn controlFn =
- [](tensor::ExtractSliceOp op) {
- if (!op.source().hasOneUse())
- return false;
-
- auto resultType = op.result().getType().cast<ShapedType>();
- constexpr int64_t kConstantFoldingMaxNumElements = 1024;
- if (resultType.getNumElements() > kConstantFoldingMaxNumElements)
- return false;
-
- return true;
- };
-
- tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
void TestTensorTransforms::runOnOperation() {
FuncOp func = getOperation();
if (testSplitPaddingPatterns)
applySplitPaddingPatterns(func);
- if (testFoldConstantExtractSlice)
- applyFoldConstantExtractSlicePatterns(func);
}
namespace mlir {
More information about the Mlir-commits
mailing list