[Mlir-commits] [mlir] [mlir][linalg] Fix and Refactor DecomposeOuterUnitDimsUnPackOpPattern (PR #119379)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Dec 11 13:09:35 PST 2024


================
@@ -1252,64 +1252,88 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
         "require the tiled outer dimensions of the result are all 1s");
   }
 
-  // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
+  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+  //    %extracted_tile = tensor.extract_slice(%unpack_op_input)
   Location loc = unpackOp.getLoc();
   Value source = unpackOp.getSource();
   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
       unpackOp.getDimAndTileMapping();
   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
-  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
-  SmallVector<OpFoldResult> readSizes;
-  SmallVector<int64_t> readShape;
-  SmallVector<Value> dynamicDims;
+
+  // The sizes, affset and strides attributes for ExtractSliceOp.
+  SmallVector<OpFoldResult> extractSliceSizes;
+  SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
+  SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
+  // The shape for ExtractSliceOp (due to rank-reducing, this is likely !=
+  // extractSliceSizes).
+  SmallVector<int64_t> readShapeForExtractSlice;
+
+  // Shape for EmptyOp that's used as the init value for TransposeOp below.
+  // This should match tile size + transposition.
+  SmallVector<OpFoldResult> shapeForEmptyOp;
+
   for (auto i : llvm::seq<unsigned>(0, destRank)) {
+    // Given the assumption that all outer tiled dims are 1, the corresponding
+    // slice size to read is also 1. As this will be rank-reducing "extract
+    // slice" (i.e. the unit dims will be "collapsed"), there's no need to
+    // update:
+    //  * the output shape for ExtractSliceOp, nor
+    //  * the shape for EmptyOp.
     if (dimAndTileMapping.count(i)) {
-      readSizes.push_back(oneIdxAttr);
+      extractSliceSizes.push_back(oneIdxAttr);
       continue;
     }
 
+    // Compute sizes attribute for ExtractSliceOp + EmptyOp
     if (ShapedType::isDynamic(srcShape[i])) {
-      Value dynamicDim =
+      OpFoldResult dynamicDim =
           rewriter.create<tensor::DimOp>(loc, source, i).getResult();
-      readSizes.push_back(dynamicDim);
-      dynamicDims.push_back(dynamicDim);
+      extractSliceSizes.push_back(dynamicDim);
+      shapeForEmptyOp.push_back(dynamicDim);
     } else {
-      readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+      extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+      if (srcShape[i] != 1)
+        shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
+    }
+    // Compute the output shape for ExtractSliceOp (take into account
+    // rank-reducing)
+    if (srcShape[i] != 1) {
+      readShapeForExtractSlice.push_back(srcShape[i]);
     }
-    if (srcShape[i] != 1)
-      readShape.push_back(srcShape[i]);
   }
   auto mixedTiles = unpackOp.getMixedTiles();
-  readSizes.append(mixedTiles.begin(), mixedTiles.end());
+  // TODO: This effectively assumes that that tile sizes match the trailing
+  // sizes for ExtractSliceOp and EmptyOp - document this.
----------------
banach-space wrote:

> Perhaps a quick check that there's only ever 1s in the beginning

See my updated comments (sending shortly) and how I split dims into:
```cpp
[ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
```

* `outer-tiled-dims` - no need to care about these - that's already checked at the top. 
* `outer-untiled-dims` - we are not making any assumptions about these. However, I'd like to change that to match [DecomposeOuterUnitDimsPackOpPattern](https://github.com/llvm/llvm-project/blob/796a1cf70639697325a86a56a0e482add19e1d56/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp#L1147-L1151)

> and mixedTiles in the end?

* `tile-sizes` - the docs for [tensor.unpack](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorunpack-tensorunpackop) make it almost clear that these are always the trailing dims (so it's already documented - my bad). [tensor.pack docs](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorpack-tensorpackop) are more specific: "These tile sizes correspond to the least significant (“inner”) result tensor dimension sizes". 

Hope I'm not missing something here!

https://github.com/llvm/llvm-project/pull/119379


More information about the Mlir-commits mailing list