[Mlir-commits] [mlir] [DRAFT] Generalize expand_shape to take shape as explicit input (PR #69267)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 17 14:42:44 PDT 2023
================
@@ -16,6 +17,76 @@
using namespace mlir;
+LogicalResult
+mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
+ RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ SmallVectorImpl<OpFoldResult> &outputShape) {
+ outputShape.clear();
+ SmallVector<Value> outputShapeValues;
+ SmallVector<int64_t> outputShapeInts;
+ // For zero-rank inputs, all dims in result shape are unit extent.
+ if (inputShape.empty()) {
+ outputShapeInts.resize(expandedType.getRank(), 1);
+ outputShape.assign(getMixedValues(outputShapeInts, outputShapeValues, b));
+ return success();
+ }
+
+ outputShapeValues.resize(expandedType.getRank());
+ outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
+
+ for (const auto &it : llvm::enumerate(reassociation)) {
+ ReassociationIndices indexGroup = it.value();
+
+ int64_t indexGroupStaticSizesProductInt = 1;
+ bool foundDynamic = false;
+ for (int64_t index : indexGroup) {
+ int64_t outputDimSize = expandedType.getDimSize(index);
+ // Cannot infer expanded shape with multiple dynamic dims in the
+ // same reassociation group!
+ if (ShapedType::isDynamic(outputDimSize)) {
+ if (foundDynamic)
+ return failure();
+ foundDynamic = true;
+ } else {
+ indexGroupStaticSizesProductInt *= outputDimSize;
+ }
+ }
+ Value indexGroupStaticSizesProduct =
+ b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+
+ int64_t inputIndex = it.index();
+ for (int64_t index : indexGroup) {
+ if (ShapedType::isDynamic(expandedType.getDimSize(index))) {
+ // Call get<Value>() under the assumption that we're not casting
+ // dynamism.
+ Value indexGroupSize = inputShape[inputIndex].get<Value>();
----------------
MaheshRavishankar wrote:
We want to avoid creating `arith.divui` ops directly. Use `affine.apply` operations instead. Those fold better.
https://github.com/llvm/llvm-project/pull/69267
More information about the Mlir-commits
mailing list