[Mlir-commits] [mlir] e98e13a - [mlir][Vector] Fold ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X).
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jul 3 19:27:02 PDT 2022
Author: jacquesguan
Date: 2022-07-04T10:06:06+08:00
New Revision: e98e13ac8f385568f30d2ded7ea146a1f2a02d1f
URL: https://github.com/llvm/llvm-project/commit/e98e13ac8f385568f30d2ded7ea146a1f2a02d1f
DIFF: https://github.com/llvm/llvm-project/commit/e98e13ac8f385568f30d2ded7ea146a1f2a02d1f.diff
LOG: [mlir][Vector] Fold ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X).
This patch folds ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X).
Differential Revision: https://reviews.llvm.org/D128969
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6f68f83ed05f9..9a29825eda506 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -487,6 +487,7 @@ def Vector_ShuffleOp :
}];
let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
def Vector_ExtractElementOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 38f38f8867059..6c1ba2161b83c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1882,6 +1882,36 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
return DenseElementsAttr::get(getVectorType(), results);
}
+namespace {
+
+/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
+class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
+public:
+ using OpRewritePattern<ShuffleOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
+ auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
+
+ if (!v1Splat || !v2Splat)
+ return failure();
+
+ if (v1Splat.getInput() != v2Splat.getInput())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
+ return success();
+ }
+};
+
+} // namespace
+
+void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ShuffleSplat>(context);
+}
+
//===----------------------------------------------------------------------===//
// InsertElementOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 515a2d1726b6f..e7747c736867f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1655,3 +1655,17 @@ func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf3
: vector<2x4xf32> into vector<8x16xf32>
return %1 : vector<8x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_splat
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>
+// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
+func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
+ %v0 = vector.splat %x : vector<4xi32>
+ %v1 = vector.splat %x : vector<2xi32>
+ %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
+ return %shuffle : vector<4xi32>
+}
+
More information about the Mlir-commits
mailing list