[Mlir-commits] [mlir] a970e69 - [mlir][vector] add pattern to cast away leading unit dim for elementwise op

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 7 08:05:09 PDT 2021


Author: thomasraoux
Date: 2021-05-07T07:54:09-07:00
New Revision: a970e69d6b62d60c4c222e2a4be0a73999c97651

URL: https://github.com/llvm/llvm-project/commit/a970e69d6b62d60c4c222e2a4be0a73999c97651
DIFF: https://github.com/llvm/llvm-project/commit/a970e69d6b62d60c4c222e2a4be0a73999c97651.diff

LOG: [mlir][vector] add pattern to cast away leading unit dim for elementwise op

Differential Revision: https://reviews.llvm.org/D102034

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 40fb659d020e7..15a211bcec669 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3175,7 +3175,7 @@ struct CastAwayTransferWriteLeadingOneDim
   }
 };
 
-struct CastAwayBrodcastLeadingOneDim
+struct CastAwayBroadcastLeadingOneDim
     : public OpRewritePattern<vector::BroadcastOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -3200,6 +3200,44 @@ struct CastAwayBrodcastLeadingOneDim
   }
 };
 
+class CastAwayElementwiseLeadingOneDim : public RewritePattern {
+public:
+  CastAwayElementwiseLeadingOneDim(MLIRContext *context)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+      return failure();
+    auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
+    if (!vecType)
+      return failure();
+    VectorType newVecType = trimLeadingOneDims(vecType);
+    if (newVecType == vecType)
+      return failure();
+
+    SmallVector<Value, 4> newOperands;
+    for (Value operand : op->getOperands()) {
+      if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
+        auto newType =
+            VectorType::get(newVecType.getShape(), opVecType.getElementType());
+        newOperands.push_back(rewriter.create<vector::ShapeCastOp>(
+            op->getLoc(), newType, operand));
+      } else {
+        newOperands.push_back(operand);
+      }
+    }
+    OperationState state(op->getLoc(), op->getName());
+    state.addAttributes(op->getAttrs());
+    state.addOperands(newOperands);
+    state.addTypes(newVecType);
+    Operation *newOp = rewriter.createOperation(state);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, vecType,
+                                                     newOp->getResult(0));
+    return success();
+  }
+};
+
 // Returns the values in `arrayAttr` as an integer vector.
 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(
@@ -3795,12 +3833,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
 
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
-               CastAwayInsertStridedSliceLeadingOneDim,
-               CastAwayTransferReadLeadingOneDim,
-               CastAwayTransferWriteLeadingOneDim,
-               CastAwayBrodcastLeadingOneDim, ShapeCastOpFolder>(
-      patterns.getContext());
+  patterns
+      .add<CastAwayExtractStridedSliceLeadingOneDim,
+           CastAwayInsertStridedSliceLeadingOneDim,
+           CastAwayTransferReadLeadingOneDim,
+           CastAwayTransferWriteLeadingOneDim, CastAwayBroadcastLeadingOneDim,
+           CastAwayElementwiseLeadingOneDim, ShapeCastOpFolder>(
+          patterns.getContext());
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 47a12958aa144..abfed46e82f10 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -689,6 +689,34 @@ func @cast_away_broadcast_leading_one_dims(
   return %0, %1, %2: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>
 }
 
+// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
+func @cast_away_elementwise_leading_one_dims(
+  %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
+  %arg3: vector<1x4xf32>, %arg4: i1) ->
+  (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
+  // CHECK:  vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
+  // CHECK:  addf %{{.*}}, %{{.*}} : vector<8xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
+  %0 = addf %arg0, %arg0 : vector<1x1x8xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<4xi1> to vector<1x4xi1>
+  %1 = cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+  %2 = select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  select %arg4, %12, %{{.*}} : vector<4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+  %3 = select %arg4, %arg3, %arg2 : vector<1x4xf32>
+  return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
+}
+
 // CHECK-LABEL: func @bubble_down_bitcast_in_extract
 //  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
 func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {


        


More information about the Mlir-commits mailing list