[Mlir-commits] [mlir] Let `memref.{expand, collapse}_shape` implement `ReifyRankedShapedTypeOpInterface` (PR #89111)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Thu Apr 18 00:54:25 PDT 2024


================
@@ -2079,6 +2080,95 @@ void ExpandShapeOp::getAsmResultNames(
   setNameFn(getResult(), "expand_shape");
 }
 
+LogicalResult ExpandShapeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  SmallVector<OpFoldResult> resultDims;
+  ArrayRef<int64_t> expandedShape = this->getResultType().getShape();
+  for (size_t expanded_dim = 0; expanded_dim < expandedShape.size();
+       ++expanded_dim) {
+    if (ShapedType::isDynamic(expandedShape[expanded_dim])) {
+      // Dynamic dimension case. Map expanded_dim to the corresponded
+      // collapsed dim. All other expanded dimensions corresponding to
+      // that collapsed dim must be static-size. Compute their product
+      // to divide the result size by.
+      auto reassoc = this->getReassociationIndices();
+      for (size_t collapsed_dim = 0; collapsed_dim < reassoc.size();
+           ++collapsed_dim) {
+        ReassociationIndices associated_dims = reassoc[collapsed_dim];
+        bool found_expanded_dim = false;
+        int64_t other_associated_dims_product_size = 1;
+        for (size_t associated_dim : associated_dims) {
+          if (associated_dim == expanded_dim) {
+            found_expanded_dim = true;
+          } else {
+            assert(!ShapedType::isDynamic(expandedShape[associated_dim]) &&
+                   "At most one dimension of a reassociation group may be "
+                   "dynamic in the result type.");
+            other_associated_dims_product_size *= expandedShape[associated_dim];
+          }
+        }
+        if (!found_expanded_dim) {
+          continue;
+        }
+        Value srcDimSize =
+            builder.create<memref::DimOp>(getLoc(), getSrc(), collapsed_dim);
+        Value resultDimSize = builder.create<arith::DivSIOp>(
----------------
ftynse wrote:

I'd rather use an unsigned division here, it lowers to simpler code on most targets, and sizes are unsigned.

https://github.com/llvm/llvm-project/pull/89111


More information about the Mlir-commits mailing list