[Mlir-commits] [mlir] 7897a94 - [mlir][vector] Fold extract(shape_cast) for same element count
Lei Zhang
llvmlistbot at llvm.org
Tue Aug 15 12:17:35 PDT 2023
Author: Lei Zhang
Date: 2023-08-15T11:28:35-07:00
New Revision: 7897a944d9e72d79f38a443afffbbfd1accfe4ad
URL: https://github.com/llvm/llvm-project/commit/7897a944d9e72d79f38a443afffbbfd1accfe4ad
DIFF: https://github.com/llvm/llvm-project/commit/7897a944d9e72d79f38a443afffbbfd1accfe4ad.diff
LOG: [mlir][vector] Fold extract(shape_cast) for same element count
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D157930
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f5a7cdc556b515..5f5909ec998105 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1808,12 +1808,34 @@ class ExtractOpNonSplatConstantFolder final
}
};
+// Folds extract(shape_cast(..)) into shape_cast when the total element count
+// does not change.
+LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
+ PatternRewriter &rewriter) {
+ auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
+ if (!castOp)
+ return failure();
+
+ VectorType sourceType = castOp.getSourceVectorType();
+ auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
+ if (!targetType)
+ return failure();
+
+ if (sourceType.getNumElements() != targetType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
+ castOp.getSource());
+ return success();
+}
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
ExtractOpFromBroadcast>(context);
+ results.add(foldExtractFromShapeCastToShapeCast);
}
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2f76fc5d5ebdb2..17a9e381b61708 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -669,6 +669,18 @@ func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
// -----
+// CHECK-LABEL: fold_extract_shapecast_to_shapecast
+// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
+// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
+// CHECK: return %[[R]]
+func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
+ %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
+ %r = vector.extract %0[0] : vector<1x12xf32>
+ return %r : vector<12xf32>
+}
+
+// -----
+
// CHECK-LABEL: dont_fold_expand_collapse
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
More information about the Mlir-commits
mailing list