[Mlir-commits] [mlir] cc83c24 - [mlir][vector] Add canonicalization extract + splat
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 13 08:08:52 PDT 2021
Author: thomasraoux
Date: 2021-10-13T08:08:46-07:00
New Revision: cc83c2444f8a3bdb7f4b92f52a14b53792346057
URL: https://github.com/llvm/llvm-project/commit/cc83c2444f8a3bdb7f4b92f52a14b53792346057
DIFF: https://github.com/llvm/llvm-project/commit/cc83c2444f8a3bdb7f4b92f52a14b53792346057.diff
LOG: [mlir][vector] Add canonicalization extract + splat
Make canonicalization working on broadcast also work on splat op.
Differential Revision: https://reviews.llvm.org/D111690
Added:
Modified:
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index b1168d084b802..343bc3e3e5100 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1094,17 +1094,18 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
return Value();
}
-/// Fold extractOp with scalar result coming from BroadcastOp.
+/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
- if (!broadcastOp)
+ Operation *defOp = extractOp.vector().getDefiningOp();
+ if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return Value();
- if (extractOp.getType() == broadcastOp.getSourceType())
- return broadcastOp.source();
+ Value source = defOp->getOperand(0);
+ if (extractOp.getType() == source.getType())
+ return source;
auto getRank = [](Type type) {
return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
};
- unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
+ unsigned broadcasrSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank < broadcasrSrcRank) {
auto extractPos = extractVector<int64_t>(extractOp.position());
@@ -1112,7 +1113,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
extractPos.erase(
extractPos.begin(),
std::next(extractPos.begin(), extractPos.size() - rankDiff));
- extractOp.setOperand(broadcastOp.source());
+ extractOp.setOperand(source);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp->setAttr(ExtractOp::getPositionAttrName(),
@@ -2259,6 +2260,21 @@ class StridedSliceBroadcast final
}
};
+/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
+class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto splat = op.vector().getDefiningOp<SplatOp>();
+ if (!splat)
+ return failure();
+ rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.input());
+ return success();
+ }
+};
+
} // end anonymous namespace
void ExtractStridedSliceOp::getCanonicalizationPatterns(
@@ -2266,7 +2282,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
- StridedSliceBroadcast>(context);
+ StridedSliceBroadcast, StridedSliceSplat>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d6e0a16ca4f69..6c232cc0f7ab5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -462,6 +462,17 @@ func @fold_extract_broadcast(%a : f32) -> f32 {
// -----
+// CHECK-LABEL: fold_extract_splat
+// CHECK-SAME: %[[A:.*]]: f32
+// CHECK: return %[[A]] : f32
+func @fold_extract_splat(%a : f32) -> f32 {
+ %b = splat %a : vector<1x2x4xf32>
+ %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+ return %r : f32
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_vector
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
// CHECK: return %[[A]] : vector<4xf32>
@@ -1047,3 +1058,16 @@ func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<
// CHECK: return %[[SOURCE]]
return %0: vector<16x16xf16>
}
+
+// -----
+
+// CHECK-LABEL: extract_strided_splat
+// CHECK: %[[B:.*]] = splat %{{.*}} : vector<2x4xf16>
+// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
+func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
+ %0 = splat %arg0 : vector<16x4xf16>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
+ vector<16x4xf16> to vector<2x4xf16>
+ return %1 : vector<2x4xf16>
+}
More information about the Mlir-commits
mailing list