[Mlir-commits] [mlir] 2c9ba9c - [mlir] Fix type transformation in DropUnitDimFromElementwiseOps (#75430)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 14 09:20:58 PST 2023


Author: Jerry Wu
Date: 2023-12-14T12:20:54-05:00
New Revision: 2c9ba9c34a22db98faf9405b157195af4e09c2ed

URL: https://github.com/llvm/llvm-project/commit/2c9ba9c34a22db98faf9405b157195af4e09c2ed
DIFF: https://github.com/llvm/llvm-project/commit/2c9ba9c34a22db98faf9405b157195af4e09c2ed.diff

LOG: [mlir] Fix type transformation in DropUnitDimFromElementwiseOps (#75430)

Use operand and result types to build the corresponding new types in
`DropUnitDimFromElementwiseOps`.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 45eb7274cd2d3c..5936b0b54af4e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1480,17 +1480,17 @@ struct DropUnitDimFromElementwiseOps final
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (op->getNumResults() != 1)
+    if (op->getNumResults() != 1 || op->getNumRegions() != 0)
       return failure();
 
-    // Check the pre-condiitions. For `Elementwise` Ops all operands
-    // are guaranteed to have identical shapes and it suffices to only check the
-    // first one.
-    auto op1 = op->getOperands()[0];
-    auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
-    if (!sourceVectorType)
+    auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
+    if (!resultVectorType)
       return failure();
 
+    // Check the pre-conditions. For `Elementwise` Ops all operands are
+    // guaranteed to have identical shapes and it suffices to only check the
+    // first one.
+    auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
     if (sourceVectorType.getRank() < 2)
       return failure();
 
@@ -1506,23 +1506,26 @@ struct DropUnitDimFromElementwiseOps final
     // Drop leading/trailing unit dim by applying vector.shape_cast to all
     // operands
     int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-    VectorType newVType = VectorType::Builder(sourceVectorType).dropDim(dim);
 
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
+      auto opVectorType = cast<VectorType>(operand.getType());
+      VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
       newOperands.push_back(opSC);
     }
 
+    VectorType newResultVectorType =
+        VectorType::Builder(resultVectorType).dropDim(dim);
     // Create an updated elementwise Op without leading/trailing unit dim
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
-                        newVType, op->getAttrs());
+                        newResultVectorType, op->getAttrs());
 
-    // Restore the leading/trailing unit dim by applying vector.shape_cast to
-    // the result
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
+    // Restore the leading/trailing unit dim by applying vector.shape_cast
+    // to the result
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
                                              elementwiseOp->getResult(0));
 
     return success();

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index b81491b9c07404..3708d741141be0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -321,6 +321,21 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
 
 // -----
 
+func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
+   %add = arith.sitofp %sc_arg0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
+   %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+   return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_sitofp(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
+// CHECK:           return %[[VAL_2]] : vector<8x[2]xf32>
+
+// -----
+
 // All shape casts are folded away
 
 func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
@@ -341,4 +356,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
-


        


More information about the Mlir-commits mailing list