[llvm] [mlir] [MLIR] Generalize expand_shape to take shape as explicit input (PR #69267)

Gaurav Shukla via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 16 09:52:39 PDT 2024


================
@@ -16,6 +18,83 @@
 
 using namespace mlir;
 
+LogicalResult mlir::inferExpandShapeOutputShape(
+    OpBuilder &b, Location loc, RankedTensorType expandedType,
+    ArrayRef<ReassociationIndices> reassociation,
+    ArrayRef<OpFoldResult> inputShape,
+    std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+  outputShape.clear();
+  SmallVector<Value> outputShapeValues;
+  SmallVector<int64_t> outputShapeInts;
+  // For zero-rank inputs, all dims in result shape are unit extent.
+  if (inputShape.empty()) {
+    outputShapeInts.resize(expandedType.getRank(), 1);
+    outputShape.assign(getMixedValues(outputShapeInts, outputShapeValues, b));
+    return success();
+  }
+
+  outputShapeValues.resize(expandedType.getRank());
+  outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
+
+  for (const auto &it : llvm::enumerate(reassociation)) {
+    ReassociationIndices indexGroup = it.value();
+
+    int64_t indexGroupStaticSizesProductInt = 1;
+    bool foundDynamic = false;
+    for (int64_t index : indexGroup) {
+      int64_t outputDimSize = expandedType.getDimSize(index);
+      // Cannot infer expanded shape with multiple dynamic dims in the
+      // same reassociation group!
+      if (ShapedType::isDynamic(outputDimSize)) {
+        if (foundDynamic)
+          return failure();
+        foundDynamic = true;
+      } else {
+        indexGroupStaticSizesProductInt *= outputDimSize;
+      }
+    }
+    Value indexGroupStaticSizesProduct =
+        b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+
+    int64_t inputIndex = it.index();
+    for (int64_t index : indexGroup) {
+      if (ShapedType::isDynamic(expandedType.getDimSize(index))) {
+        // Call get<Value>() under the assumption that we're not casting
+        // dynamism.
+        Value indexGroupSize = inputShape[inputIndex].get<Value>();
+
+        // Create an AffineMap representing the division operation.
+        MLIRContext *context = b.getContext();
+        AffineExpr dividend = getAffineSymbolExpr(0, context);
+        AffineExpr divisor = getAffineSymbolExpr(1, context);
+        AffineMap divisionMap = AffineMap::get(/*numDims=*/0, /*numSymbols=*/2,
+                                               {dividend.floorDiv(divisor)});
+        Value dynamicDimSize = b.createOrFold<affine::AffineApplyOp>(
+            loc, divisionMap,
+            ValueRange({indexGroupSize, indexGroupStaticSizesProduct}));
+        outputShapeValues[index] = dynamicDimSize;
+      }
+    }
+
+    for (int64_t index : indexGroup) {
+      int64_t outputDimSize = expandedType.getDimSize(index);
+      if (ShapedType::isDynamic(outputDimSize))
+        continue;
+      outputShapeInts[index] = outputDimSize;
+    }
+  }
+
+  if (static_cast<uint64_t>(
+          llvm::count(outputShapeInts, ShapedType::kDynamic)) ==
+      (outputShapeValues.size() - llvm::count(outputShapeValues, Value{})))
+    return failure();
+
+  llvm::erase(outputShapeValues, Value{});
----------------
Shukla-Gaurav wrote:

Thanks for pointing this, updated the code :)

https://github.com/llvm/llvm-project/pull/69267


More information about the llvm-commits mailing list