[Mlir-commits] [mlir] 7417541 - [mlir][vector] Add canonicalization for extract/insert -> shapecast

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 3 10:45:08 PDT 2021


Author: thomasraoux
Date: 2021-05-03T10:41:15-07:00
New Revision: 7417541fd8d764c42e5c7f3647e73ae6913b0fd7

URL: https://github.com/llvm/llvm-project/commit/7417541fd8d764c42e5c7f3647e73ae6913b0fd7
DIFF: https://github.com/llvm/llvm-project/commit/7417541fd8d764c42e5c7f3647e73ae6913b0fd7.diff

LOG: [mlir][vector] Add canonicalization for extract/insert -> shapecast

Differential Revision: https://reviews.llvm.org/D101643

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index d08c73ff2f4d6..3d83bcd0aa57a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -519,6 +519,7 @@ def Vector_ExtractOp :
       return vector().getType().cast<VectorType>();
     }
   }];
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
@@ -763,6 +764,7 @@ def Vector_InsertOp :
       return dest().getType().cast<VectorType>();
     }
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Vector_InsertSlicesOp :

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 9fd9e1e40866a..2958088b258df 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1142,6 +1142,33 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
   return OpFoldResult();
 }
 
+namespace {
+
+// If extractOp is only removing unit dimensions it can be transformed to a
+// shapecast.
+class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
+    if (!dstVecType || extractOp.getVectorType().getNumElements() !=
+                           dstVecType.getNumElements())
+      return failure();
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
+                                             extractOp.vector());
+    return success();
+  }
+};
+
+} // namespace
+
+void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                            MLIRContext *context) {
+  results.add<ExtractToShapeCast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractSlicesOp
 //===----------------------------------------------------------------------===//
@@ -1536,6 +1563,33 @@ static LogicalResult verify(InsertOp op) {
   return success();
 }
 
+namespace {
+
+// If insertOp is only inserting unit dimensions it can be transformed to a
+// shapecast.
+class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
+    if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
+                           srcVecType.getNumElements())
+      return failure();
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(
+        insertOp, insertOp.getDestVectorType(), insertOp.source());
+    return success();
+  }
+};
+
+} // namespace
+
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<InsertToShapeCast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // InsertSlicesOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6d25e40ace3e4..0806e1df9ee82 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -504,16 +504,18 @@ func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
 //       CHECK:   %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
 //       CHECK:   %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32>
 //       CHECK:   %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32>
-//       CHECK:   return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32>
+//       CHECK:   return %[[R0]], %[[R1]], %[[R2]], %[[A1]] : f32, vector<2xf32>, vector<4x2xf32>, vector<8x4x2xf32>
 func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
                              %arg1 : vector<8x4x2xf32>)
-  -> (f32, vector<2xf32>, vector<4x2xf32>) {
+  -> (f32, vector<2xf32>, vector<4x2xf32>, vector<8x4x2xf32>) {
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
   %1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32>
+  %2 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<1x8x4x2xf32>
   %r1 = vector.extract %0[4, 1] : vector<15x2xf32>
   %r2 = vector.extract %0[5] : vector<15x2xf32>
   %r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32>
-  return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32>
+  %r4 = vector.extract %2[0] : vector<1x8x4x2xf32>
+  return %r1, %r2, %r3, %r4 : f32, vector<2xf32>, vector<4x2xf32>, vector<8x4x2xf32>
 }
 
 // -----
@@ -932,3 +934,17 @@ func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
     vector<1x4xf32>, tensor<4x4xf32>
   return %w2 : tensor<4x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_shapecast
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
+  %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+  %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
+  %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+  return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}


        


More information about the Mlir-commits mailing list