[Mlir-commits] [mlir] e98e13a - [mlir][Vector] Fold ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 3 19:27:02 PDT 2022


Author: jacquesguan
Date: 2022-07-04T10:06:06+08:00
New Revision: e98e13ac8f385568f30d2ded7ea146a1f2a02d1f

URL: https://github.com/llvm/llvm-project/commit/e98e13ac8f385568f30d2ded7ea146a1f2a02d1f
DIFF: https://github.com/llvm/llvm-project/commit/e98e13ac8f385568f30d2ded7ea146a1f2a02d1f.diff

LOG: [mlir][Vector] Fold ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

This patch folds ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

Differential Revision: https://reviews.llvm.org/D128969

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6f68f83ed05f9..9a29825eda506 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -487,6 +487,7 @@ def Vector_ShuffleOp :
   }];
   let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Vector_ExtractElementOp :

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 38f38f8867059..6c1ba2161b83c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1882,6 +1882,36 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
   return DenseElementsAttr::get(getVectorType(), results);
 }
 
+namespace {
+
+/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
+class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
+public:
+  using OpRewritePattern<ShuffleOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShuffleOp op,
+                                PatternRewriter &rewriter) const override {
+    auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
+    auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
+
+    if (!v1Splat || !v2Splat)
+      return failure();
+
+    if (v1Splat.getInput() != v2Splat.getInput())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
+    return success();
+  }
+};
+
+} // namespace
+
+void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                            MLIRContext *context) {
+  results.add<ShuffleSplat>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // InsertElementOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 515a2d1726b6f..e7747c736867f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1655,3 +1655,17 @@ func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf3
         : vector<2x4xf32> into vector<8x16xf32>
   return %1 : vector<8x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @shuffle_splat
+//  CHECK-SAME:   (%[[ARG:.*]]: i32)
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>
+//  CHECK-NEXT:   return %[[SPLAT]] : vector<4xi32>
+func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
+  %v0 = vector.splat %x : vector<4xi32>
+  %v1 = vector.splat %x : vector<2xi32>
+  %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
+  return %shuffle : vector<4xi32>
+}
+


        


More information about the Mlir-commits mailing list