[Mlir-commits] [mlir] [DRAFT] Generalize expand_shape to take shape as explicit input (PR #69267)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 17 14:42:44 PDT 2023


================
@@ -16,6 +17,76 @@
 
 using namespace mlir;
 
+LogicalResult
+mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
+                                  RankedTensorType expandedType,
+                                  ArrayRef<ReassociationIndices> reassociation,
+                                  ArrayRef<OpFoldResult> inputShape,
+                                  SmallVectorImpl<OpFoldResult> &outputShape) {
+  outputShape.clear();
+  SmallVector<Value> outputShapeValues;
+  SmallVector<int64_t> outputShapeInts;
+  // For zero-rank inputs, all dims in result shape are unit extent.
+  if (inputShape.empty()) {
+    outputShapeInts.resize(expandedType.getRank(), 1);
+    outputShape.assign(getMixedValues(outputShapeInts, outputShapeValues, b));
+    return success();
+  }
+
+  outputShapeValues.resize(expandedType.getRank());
+  outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
+
+  for (const auto &it : llvm::enumerate(reassociation)) {
+    ReassociationIndices indexGroup = it.value();
+
+    int64_t indexGroupStaticSizesProductInt = 1;
+    bool foundDynamic = false;
+    for (int64_t index : indexGroup) {
+      int64_t outputDimSize = expandedType.getDimSize(index);
+      // Cannot infer expanded shape with multiple dynamic dims in the
+      // same reassociation group!
+      if (ShapedType::isDynamic(outputDimSize)) {
+        if (foundDynamic)
+          return failure();
+        foundDynamic = true;
+      } else {
+        indexGroupStaticSizesProductInt *= outputDimSize;
+      }
+    }
+    Value indexGroupStaticSizesProduct =
+        b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+
+    int64_t inputIndex = it.index();
+    for (int64_t index : indexGroup) {
+      if (ShapedType::isDynamic(expandedType.getDimSize(index))) {
+        // Call get<Value>() under the assumption that we're not casting
+        // dynamism.
+        Value indexGroupSize = inputShape[inputIndex].get<Value>();
----------------
MaheshRavishankar wrote:

We want to avoid creating `arith.divui` ops directly. Use `affine.apply` operations instead. Those fold better.

https://github.com/llvm/llvm-project/pull/69267


More information about the Mlir-commits mailing list