[Mlir-commits] [mlir] 7418688 - [mlir][vector] Add more vector Ops canonicalization
Thomas Raoux
llvmlistbot at llvm.org
Wed Dec 23 11:25:38 PST 2020
Author: Thomas Raoux
Date: 2020-12-23T11:25:01-08:00
New Revision: 74186880ba99b37c0375e9d87df818beee8b4ff2
URL: https://github.com/llvm/llvm-project/commit/74186880ba99b37c0375e9d87df818beee8b4ff2
DIFF: https://github.com/llvm/llvm-project/commit/74186880ba99b37c0375e9d87df818beee8b4ff2.diff
LOG: [mlir][vector] Add more vector Ops canonicalization
Add canonicalization for BroadcastOp, ExtractStrideSlicesOp and ShapeCastOp
Differential Revision: https://reviews.llvm.org/D93120
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 13aba2076ee9..e031f87cfb8e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -271,6 +271,7 @@ def Vector_BroadcastOp :
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def Vector_ShuffleOp :
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index a3ad355d30b2..539e00d58dbf 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1110,6 +1110,36 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+namespace {
+
+// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
+// the degenerated case where the broadcast only adds dimensions of size 1 it
+// can be replaced by a ShapeCastOp. This canonicalization checks if the total
+// number of elements is the same before and after the broadcast to detect if
+// the only change in the vector type are new dimensions of size 1.
+class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
+public:
+ using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
+ if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
+ srcVecType.getNumElements())
+ return failure();
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(
+ broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
+ return success();
+ }
+};
+
+} // namespace
+
+void BroadcastOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<BroadcastToShapeCast>(context);
+}
+
//===----------------------------------------------------------------------===//
// ShuffleOp
//===----------------------------------------------------------------------===//
@@ -1768,7 +1798,8 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
namespace {
-// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
+// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
+// ConstantMaskOp.
class StridedSliceConstantMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
@@ -1847,14 +1878,70 @@ class StridedSliceConstantFolder final
}
};
+// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
+static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
+ unsigned dropFront = 0,
+ unsigned dropBack = 0) {
+ assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
+ auto range = arrayAttr.getAsRange<IntegerAttr>();
+ SmallVector<int64_t, 4> res;
+ res.reserve(arrayAttr.size() - dropFront - dropBack);
+ for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
+ it != eit; ++it)
+ res.push_back((*it).getValue().getSExtValue());
+ return res;
+}
+
+// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
+// BroadcastOp(ExtractStrideSliceOp).
+class StridedSliceBroadcast final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
+ if (!broadcast)
+ return failure();
+ auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
+ unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
+ auto dstVecType = op.getType().cast<VectorType>();
+ unsigned dstRank = dstVecType.getRank();
+ unsigned rankDiff = dstRank - srcRrank;
+ // Check if the most inner dimensions of the source of the broacast are the
+ // same as the destination of the extract. If this is the case we can just
+ // use a broadcast as the original dimensions are untouched.
+ bool lowerDimMatch = true;
+ for (unsigned i = 0; i < srcRrank; i++) {
+ if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
+ lowerDimMatch = false;
+ break;
+ }
+ }
+ Value source = broadcast.source();
+ if (!lowerDimMatch) {
+ // The inner dimensions don't match, it means we need to extract from the
+ // source of the orignal broadcast and then broadcast the extracted value.
+ source = rewriter.create<ExtractStridedSliceOp>(
+ op->getLoc(), source,
+ getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
+ getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
+ getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
+ }
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
+ return success();
+ }
+};
+
} // end anonymous namespace
void ExtractStridedSliceOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
- results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
- context);
+ results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
+ StridedSliceBroadcast>(context);
}
//===----------------------------------------------------------------------===//
@@ -2652,10 +2739,12 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
return source();
// Canceling shape casts.
- if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
+ if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
if (result().getType() == otherOp.source().getType())
return otherOp.source();
-
+ setOperand(otherOp.source());
+ return getResult();
+ }
return {};
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f07285d7d98c..f94c3bcce5be 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -613,4 +613,51 @@ func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
}
+// -----
+
+// CHECK-LABEL: extract_strided_broadcast
+// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<4xf16> to vector<2x4xf16>
+// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
+func @extract_strided_broadcast(%arg0: vector<4xf16>) -> vector<2x4xf16> {
+ %0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} :
+ vector<16x4xf16> to vector<2x4xf16>
+ return %1 : vector<2x4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_broadcast2
+// CHECK: %[[E:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf16> to vector<2xf16>
+// CHECK-NEXT: %[[B:.*]] = vector.broadcast %[[E]] : vector<2xf16> to vector<2x2xf16>
+// CHECK-NEXT: return %[[B]] : vector<2x2xf16>
+func @extract_strided_broadcast2(%arg0: vector<4xf16>) -> vector<2x2xf16> {
+ %0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
+ vector<16x4xf16> to vector<2x2xf16>
+ return %1 : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: consecutive_shape_cast
+// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
+// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
+func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
+ %0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
+ %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
+ return %1 : vector<4x4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_to_shapecast
+// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
+// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16>
+func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
+ %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
+ return %0 : vector<1x4x4xf16>
+}
More information about the Mlir-commits
mailing list