[Mlir-commits] [mlir] 2c9ba9c - [mlir] Fix type transformation in DropUnitDimFromElementwiseOps (#75430)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 14 09:20:58 PST 2023
Author: Jerry Wu
Date: 2023-12-14T12:20:54-05:00
New Revision: 2c9ba9c34a22db98faf9405b157195af4e09c2ed
URL: https://github.com/llvm/llvm-project/commit/2c9ba9c34a22db98faf9405b157195af4e09c2ed
DIFF: https://github.com/llvm/llvm-project/commit/2c9ba9c34a22db98faf9405b157195af4e09c2ed.diff
LOG: [mlir] Fix type transformation in DropUnitDimFromElementwiseOps (#75430)
Use operand and result types to build the corresponding new types in
`DropUnitDimFromElementwiseOps`.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 45eb7274cd2d3c..5936b0b54af4e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1480,17 +1480,17 @@ struct DropUnitDimFromElementwiseOps final
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (op->getNumResults() != 1)
+ if (op->getNumResults() != 1 || op->getNumRegions() != 0)
return failure();
- // Check the pre-condiitions. For `Elementwise` Ops all operands
- // are guaranteed to have identical shapes and it suffices to only check the
- // first one.
- auto op1 = op->getOperands()[0];
- auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
- if (!sourceVectorType)
+ auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultVectorType)
return failure();
+ // Check the pre-conditions. For `Elementwise` Ops all operands are
+ // guaranteed to have identical shapes and it suffices to only check the
+ // first one.
+ auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
if (sourceVectorType.getRank() < 2)
return failure();
@@ -1506,23 +1506,26 @@ struct DropUnitDimFromElementwiseOps final
// Drop leading/trailing unit dim by applying vector.shape_cast to all
// operands
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
- VectorType newVType = VectorType::Builder(sourceVectorType).dropDim(dim);
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
+ auto opVectorType = cast<VectorType>(operand.getType());
+ VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
newOperands.push_back(opSC);
}
+ VectorType newResultVectorType =
+ VectorType::Builder(resultVectorType).dropDim(dim);
// Create an updated elementwise Op without leading/trailing unit dim
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
- newVType, op->getAttrs());
+ newResultVectorType, op->getAttrs());
- // Restore the leading/trailing unit dim by applying vector.shape_cast to
- // the result
- rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
+ // Restore the leading/trailing unit dim by applying vector.shape_cast
+ // to the result
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));
return success();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index b81491b9c07404..3708d741141be0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -321,6 +321,21 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
// -----
+func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
+ %add = arith.sitofp %sc_arg0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
+ %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+ return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dim_sitofp(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
+// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
+// CHECK: return %[[VAL_2]] : vector<8x[2]xf32>
+
+// -----
+
// All shape casts are folded away
func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
@@ -341,4 +356,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>
-
More information about the Mlir-commits
mailing list