[Mlir-commits] [mlir] [mlir] Fix type transformation in DropUnitDimFromElementwiseOps (PR #75430)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 13 23:10:42 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Jerry Wu (pzread)

<details>
<summary>Changes</summary>

Use operand and result types to build the corresponding new types in `DropUnitDimFromElementwiseOps`.

Elementwise ops only guarantee to have the same shape on their operands and results, but don't guarantee to have the same element type.

---
Full diff: https://github.com/llvm/llvm-project/pull/75430.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+22-12) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+15-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 45eb7274cd2d3c..1175da921d7ba1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1480,17 +1480,24 @@ struct DropUnitDimFromElementwiseOps final
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (op->getNumResults() != 1)
+    if (op->getNumResults() != 1 || op->getNumRegions() != 0)
       return failure();
 
-    // Check the pre-condiitions. For `Elementwise` Ops all operands
-    // are guaranteed to have identical shapes and it suffices to only check the
-    // first one.
-    auto op1 = op->getOperands()[0];
-    auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
-    if (!sourceVectorType)
+    auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
+    if (!resultVectorType)
       return failure();
 
+    if (llvm::any_of(op->getOperands(), [](auto operand) {
+          return !isa<VectorType>(operand.getType());
+        })) {
+      return failure();
+    }
+
+    // Check the pre-conditions. For `Elementwise` Ops all operands are
+    // guaranteed to have identical shapes and it suffices to only check the
+    // first one.
+    auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
+
     if (sourceVectorType.getRank() < 2)
       return failure();
 
@@ -1506,23 +1513,26 @@ struct DropUnitDimFromElementwiseOps final
     // Drop leading/trailing unit dim by applying vector.shape_cast to all
     // operands
     int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-    VectorType newVType = VectorType::Builder(sourceVectorType).dropDim(dim);
 
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
+      auto opVectorType = cast<VectorType>(operand.getType());
+      VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
       newOperands.push_back(opSC);
     }
 
+    VectorType newResultVectorType =
+        VectorType::Builder(resultVectorType).dropDim(dim);
     // Create an updated elementwise Op without leading/trailing unit dim
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
-                        newVType, op->getAttrs());
+                        newResultVectorType, op->getAttrs());
 
-    // Restore the leading/trailing unit dim by applying vector.shape_cast to
-    // the result
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
+    // Restore the leading/trailing unit dim by applying vector.shape_cast
+    // to the result
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
                                              elementwiseOp->getResult(0));
 
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index b81491b9c07404..3708d741141be0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -321,6 +321,21 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
 
 // -----
 
+func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
+   %add = arith.sitofp %sc_arg0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
+   %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+   return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_sitofp(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
+// CHECK:           return %[[VAL_2]] : vector<8x[2]xf32>
+
+// -----
+
 // All shape casts are folded away
 
 func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
@@ -341,4 +356,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
-

``````````

</details>


https://github.com/llvm/llvm-project/pull/75430


More information about the Mlir-commits mailing list