[Mlir-commits] [mlir] 7417541 - [mlir][vector] Add canonicalization for extract/insert -> shapecast
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 3 10:45:08 PDT 2021
Author: thomasraoux
Date: 2021-05-03T10:41:15-07:00
New Revision: 7417541fd8d764c42e5c7f3647e73ae6913b0fd7
URL: https://github.com/llvm/llvm-project/commit/7417541fd8d764c42e5c7f3647e73ae6913b0fd7
DIFF: https://github.com/llvm/llvm-project/commit/7417541fd8d764c42e5c7f3647e73ae6913b0fd7.diff
LOG: [mlir][vector] Add canonicalization for extract/insert -> shapecast
Differential Revision: https://reviews.llvm.org/D101643
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index d08c73ff2f4d6..3d83bcd0aa57a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -519,6 +519,7 @@ def Vector_ExtractOp :
return vector().getType().cast<VectorType>();
}
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -763,6 +764,7 @@ def Vector_InsertOp :
return dest().getType().cast<VectorType>();
}
}];
+ let hasCanonicalizer = 1;
}
def Vector_InsertSlicesOp :
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 9fd9e1e40866a..2958088b258df 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1142,6 +1142,33 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
return OpFoldResult();
}
+namespace {
+
+// If extractOp is only removing unit dimensions it can be transformed to a
+// shapecast.
+class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
+ if (!dstVecType || extractOp.getVectorType().getNumElements() !=
+ dstVecType.getNumElements())
+ return failure();
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
+ extractOp.vector());
+ return success();
+ }
+};
+
+} // namespace
+
+void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ExtractToShapeCast>(context);
+}
+
//===----------------------------------------------------------------------===//
// ExtractSlicesOp
//===----------------------------------------------------------------------===//
@@ -1536,6 +1563,33 @@ static LogicalResult verify(InsertOp op) {
return success();
}
+namespace {
+
+// If insertOp is only inserting unit dimensions it can be transformed to a
+// shapecast.
+class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
+ if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
+ srcVecType.getNumElements())
+ return failure();
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(
+ insertOp, insertOp.getDestVectorType(), insertOp.source());
+ return success();
+ }
+};
+
+} // namespace
+
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<InsertToShapeCast>(context);
+}
+
//===----------------------------------------------------------------------===//
// InsertSlicesOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6d25e40ace3e4..0806e1df9ee82 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -504,16 +504,18 @@ func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
// CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
// CHECK: %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32>
// CHECK: %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32>
-// CHECK: return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32>
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[A1]] : f32, vector<2xf32>, vector<4x2xf32>, vector<8x4x2xf32>
func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x4x2xf32>)
- -> (f32, vector<2xf32>, vector<4x2xf32>) {
+ -> (f32, vector<2xf32>, vector<4x2xf32>, vector<8x4x2xf32>) {
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
%1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32>
+ %2 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<1x8x4x2xf32>
%r1 = vector.extract %0[4, 1] : vector<15x2xf32>
%r2 = vector.extract %0[5] : vector<15x2xf32>
%r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32>
- return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32>
+ %r4 = vector.extract %2[0] : vector<1x8x4x2xf32>
+ return %r1, %r2, %r3, %r4 : f32, vector<2xf32>, vector<4x2xf32>, vector<8x4x2xf32>
}
// -----
@@ -932,3 +934,17 @@ func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
vector<1x4xf32>, tensor<4x4xf32>
return %w2 : tensor<4x4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_shapecast
+// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
+ %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+ %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
+ %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+ return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
More information about the Mlir-commits
mailing list