[Mlir-commits] [mlir] 813bfff - [mlir][VectorOps] Adds canonicalization rewrite patterns for vector ShapeCastOp.

Andy Davis llvmlistbot at llvm.org
Tue Feb 11 13:11:53 PST 2020


Author: Andy Davis
Date: 2020-02-11T13:11:45-08:00
New Revision: 813bfffec34b87d32c9c834718f660afb5275bc8

URL: https://github.com/llvm/llvm-project/commit/813bfffec34b87d32c9c834718f660afb5275bc8
DIFF: https://github.com/llvm/llvm-project/commit/813bfffec34b87d32c9c834718f660afb5275bc8.diff

LOG: [mlir][VectorOps] Adds canonicalization rewrite patterns for vector ShapeCastOp.

Summary:
Adds two rewrite patterns for the vector ShapeCastOp.
*) ShapeCastOp decomposer: decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps each on vector types.
*) ShapeCastOp folder: folds canceling shape cast ops (e.g. shape_cast A -> B followed by shape_cast B -> A) away.

Reviewers: nicolasvasilache, aartbik

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74327

Added: 
    

Modified: 
    mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
    mlir/test/Dialect/VectorOps/vector-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 440c7707ce75..8bdeb92afe5e 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -646,6 +646,90 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
   }
 };
 
+/// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
+/// on vector types.
+struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                     PatternRewriter &rewriter) const override {
+    // Check if 'shapeCastOp' has tuple source/result type.
+    auto sourceTupleType =
+        shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
+    auto resultTupleType =
+        shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
+    if (!sourceTupleType || !resultTupleType)
+      return matchFailure();
+    assert(sourceTupleType.size() == resultTupleType.size());
+
+    // Create single-vector ShapeCastOp for each source tuple element.
+    Location loc = shapeCastOp.getLoc();
+    SmallVector<Value, 8> resultElements;
+    resultElements.reserve(resultTupleType.size());
+    for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) {
+      auto sourceElement = rewriter.create<vector::TupleGetOp>(
+          loc, sourceTupleType.getType(i), shapeCastOp.source(),
+          rewriter.getI64IntegerAttr(i));
+      resultElements.push_back(rewriter.create<vector::ShapeCastOp>(
+          loc, resultTupleType.getType(i), sourceElement));
+    }
+
+    // Replace 'shapeCastOp' with tuple of 'resultElements'.
+    rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
+                                                 resultElements);
+    return matchSuccess();
+  }
+};
+
+/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
+//
+// Example:
+//
+//  The following MLIR with cancelling ShapeCastOps:
+//
+//   %0 = source : vector<5x4x2xf32>
+//   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
+//   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
+//   %3 = user %2 : vector<5x4x2xf32>
+//
+//  Should canonicalize to the following:
+//
+//   %0 = source : vector<5x4x2xf32>
+//   %1 = user %0 : vector<5x4x2xf32>
+//
+struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                     PatternRewriter &rewriter) const override {
+    // Check if 'shapeCastOp' has vector source/result type.
+    auto sourceVectorType =
+        shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
+    auto resultVectorType =
+        shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
+    if (!sourceVectorType || !resultVectorType)
+      return matchFailure();
+
+    // Check if shape cast op source operand is also a shape cast op.
+    auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
+        shapeCastOp.source().getDefiningOp());
+    if (!sourceShapeCastOp)
+      return matchFailure();
+    auto operandSourceVectorType =
+        sourceShapeCastOp.source().getType().cast<VectorType>();
+    auto operandResultVectorType =
+        sourceShapeCastOp.result().getType().cast<VectorType>();
+
+    // Check if shape cast operations invert each other.
+    if (operandSourceVectorType != resultVectorType ||
+        operandResultVectorType != sourceVectorType)
+      return matchFailure();
+
+    rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
+    return matchSuccess();
+  }
+};
+
 // Patter rewrite which forward tuple elements to their users.
 // User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
 //   -> User(Producer)
@@ -784,8 +868,8 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
 // TODO(andydavis) Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
-      context);
+  patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, SplitTransferReadOp,
+                  SplitTransferWriteOp, TupleGetFolderOp>(context);
 }
 
 void mlir::vector::populateVectorSlicesLoweringPatterns(

diff  --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir
index 1153ffb2999f..7582758384ce 100644
--- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir
@@ -346,3 +346,62 @@ func @vector_transfers_vector_element_type() {
 
   return
 }
+
+// Test that ShapeCastOp on tuple of vectors, decomposes to multiple
+// ShapeCastOps on vectors.
+// CHECK-LABEL: func @shape_cast_decomposition
+//       CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<5x4x2xf32> to vector<20x2xf32>
+//  CHECK-NEXT: %[[V1:.*]] = vector.shape_cast %{{.*}} : vector<3x4x2xf32> to vector<12x2xf32>
+//  CHECK-NEXT: return %[[V0]], %[[V1]] : vector<20x2xf32>, vector<12x2xf32>
+
+func @shape_cast_decomposition(%arg0 : vector<5x4x2xf32>,
+                               %arg1 : vector<3x4x2xf32>)
+  -> (vector<20x2xf32>, vector<12x2xf32>) {
+  %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
+  %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
+                              tuple<vector<20x2xf32>, vector<12x2xf32>>
+  %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
+  %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
+  return %2, %3 : vector<20x2xf32>, vector<12x2xf32>
+}
+
+// Test that cancelling ShapeCastOps are canonicalized away.
+// EX:
+//
+//  The following MLIR with cancelling ShapeCastOps:
+//
+//   %0 = source : vector<5x4x2xf32>
+//   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
+//   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
+//   %3 = user %2 : vector<5x4x2xf32>
+//
+//  Should canonicalize to the following:
+//
+//
+//   %0 = source : vector<5x4x2xf32>
+//   %1 = user %0 : vector<5x4x2xf32>
+//
+
+// ShapeCastOps on vectors.
+// CHECK-LABEL: func @shape_cast_fold
+//       CHECK: return %{{.*}},  %{{.*}} : vector<5x4x2xf32>, vector<3x4x2xf32>
+
+func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
+  -> (vector<5x4x2xf32>, vector<3x4x2xf32>) {
+  %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
+
+  %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
+                              tuple<vector<20x2xf32>, vector<12x2xf32>>
+
+  %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
+  %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
+
+  %4 = vector.tuple %2, %3 : vector<20x2xf32>, vector<12x2xf32>
+  %5 = vector.shape_cast %4 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
+                              tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
+
+  %6 = vector.tuple_get %5, 0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
+  %7 = vector.tuple_get %5, 1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
+
+  return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32>
+}


        


More information about the Mlir-commits mailing list