[Mlir-commits] [mlir] [DRAFT] Generalize expand_shape to take shape as explicit input (PR #69267)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 09:14:28 PDT 2023
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 9f93a99a096c093b5c205cf9143d88bbbbba1b53 e8ac533dd84b1c79b06ae6f112f28518dfb6d57e -- mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h mlir/include/mlir/Dialect/Utils/StaticValueUtils.h mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp mlir/lib/Dialect/Tensor/IR/TensorOps.cpp mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp mlir/lib/Dialect/Utils/StaticValueUtils.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7be9315b9..6887f3ff9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -549,22 +549,20 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
auto resultType =
RankedTensorType::get(resultShape, shapedType.getElementType());
-
SmallVector<OpFoldResult> inputShape =
- tensor::getMixedSizes(rewriter, loc, tensor);
- SmallVector<OpFoldResult> outputShape;
- if (failed(tensor::ExpandShapeOp::inferOutputShape(
- rewriter, loc, resultType,
- reassociationIndices, inputShape,
- outputShape))) {
- (void)rewriter.notifyMatchFailure(
- loc, "unable to infer output shape argument for tensor.expand_shape");
- return {};
- }
+ tensor::getMixedSizes(rewriter, loc, tensor);
+ SmallVector<OpFoldResult> outputShape;
+ if (failed(tensor::ExpandShapeOp::inferOutputShape(
+ rewriter, loc, resultType, reassociationIndices, inputShape,
+ outputShape))) {
+ (void)rewriter.notifyMatchFailure(
+ loc, "unable to infer output shape argument for tensor.expand_shape");
+ return {};
+ }
// Emit 'tensor.expand_shape' op
- return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
- reassociationIndices, outputShape);
+ return rewriter.create<tensor::ExpandShapeOp>(
+ loc, resultType, tensor, reassociationIndices, outputShape);
}
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 2b43ffc0c..8aabc7a64 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -203,12 +203,11 @@ Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
convertReassociationMapsToIndices(reassociationMap), inputShape,
outputShape))) {
(void)rewriter.notifyMatchFailure(
- loc,
- "unable to infer output shape argument for tensor.expand_shape");
+ loc, "unable to infer output shape argument for tensor.expand_shape");
return {};
}
- return rewriter.create<tensor::ExpandShapeOp>(
- loc, resultTy, operand, reassociationMap, outputShape);
+ return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
+ reassociationMap, outputShape);
}
class ReshapeConverterCollapseExpand
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index b92a68309..f1dbc8f5d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -508,14 +508,14 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
});
Value result = genericOp.getResults().front();
SmallVector<OpFoldResult> inputShape =
- tensor::getMixedSizes(rewriter, loc, result);
+ tensor::getMixedSizes(rewriter, loc, result);
SmallVector<OpFoldResult> expandOutputShape;
if (failed(tensor::ExpandShapeOp::inferOutputShape(
rewriter, loc, outputType.cast<RankedTensorType>(),
- outputReassocIndices, inputShape,
- expandOutputShape))) {
+ outputReassocIndices, inputShape, expandOutputShape))) {
return rewriter.notifyMatchFailure(
- convOp, "unable to infer output shape argument for tensor.expand_shape");
+ convOp,
+ "unable to infer output shape argument for tensor.expand_shape");
}
auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 820dd267f..642b25f66 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -276,7 +276,7 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
SmallVector<OpFoldResult> inputShape =
- tensor::getMixedSizes(rewriter, loc, result);
+ tensor::getMixedSizes(rewriter, loc, result);
SmallVector<OpFoldResult> outputShape;
if (failed(tensor::ExpandShapeOp::inferOutputShape(
rewriter, loc, origResultType, reassociation, inputShape,
@@ -284,8 +284,8 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
return failure();
}
return rewriter
- .create<tensor::ExpandShapeOp>(loc, origResultType, result,
- reassociation, outputShape)
+ .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation,
+ outputShape)
.getResult();
}
@@ -549,12 +549,11 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
resultReplacements.push_back(result);
continue;
}
- FailureOr<Value> expandedValue = expandValue(rewriter, loc, result, origDest,
- reassociations[opOperandIndex],
- options.rankReductionStrategy);
+ FailureOr<Value> expandedValue = expandValue(
+ rewriter, loc, result, origDest, reassociations[opOperandIndex],
+ options.rankReductionStrategy);
if (failed(expandedValue)) {
- return rewriter.notifyMatchFailure(genericOp,
- "unable to expand result");
+ return rewriter.notifyMatchFailure(genericOp, "unable to expand result");
}
resultReplacements.push_back(*expandedValue);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index eff94d871..9c27c4a2c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1592,19 +1592,21 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
if (isa<MemRefType>(collapsedOpResult.getType())) {
SmallVector<OpFoldResult> collapsedOpShape =
memref::getMixedSizes(rewriter, loc, collapsedOpResult);
- MemRefType expandShapeResultType =
- MemRefType::get(originalResultType.getShape(), originalResultType.getElementType());
+ MemRefType expandShapeResultType = MemRefType::get(
+ originalResultType.getShape(), originalResultType.getElementType());
SmallVector<OpFoldResult> outputShape;
if (failed(memref::ExpandShapeOp::inferOutputShape(
- rewriter, loc, expandShapeResultType, reassociation, collapsedOpShape,
- outputShape))) {
+ rewriter, loc, expandShapeResultType, reassociation,
+ collapsedOpShape, outputShape))) {
return rewriter.notifyMatchFailure(
- genericOp, "unable to infer output shape argument for memref.expand_shape");
+ genericOp,
+ "unable to infer output shape argument for memref.expand_shape");
}
Value result = rewriter.create<memref::ExpandShapeOp>(
- loc, expandShapeResultType, collapsedOpResult, reassociation, outputShape);
+ loc, expandShapeResultType, collapsedOpResult, reassociation,
+ outputShape);
results.push_back(result);
} else {
SmallVector<OpFoldResult> collapsedOpShape =
``````````
</details>
https://github.com/llvm/llvm-project/pull/69267
More information about the Mlir-commits
mailing list