[llvm] [mlir] [MLIR] Generalize expand_shape to take shape as explicit input (PR #69267)
Gaurav Shukla via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 16 09:52:39 PDT 2024
================
@@ -16,6 +18,83 @@
using namespace mlir;
+LogicalResult mlir::inferExpandShapeOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+ outputShape.clear();
+ SmallVector<Value> outputShapeValues;
+ SmallVector<int64_t> outputShapeInts;
+ // For zero-rank inputs, all dims in result shape are unit extent.
+ if (inputShape.empty()) {
+ outputShapeInts.resize(expandedType.getRank(), 1);
+ outputShape.assign(getMixedValues(outputShapeInts, outputShapeValues, b));
+ return success();
+ }
+
+ outputShapeValues.resize(expandedType.getRank());
+ outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
+
+ for (const auto &it : llvm::enumerate(reassociation)) {
+ ReassociationIndices indexGroup = it.value();
+
+ int64_t indexGroupStaticSizesProductInt = 1;
+ bool foundDynamic = false;
+ for (int64_t index : indexGroup) {
+ int64_t outputDimSize = expandedType.getDimSize(index);
+ // Cannot infer expanded shape with multiple dynamic dims in the
+ // same reassociation group!
+ if (ShapedType::isDynamic(outputDimSize)) {
+ if (foundDynamic)
+ return failure();
+ foundDynamic = true;
+ } else {
+ indexGroupStaticSizesProductInt *= outputDimSize;
+ }
+ }
+ Value indexGroupStaticSizesProduct =
+ b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+
+ int64_t inputIndex = it.index();
+ for (int64_t index : indexGroup) {
+ if (ShapedType::isDynamic(expandedType.getDimSize(index))) {
+ // Call get<Value>() under the assumption that we're not casting
+ // dynamism.
+ Value indexGroupSize = inputShape[inputIndex].get<Value>();
+
+ // Create an AffineMap representing the division operation.
+ MLIRContext *context = b.getContext();
+ AffineExpr dividend = getAffineSymbolExpr(0, context);
+ AffineExpr divisor = getAffineSymbolExpr(1, context);
+ AffineMap divisionMap = AffineMap::get(/*numDims=*/0, /*numSymbols=*/2,
+ {dividend.floorDiv(divisor)});
+ Value dynamicDimSize = b.createOrFold<affine::AffineApplyOp>(
+ loc, divisionMap,
+ ValueRange({indexGroupSize, indexGroupStaticSizesProduct}));
+ outputShapeValues[index] = dynamicDimSize;
+ }
+ }
+
+ for (int64_t index : indexGroup) {
+ int64_t outputDimSize = expandedType.getDimSize(index);
+ if (ShapedType::isDynamic(outputDimSize))
+ continue;
+ outputShapeInts[index] = outputDimSize;
+ }
+ }
+
+ if (static_cast<uint64_t>(
+ llvm::count(outputShapeInts, ShapedType::kDynamic)) ==
+ (outputShapeValues.size() - llvm::count(outputShapeValues, Value{})))
+ return failure();
+
+ llvm::erase(outputShapeValues, Value{});
----------------
Shukla-Gaurav wrote:
Thanks for pointing this, updated the code :)
https://github.com/llvm/llvm-project/pull/69267
More information about the llvm-commits
mailing list