[Mlir-commits] [mlir] [mlir] Fix type transformation in DropUnitDimFromElementwiseOps (PR #75430)

Jerry Wu llvmlistbot at llvm.org
Thu Dec 14 00:05:12 PST 2023


https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/75430

>From baef0db238b895685eee69a95eeec429c459f020 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 14 Dec 2023 07:00:01 +0000
Subject: [PATCH 1/2] Enhance condition and fix type conversion

---
 .../Vector/Transforms/VectorTransforms.cpp    | 34 ++++++++++++-------
 .../Vector/vector-transfer-flatten.mlir       | 16 ++++++++-
 2 files changed, 37 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 45eb7274cd2d3c..1175da921d7ba1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1480,17 +1480,24 @@ struct DropUnitDimFromElementwiseOps final
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (op->getNumResults() != 1)
+    if (op->getNumResults() != 1 || op->getNumRegions() != 0)
       return failure();
 
-    // Check the pre-condiitions. For `Elementwise` Ops all operands
-    // are guaranteed to have identical shapes and it suffices to only check the
-    // first one.
-    auto op1 = op->getOperands()[0];
-    auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
-    if (!sourceVectorType)
+    auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
+    if (!resultVectorType)
       return failure();
 
+    if (llvm::any_of(op->getOperands(), [](auto operand) {
+          return !isa<VectorType>(operand.getType());
+        })) {
+      return failure();
+    }
+
+    // Check the pre-conditions. For `Elementwise` Ops all operands are
+    // guaranteed to have identical shapes and it suffices to only check the
+    // first one.
+    auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
+
     if (sourceVectorType.getRank() < 2)
       return failure();
 
@@ -1506,23 +1513,26 @@ struct DropUnitDimFromElementwiseOps final
     // Drop leading/trailing unit dim by applying vector.shape_cast to all
     // operands
     int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-    VectorType newVType = VectorType::Builder(sourceVectorType).dropDim(dim);
 
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
+      auto opVectorType = cast<VectorType>(operand.getType());
+      VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
       newOperands.push_back(opSC);
     }
 
+    VectorType newResultVectorType =
+        VectorType::Builder(resultVectorType).dropDim(dim);
     // Create an updated elementwise Op without leading/trailing unit dim
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
-                        newVType, op->getAttrs());
+                        newResultVectorType, op->getAttrs());
 
-    // Restore the leading/trailing unit dim by applying vector.shape_cast to
-    // the result
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
+    // Restore the leading/trailing unit dim by applying vector.shape_cast
+    // to the result
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
                                              elementwiseOp->getResult(0));
 
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index b81491b9c07404..3708d741141be0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -321,6 +321,21 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
 
 // -----
 
+func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
+   %add = arith.sitofp %sc_arg0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
+   %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+   return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_sitofp(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
+// CHECK:           return %[[VAL_2]] : vector<8x[2]xf32>
+
+// -----
+
 // All shape casts are folded away
 
 func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
@@ -341,4 +356,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
-

>From 9f8506a8ccee8be3b76a6c835e9ced8ff2390ef3 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 14 Dec 2023 08:02:24 +0000
Subject: [PATCH 2/2] Remove extra condition

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 1175da921d7ba1..5936b0b54af4e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1487,17 +1487,10 @@ struct DropUnitDimFromElementwiseOps final
     if (!resultVectorType)
       return failure();
 
-    if (llvm::any_of(op->getOperands(), [](auto operand) {
-          return !isa<VectorType>(operand.getType());
-        })) {
-      return failure();
-    }
-
     // Check the pre-conditions. For `Elementwise` Ops all operands are
     // guaranteed to have identical shapes and it suffices to only check the
     // first one.
     auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
-
     if (sourceVectorType.getRank() < 2)
       return failure();
 



More information about the Mlir-commits mailing list