[Mlir-commits] [mlir] 7897a94 - [mlir][vector] Fold extract(shape_cast) for same element count

Lei Zhang llvmlistbot at llvm.org
Tue Aug 15 12:17:35 PDT 2023


Author: Lei Zhang
Date: 2023-08-15T11:28:35-07:00
New Revision: 7897a944d9e72d79f38a443afffbbfd1accfe4ad

URL: https://github.com/llvm/llvm-project/commit/7897a944d9e72d79f38a443afffbbfd1accfe4ad
DIFF: https://github.com/llvm/llvm-project/commit/7897a944d9e72d79f38a443afffbbfd1accfe4ad.diff

LOG: [mlir][vector] Fold extract(shape_cast) for same element count

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D157930

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f5a7cdc556b515..5f5909ec998105 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1808,12 +1808,34 @@ class ExtractOpNonSplatConstantFolder final
   }
 };
 
+// Folds extract(shape_cast(..)) into shape_cast when the total element count
+// does not change.
+LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
+                                                  PatternRewriter &rewriter) {
+  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
+  if (!castOp)
+    return failure();
+
+  VectorType sourceType = castOp.getSourceVectorType();
+  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
+  if (!targetType)
+    return failure();
+
+  if (sourceType.getNumElements() != targetType.getNumElements())
+    return failure();
+
+  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
+                                                   castOp.getSource());
+  return success();
+}
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
               ExtractOpFromBroadcast>(context);
+  results.add(foldExtractFromShapeCastToShapeCast);
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2f76fc5d5ebdb2..17a9e381b61708 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -669,6 +669,18 @@ func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: fold_extract_shapecast_to_shapecast
+//  CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
+//       CHECK:   %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
+//       CHECK:   return %[[R]]
+func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
+  %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
+  %r = vector.extract %0[0] : vector<1x12xf32>
+  return %r : vector<12xf32>
+}
+
+// -----
+
 // CHECK-LABEL: dont_fold_expand_collapse
 //       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
 //       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>


        


More information about the Mlir-commits mailing list