[Mlir-commits] [mlir] [DRAFT] Generalize expand_shape to take shape as explicit input (PR #69267)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 26 13:06:56 PDT 2023
================
@@ -16,6 +17,76 @@
using namespace mlir;
+LogicalResult
+mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
+ RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ SmallVectorImpl<OpFoldResult> &outputShape) {
+ outputShape.clear();
+ SmallVector<Value> outputShapeValues;
+ SmallVector<int64_t> outputShapeInts;
+ // For zero-rank inputs, all dims in result shape are unit extent.
+ if (inputShape.empty()) {
+ outputShapeInts.resize(expandedType.getRank(), 1);
+ outputShape.assign(getMixedValues(outputShapeInts, outputShapeValues, b));
+ return success();
+ }
+
+ outputShapeValues.resize(expandedType.getRank());
+ outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
+
+ for (const auto &it : llvm::enumerate(reassociation)) {
+ ReassociationIndices indexGroup = it.value();
+
+ int64_t indexGroupStaticSizesProductInt = 1;
+ bool foundDynamic = false;
+ for (int64_t index : indexGroup) {
+ int64_t outputDimSize = expandedType.getDimSize(index);
+ // Cannot infer expanded shape with multiple dynamic dims in the
+ // same reassociation group!
+ if (ShapedType::isDynamic(outputDimSize)) {
+ if (foundDynamic)
+ return failure();
+ foundDynamic = true;
+ } else {
+ indexGroupStaticSizesProductInt *= outputDimSize;
+ }
+ }
+ Value indexGroupStaticSizesProduct =
+ b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+
+ int64_t inputIndex = it.index();
+ for (int64_t index : indexGroup) {
+ if (ShapedType::isDynamic(expandedType.getDimSize(index))) {
+ // Call get<Value>() under the assumption that we're not casting
+ // dynamism.
+ Value indexGroupSize = inputShape[inputIndex].get<Value>();
+ Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
+ loc, indexGroupSize, indexGroupStaticSizesProduct);
+ outputShapeValues[index] = dynamicDimSize;
+ }
+ }
+
+ for (int64_t index : indexGroup) {
+ int64_t outputDimSize = expandedType.getDimSize(index);
+ if (ShapedType::isDynamic(outputDimSize))
+ continue;
+ outputShapeInts[index] = outputDimSize;
+ }
+ }
+
+ assert(static_cast<uint64_t>(
----------------
MaheshRavishankar wrote:
Just return `failure` here.
Also it seems like there is really no `failure` returning happening in this method. One option is to make this function return `void` (in which case the `assert` here is OK). On balance, maybe better to just keep the `LogicalResult` return type and return `failure()`.
https://github.com/llvm/llvm-project/pull/69267
More information about the Mlir-commits
mailing list