[Mlir-commits] [mlir] [mlir][tensor] Generalize/restrict `GeneralizeOuterUnitDimsPackOpPattern` (PR #114315)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Nov 6 01:00:09 PST 2024


================
@@ -1148,69 +1172,104 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
         packOp, "require the tiled outer dimensions of the result are all 1s");
   }
 
-  // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
-  // outer dims.
+  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
+  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   Location loc = packOp.getLoc();
+
   Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
   auto inputShape = packOp.getSourceType().getShape();
   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
       packOp.getDimAndTileMapping();
-  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
-  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   int64_t srcRank = packOp.getSourceRank();
+
+  int64_t destRank = packOp.getDestRank();
+  size_t numTiles = destRank - srcRank;
+
+  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+  //    %extracted_tile = tensor.extract_slice(%pack_op_input)
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
-  SmallVector<OpFoldResult> readSizes;
-  SmallVector<OpFoldResult> transShapeForEmpty;
-  SmallVector<int64_t> readShapeForExtractSlice;
+
+  // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
+  // all outer dims are 1.
+  SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
+  // The shape of the output for ExtractSliceOp. All leading unit dims are
+  // effectively rank-reduced, hence skipped.
+  SmallVector<int64_t> outputShapeForExtractSlice;
+
+  // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
+  // be equal to the inner tile sizes.
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
     if (dimAndTileMapping.count(i)) {
-      readShapeForExtractSlice.push_back(
-          getConstantIntValue(dimAndTileMapping[i])
-              .value_or(ShapedType::kDynamic));
-      readSizes.push_back(dimAndTileMapping[i]);
-      transShapeForEmpty.push_back(dimAndTileMapping[i]);
-      continue;
-    }
-    if (ShapedType::isDynamic(inputShape[i])) {
-      readSizes.push_back(
-          rewriter.create<tensor::DimOp>(loc, input, i).getResult());
-    } else {
-      readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
-    }
-    if (inputShape[i] != 1) {
-      readShapeForExtractSlice.push_back(inputShape[i]);
-      transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
+      auto [tileSize, tileSizeOfr] =
+          getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
+      extractSliceSizes.push_back(tileSizeOfr);
+      outputShapeForExtractSlice.push_back(tileSize);
     }
   }
 
   Type elemType = packOp.getSourceType().getElementType();
-  auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
+  auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
 
   Value tile = rewriter.create<tensor::ExtractSliceOp>(
-      loc, readType, input, readOffsets, readSizes, readStrides);
+      loc, readType, input, readOffsets, extractSliceSizes, readStrides);
 
-  // 2. Transpose the tile to match the inner tile order.
+  // 2. Transpose the tile to match the inner tile order:
+  //    %init = tensor.empty()
+  //    %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
+  //  NOTE: Outer dims are 1 and hence effectively ignored.
   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
       inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
              llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
 
-  applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
+  // 2.1 Create tensor.empty (init value for TransposeOp)
+  SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
+  SmallVector<int64_t> transShapeForEmptyOpStatic;
+
+  // Acquire tensor shape required to create EmptyOp. This will match the inner
+  // tile sizes, but the actual data format will depend on whether the tile
+  // sizes are static or dynamic (each case leads to a different builder for
+  // EmptyOp). Conservatively, prepare for both scenarios.
+  size_t idx = numTiles;
+  while (idx != 0) {
+    transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
+    transShapeForEmptyOpStatic.push_back(
+        outputShapeForExtractSlice[numTiles - idx]);
+    idx--;
+  }
 
-  Value empty =
-      rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
+  applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
+  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
+
+  Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
----------------
banach-space wrote:

> Shouldn't the builder for the static case always produce the same result as the dynamic case? Can we just keep the dynamic path?

Great point! 

It turns out that [EmptyOp::build](https://github.com/llvm/llvm-project/blob/08411c855f77bd7416725c280ad3dccdc00b7dd6/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L887-L894) already supports the necessary "magic" via [dispatchIndexOpFoldResults](https://github.com/llvm/llvm-project/blob/08411c855f77bd7416725c280ad3dccdc00b7dd6/mlir/lib/Dialect/Utils/StaticValueUtils.cpp#L61-L66) :) 

https://github.com/llvm/llvm-project/pull/114315


More information about the Mlir-commits mailing list