[Mlir-commits] [mlir] [mlir][tensor] Update `GeneralizeOuterUnitDimsPackOpPattern` (PR #115312)

Han-Chung Wang llvmlistbot at llvm.org
Mon Nov 11 13:19:03 PST 2024


================
@@ -1153,71 +1153,65 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   Location loc = packOp.getLoc();
 
   Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
-  auto inputShape = packOp.getSourceType().getShape();
   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
       packOp.getDimAndTileMapping();
   int64_t srcRank = packOp.getSourceRank();
-
   int64_t destRank = packOp.getDestRank();
-  size_t numTiles = destRank - srcRank;
-
-  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
-  //    %extracted_tile = tensor.extract_slice(%pack_op_input)
-  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
-  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
+  int64_t numTiles = destRank - srcRank;
 
-  // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
-  // all outer dims are 1.
-  SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
-  // The shape of the output for ExtractSliceOp. All leading unit dims are
-  // effectively rank-reduced, hence skipped.
-  SmallVector<int64_t> outputShapeForExtractSlice;
+  if (!llvm::all_of(packOp.getInnerDimsPos(),
+                    [&srcRank, &numTiles](int64_t dimPos) {
+                      return dimPos >= (srcRank - numTiles - 1);
+                    }))
+    return rewriter.notifyMatchFailure(
+        packOp, "Attempting to tile non-trailing source dims!");
 
-  // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
-  // be equal to the inner tile sizes.
+  // 1. Extract the inner tile sizes.
+  // Where possible, values are replaced with constant attributes (to match the
+  // behaviour of `getPackOpSourceOrPaddedSource`).
+  SmallVector<OpFoldResult> tileSizes;
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
     if (dimAndTileMapping.count(i)) {
-      auto [tileSize, tileSizeOfr] =
+      // Rather than taking the tile size as is, extact the actual constant
+      // value Attribute where possible, e.g.:
+      //    [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
+      auto [_, tileSize] =
           getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
-      extractSliceSizes.push_back(tileSizeOfr);
-      outputShapeForExtractSlice.push_back(tileSize);
+      tileSizes.push_back(tileSize);
     }
   }
 
-  Type elemType = packOp.getSourceType().getElementType();
-  auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
-
-  Value tile = rewriter.create<tensor::ExtractSliceOp>(
-      loc, readType, input, readOffsets, extractSliceSizes, readStrides);
-
-  // 2. Transpose the tile to match the inner tile order:
+  // 2. Transpose the input to match the inner tile order:
   //    %init = tensor.empty()
   //    %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
----------------
hanhanW wrote:

This comment needs to be updated. There are no extracted_tile anymore?

https://github.com/llvm/llvm-project/pull/115312


More information about the Mlir-commits mailing list