[Mlir-commits] [mlir] [mlir] Fix type transformation in DropUnitDimFromElementwiseOps (PR #75430)
Jerry Wu
llvmlistbot at llvm.org
Thu Dec 14 00:02:40 PST 2023
https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/75430
>From 3ae4fae4a678b2b573bb939705ef20b960cc7e2f Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 14 Dec 2023 07:00:01 +0000
Subject: [PATCH 1/2] Enhance condition and fix type conversion
---
.../Vector/Transforms/VectorTransforms.cpp | 34 ++++++++++++-------
.../Vector/vector-transfer-flatten.mlir | 16 ++++++++-
2 files changed, 37 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 45eb7274cd2d3c..1175da921d7ba1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1480,17 +1480,24 @@ struct DropUnitDimFromElementwiseOps final
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (op->getNumResults() != 1)
+ if (op->getNumResults() != 1 || op->getNumRegions() != 0)
return failure();
- // Check the pre-condiitions. For `Elementwise` Ops all operands
- // are guaranteed to have identical shapes and it suffices to only check the
- // first one.
- auto op1 = op->getOperands()[0];
- auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
- if (!sourceVectorType)
+ auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultVectorType)
return failure();
+ if (llvm::any_of(op->getOperands(), [](auto operand) {
+ return !isa<VectorType>(operand.getType());
+ })) {
+ return failure();
+ }
+
+ // Check the pre-conditions. For `Elementwise` Ops all operands are
+ // guaranteed to have identical shapes and it suffices to only check the
+ // first one.
+ auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
+
if (sourceVectorType.getRank() < 2)
return failure();
@@ -1506,23 +1513,26 @@ struct DropUnitDimFromElementwiseOps final
// Drop leading/trailing unit dim by applying vector.shape_cast to all
// operands
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
- VectorType newVType = VectorType::Builder(sourceVectorType).dropDim(dim);
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
+ auto opVectorType = cast<VectorType>(operand.getType());
+ VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
newOperands.push_back(opSC);
}
+ VectorType newResultVectorType =
+ VectorType::Builder(resultVectorType).dropDim(dim);
// Create an updated elementwise Op without leading/trailing unit dim
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
- newVType, op->getAttrs());
+ newResultVectorType, op->getAttrs());
- // Restore the leading/trailing unit dim by applying vector.shape_cast to
- // the result
- rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
+ // Restore the leading/trailing unit dim by applying vector.shape_cast
+ // to the result
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));
return success();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index b81491b9c07404..3708d741141be0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -321,6 +321,21 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
// -----
+func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
+ %add = arith.sitofp %sc_arg0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
+ %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+ return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dim_sitofp(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
+// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
+// CHECK: return %[[VAL_2]] : vector<8x[2]xf32>
+
+// -----
+
// All shape casts are folded away
func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
@@ -341,4 +356,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>
-
>From 7eabd6150e1d694d50c831253cc7b1bf493b49eb Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 14 Dec 2023 08:02:24 +0000
Subject: [PATCH 2/2] Remove extra condition
---
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 7 -------
1 file changed, 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 1175da921d7ba1..5936b0b54af4e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1487,17 +1487,10 @@ struct DropUnitDimFromElementwiseOps final
if (!resultVectorType)
return failure();
- if (llvm::any_of(op->getOperands(), [](auto operand) {
- return !isa<VectorType>(operand.getType());
- })) {
- return failure();
- }
-
// Check the pre-conditions. For `Elementwise` Ops all operands are
// guaranteed to have identical shapes and it suffices to only check the
// first one.
auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
-
if (sourceVectorType.getRank() < 2)
return failure();
More information about the Mlir-commits
mailing list