[Mlir-commits] [mlir] [DRAFT] Generalize expand_shape to take shape as explicit input (PR #69267)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 19 09:14:28 PDT 2023


github-actions[bot] wrote:


<!--LLVM CODE FORMAT COMMENT: {clang-format}-->

:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 9f93a99a096c093b5c205cf9143d88bbbbba1b53 e8ac533dd84b1c79b06ae6f112f28518dfb6d57e -- mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h mlir/include/mlir/Dialect/Utils/StaticValueUtils.h mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp mlir/lib/Dialect/Tensor/IR/TensorOps.cpp mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp mlir/lib/Dialect/Utils/StaticValueUtils.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7be9315b9..6887f3ff9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -549,22 +549,20 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
   auto resultType =
       RankedTensorType::get(resultShape, shapedType.getElementType());
 
-
   SmallVector<OpFoldResult> inputShape =
-        tensor::getMixedSizes(rewriter, loc, tensor);
-    SmallVector<OpFoldResult> outputShape;
-    if (failed(tensor::ExpandShapeOp::inferOutputShape(
-            rewriter, loc, resultType,
-            reassociationIndices, inputShape,
-            outputShape))) {
-      (void)rewriter.notifyMatchFailure(
-          loc, "unable to infer output shape argument for tensor.expand_shape");
-      return {};
-    }
+      tensor::getMixedSizes(rewriter, loc, tensor);
+  SmallVector<OpFoldResult> outputShape;
+  if (failed(tensor::ExpandShapeOp::inferOutputShape(
+          rewriter, loc, resultType, reassociationIndices, inputShape,
+          outputShape))) {
+    (void)rewriter.notifyMatchFailure(
+        loc, "unable to infer output shape argument for tensor.expand_shape");
+    return {};
+  }
 
   // Emit 'tensor.expand_shape' op
-  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
-                                                reassociationIndices, outputShape);
+  return rewriter.create<tensor::ExpandShapeOp>(
+      loc, resultType, tensor, reassociationIndices, outputShape);
 }
 
 static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 2b43ffc0c..8aabc7a64 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -203,12 +203,11 @@ Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
           convertReassociationMapsToIndices(reassociationMap), inputShape,
           outputShape))) {
     (void)rewriter.notifyMatchFailure(
-        loc,
-        "unable to infer output shape argument for tensor.expand_shape");
+        loc, "unable to infer output shape argument for tensor.expand_shape");
     return {};
   }
-  return rewriter.create<tensor::ExpandShapeOp>(
-      loc, resultTy, operand, reassociationMap, outputShape);
+  return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
+                                                reassociationMap, outputShape);
 }
 
 class ReshapeConverterCollapseExpand
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index b92a68309..f1dbc8f5d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -508,14 +508,14 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
       });
   Value result = genericOp.getResults().front();
   SmallVector<OpFoldResult> inputShape =
-        tensor::getMixedSizes(rewriter, loc, result);
+      tensor::getMixedSizes(rewriter, loc, result);
   SmallVector<OpFoldResult> expandOutputShape;
   if (failed(tensor::ExpandShapeOp::inferOutputShape(
           rewriter, loc, outputType.cast<RankedTensorType>(),
-          outputReassocIndices, inputShape,
-          expandOutputShape))) {
+          outputReassocIndices, inputShape, expandOutputShape))) {
     return rewriter.notifyMatchFailure(
-        convOp, "unable to infer output shape argument for tensor.expand_shape");
+        convOp,
+        "unable to infer output shape argument for tensor.expand_shape");
   }
 
   auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 820dd267f..642b25f66 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -276,7 +276,7 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
              ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
          "unknown rank reduction strategy");
   SmallVector<OpFoldResult> inputShape =
-        tensor::getMixedSizes(rewriter, loc, result);
+      tensor::getMixedSizes(rewriter, loc, result);
   SmallVector<OpFoldResult> outputShape;
   if (failed(tensor::ExpandShapeOp::inferOutputShape(
           rewriter, loc, origResultType, reassociation, inputShape,
@@ -284,8 +284,8 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
     return failure();
   }
   return rewriter
-      .create<tensor::ExpandShapeOp>(loc, origResultType, result,
-                                     reassociation, outputShape)
+      .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation,
+                                     outputShape)
       .getResult();
 }
 
@@ -549,12 +549,11 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
       resultReplacements.push_back(result);
       continue;
     }
-    FailureOr<Value> expandedValue = expandValue(rewriter, loc, result, origDest,
-                                             reassociations[opOperandIndex],
-                                             options.rankReductionStrategy);
+    FailureOr<Value> expandedValue = expandValue(
+        rewriter, loc, result, origDest, reassociations[opOperandIndex],
+        options.rankReductionStrategy);
     if (failed(expandedValue)) {
-      return rewriter.notifyMatchFailure(genericOp,
-                                         "unable to expand result");
+      return rewriter.notifyMatchFailure(genericOp, "unable to expand result");
     }
     resultReplacements.push_back(*expandedValue);
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index eff94d871..9c27c4a2c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1592,19 +1592,21 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
       if (isa<MemRefType>(collapsedOpResult.getType())) {
         SmallVector<OpFoldResult> collapsedOpShape =
             memref::getMixedSizes(rewriter, loc, collapsedOpResult);
-        MemRefType expandShapeResultType =
-            MemRefType::get(originalResultType.getShape(), originalResultType.getElementType());
+        MemRefType expandShapeResultType = MemRefType::get(
+            originalResultType.getShape(), originalResultType.getElementType());
         SmallVector<OpFoldResult> outputShape;
 
         if (failed(memref::ExpandShapeOp::inferOutputShape(
-                rewriter, loc, expandShapeResultType, reassociation, collapsedOpShape,
-                outputShape))) {
+                rewriter, loc, expandShapeResultType, reassociation,
+                collapsedOpShape, outputShape))) {
           return rewriter.notifyMatchFailure(
-              genericOp, "unable to infer output shape argument for memref.expand_shape");
+              genericOp,
+              "unable to infer output shape argument for memref.expand_shape");
         }
 
         Value result = rewriter.create<memref::ExpandShapeOp>(
-            loc, expandShapeResultType, collapsedOpResult, reassociation, outputShape);
+            loc, expandShapeResultType, collapsedOpResult, reassociation,
+            outputShape);
         results.push_back(result);
       } else {
         SmallVector<OpFoldResult> collapsedOpShape =

``````````

</details>


https://github.com/llvm/llvm-project/pull/69267


More information about the Mlir-commits mailing list