[Mlir-commits] [mlir] afba867 - [mlir][vector] Add fold for ExtractStridedSlice(non-splat ConstantOp)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Nov 25 10:45:17 PST 2022
Author: Jakub Kuderski
Date: 2022-11-25T13:42:56-05:00
New Revision: afba86709fc5c2d9c6b34bd4fedff4ea1deeed23
URL: https://github.com/llvm/llvm-project/commit/afba86709fc5c2d9c6b34bd4fedff4ea1deeed23
DIFF: https://github.com/llvm/llvm-project/commit/afba86709fc5c2d9c6b34bd4fedff4ea1deeed23.diff
LOG: [mlir][vector] Add fold for ExtractStridedSlice(non-splat ConstantOp)
This allows us to better canonicalize/clean-up code created by the Wide
Integer Emulation pass.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D138606
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b71c2a0f06112..af81c156e2d79 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -31,9 +31,13 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
+
+#include <cassert>
#include <numeric>
#include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
@@ -2680,28 +2684,117 @@ class StridedSliceConstantMaskFolder final
};
// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
-class StridedSliceConstantFolder final
+class StridedSliceSplatConstantFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
PatternRewriter &rewriter) const override {
- // Return if 'extractStridedSliceOp' operand is not defined by a
+ // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
// ConstantOp.
- auto constantOp =
- extractStridedSliceOp.getVector().getDefiningOp<arith::ConstantOp>();
- if (!constantOp)
+ Value sourceVector = extractStridedSliceOp.getVector();
+ Attribute vectorCst;
+ if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
return failure();
- auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
- if (!dense)
+
+ auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
+ if (!splat)
+ return failure();
+
+ auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
+ splat.getSplatValue<Attribute>());
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
+ newAttr);
+ return success();
+ }
+};
+
+// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
+// ConstantOp.
+class StridedSliceNonSplatConstantFolder final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
+ // ConstantOp.
+ Value sourceVector = extractStridedSliceOp.getVector();
+ Attribute vectorCst;
+ if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
+ return failure();
+
+ // The splat case is handled by `StridedSliceSplatConstantFolder`.
+ auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
+ if (!dense || dense.isSplat())
+ return failure();
+
+ // TODO: Handle non-unit strides when they become available.
+ if (extractStridedSliceOp.hasNonUnitStrides())
return failure();
- auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
- dense.getSplatValue<Attribute>());
+
+ auto sourceVecTy = sourceVector.getType().cast<VectorType>();
+ ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
+ SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
+
+ VectorType sliceVecTy = extractStridedSliceOp.getType();
+ ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+ int64_t sliceRank = sliceVecTy.getRank();
+
+ // Expand offsets and sizes to match the vector rank.
+ SmallVector<int64_t, 4> offsets(sliceRank, 0);
+ llvm::copy(getI64SubArray(extractStridedSliceOp.getOffsets()),
+ offsets.begin());
+
+ SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
+ llvm::copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
+
+ // Calcualte the slice elements by enumerating all slice positions and
+ // linearizing them. The enumeration order is lexicographic which yields a
+ // sequence of monotonically increasing linearized position indices.
+ auto denseValuesBegin = dense.value_begin<Attribute>();
+ SmallVector<Attribute> sliceValues;
+ sliceValues.reserve(sliceVecTy.getNumElements());
+ SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
+ do {
+ int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
+ assert(linearizedPosition < sourceVecTy.getNumElements() &&
+ "Invalid index");
+ sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
+ } while (succeeded(incPosition(currSlicePosition, sliceShape, offsets)));
+
+ assert(static_cast<int64_t>(sliceValues.size()) ==
+ sliceVecTy.getNumElements() &&
+ "Invalid number of slice elements");
+ auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
newAttr);
return success();
}
+
+private:
+ // Calculate the next `position` in the n-D vector of size `shape`,
+ // applying an offset `offsets`. Modifies the `position` in place.
+ // Returns a failure when `position` becomes the end position.
+ static LogicalResult incPosition(MutableArrayRef<int64_t> position,
+ ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> offsets) {
+ assert(position.size() == shape.size());
+ assert(position.size() == offsets.size());
+ for (auto [posInDim, dimSize, offsetInDim] :
+ llvm::reverse(llvm::zip(position, shape, offsets))) {
+ ++posInDim;
+ if (posInDim < dimSize + offsetInDim)
+ return success();
+
+ // Carry the overflow to the next loop iteration.
+ posInDim = offsetInDim;
+ }
+
+ return failure();
+ }
};
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
@@ -2770,8 +2863,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
- results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
- StridedSliceBroadcast, StridedSliceSplat>(context);
+ results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
+ StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
+ StridedSliceSplat>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index eb1fb247bb269..19a06af5c9f8b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1533,6 +1533,63 @@ func.func @extract_splat_vector_3d_constant() -> (vector<2xi32>, vector<2xi32>,
// -----
+// CHECK-LABEL: func.func @extract_strided_slice_1d_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[1, 2]> : vector<2xi32>
+// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<2> : vector<1xi32>
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<3xi32>, vector<2xi32>, vector<1xi32>
+func.func @extract_strided_slice_1d_constant() -> (vector<3xi32>, vector<2xi32>, vector<1xi32>) {
+ %cst = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+ %a = vector.extract_strided_slice %cst
+ {offsets = [0], sizes = [3], strides = [1]} : vector<3xi32> to vector<3xi32>
+ %b = vector.extract_strided_slice %cst
+ {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
+ %c = vector.extract_strided_slice %cst
+ {offsets = [2], sizes = [1], strides = [1]} : vector<3xi32> to vector<1xi32>
+ return %a, %b, %c : vector<3xi32>, vector<2xi32>, vector<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_strided_slice_2d_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<1x1xi32>
+// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[4, 5\]\]}}> : vector<1x2xi32>
+// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[1, 2\], \[4, 5\]\]}}> : vector<2x2xi32>
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32>
+func.func @extract_strided_slice_2d_constant() -> (vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32>) {
+ %cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32>
+ %a = vector.extract_strided_slice %cst
+ {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x3xi32> to vector<1x1xi32>
+ %b = vector.extract_strided_slice %cst
+ {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
+ %c = vector.extract_strided_slice %cst
+ {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]} : vector<2x3xi32> to vector<2x2xi32>
+ return %a, %b, %c : vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_strided_slice_3d_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[\[8, 9\], \[10, 11\]\]\]}}> : vector<1x2x2xi32>
+// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[\[2, 3\]\]\]}}> : vector<1x1x2xi32>
+// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[\[6, 7\]\], \[\[10, 11\]\]\]}}> : vector<2x1x2xi32>
+// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<11> : vector<1x1x1xi32>
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]]
+func.func @extract_strided_slice_3d_constant() -> (vector<1x2x2xi32>, vector<1x1x2xi32>, vector<2x1x2xi32>, vector<1x1x1xi32>) {
+ %cst = arith.constant dense<[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]]> : vector<3x2x2xi32>
+ %a = vector.extract_strided_slice %cst
+ {offsets = [2], sizes = [1], strides = [1]} : vector<3x2x2xi32> to vector<1x2x2xi32>
+ %b = vector.extract_strided_slice %cst
+ {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<3x2x2xi32> to vector<1x1x2xi32>
+ %c = vector.extract_strided_slice %cst
+ {offsets = [1, 1, 0], sizes = [2, 1, 2], strides = [1, 1, 1]} : vector<3x2x2xi32> to vector<2x1x2xi32>
+ %d = vector.extract_strided_slice %cst
+ {offsets = [2, 1, 1], sizes = [1, 1, 1], strides = [1, 1, 1]} : vector<3x2x2xi32> to vector<1x1x1xi32>
+ return %a, %b, %c, %d : vector<1x2x2xi32>, vector<1x1x2xi32>, vector<2x1x2xi32>, vector<1x1x1xi32>
+}
+
+// -----
+
// CHECK-LABEL: extract_extract_strided
// CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16>
// CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>
More information about the Mlir-commits
mailing list