[Mlir-commits] [mlir] 009c053 - [mlir][linalg] Allow outer dims perm and untiled dims in pack/unpack generalization

Quinn Dawkins llvmlistbot at llvm.org
Tue May 2 09:29:29 PDT 2023


Author: Quinn Dawkins
Date: 2023-05-02T12:26:45-04:00
New Revision: 009c053e3f822d0df556c6b39f632e31594373de

URL: https://github.com/llvm/llvm-project/commit/009c053e3f822d0df556c6b39f632e31594373de
DIFF: https://github.com/llvm/llvm-project/commit/009c053e3f822d0df556c6b39f632e31594373de.diff

LOG: [mlir][linalg] Allow outer dims perm and untiled dims in pack/unpack generalization

Extends the pack/unpack generalization patterns to work for any packing
op with only full tiles. This produces a combination of rank-reduced
insert/extract slice ops paired with a transpose on the reduced shape,
similar to what the pattern currently produces for fully tiled
pack/unpacks. Note that only the outer dims are rank-reduced in this
pattern, leaving the shape of the inner tile intact.

Differential Revision: https://reviews.llvm.org/D147555

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
    mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4f3f2dc0c734b..4a5c69c5bc061 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1241,66 +1241,124 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                  /*nofold=*/false, loc, builder);
 }
 
+// Normalizes a permutation on a higher rank space to its actual size, e.g.
+//   perm = [1, 4, 2]
+// becomes
+//   norm = [0, 2, 1]
 static SmallVector<int64_t>
-getPackUnpackNormalizedInnerPerm(int rank, ArrayRef<int64_t> innerDimsPos) {
+getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
   constexpr int64_t kNonTiledMarker = -1;
   SmallVector<int64_t> vec(rank, kNonTiledMarker);
-  for (auto [index, value] : llvm::enumerate(innerDimsPos))
+  for (auto [index, value] : llvm::enumerate(perm))
     vec[value] = index;
-  SmallVector<int64_t> perm = llvm::to_vector(llvm::make_filter_range(
+  SmallVector<int64_t> normalizedPerm = llvm::to_vector(llvm::make_filter_range(
       vec, [&](int64_t v) { return v != kNonTiledMarker; }));
+  // This inverts the permutation in addition to normalizing so invert back.
+  return invertPermutationVector(normalizedPerm);
+}
+
+// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
+// assuming rank reduction of unit outer dims.
+static SmallVector<int64_t>
+getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
+                             ArrayRef<int64_t> innerDimsPos,
+                             ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> rankReducedOuterDimsPerm;
+  SmallVector<int64_t> outerDims;
+  SmallVector<int64_t> innerDims;
+  int64_t dim = 0;
+  int64_t unpackedRank = shape.size();
+  for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
+    if (llvm::is_contained(innerDimsPos, i)) {
+      innerDims.push_back(dim++);
+      continue;
+    }
+    if (shape[i] == 1)
+      continue;
+    outerDims.push_back(dim++);
+    if (!outerDimsPerm.empty())
+      rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
+  }
+
+  // Get the position of the inner dims after permutation.
+  SmallVector<int64_t> innerPerm =
+      getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
+  applyPermutationToVector<int64_t>(innerDims, innerPerm);
+
+  // Ditto for the outer dims.
+  SmallVector<int64_t> perm = outerDims;
+
+  rankReducedOuterDimsPerm =
+      getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
+  if (!rankReducedOuterDimsPerm.empty())
+    applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
+
+  // The tile always ends up as the inner most dims after packing.
+  perm.append(innerDims);
+
   return perm;
 }
 
 LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
-  // TODO: support the case that outer dimensions are not all 1s A
-  // tensor.expand_shape will be generated in this case.
-  int64_t srcRank = packOp.getSourceRank();
-  if (llvm::any_of(packOp.getDestType().getShape().take_front(srcRank),
-                   [](int64_t val) { return val != 1; })) {
-    return rewriter.notifyMatchFailure(
-        packOp, "require the outer dimension of the result are all 1s");
-  }
-
   if (llvm::any_of(packOp.getMixedTiles(),
                    [](OpFoldResult tile) { return tile.is<Value>(); })) {
     return rewriter.notifyMatchFailure(packOp,
                                        "require inner tile sizes being static");
   }
 
-  // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
+  // TODO: support the case that outer dimensions are not all 1s. A
+  // tensor.expand_shape will be generated in this case.
+  auto innerDimsPos = packOp.getInnerDimsPos();
+  int64_t srcRank = packOp.getSourceRank();
+  auto destShape = packOp.getDestType().getShape();
+  if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
+        return destShape[index] != 1;
+      })) {
+    return rewriter.notifyMatchFailure(
+        packOp, "require the tiled outer dimensions of the result are all 1s");
+  }
+
+  // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
+  // outer dims.
   Location loc = packOp.getLoc();
+  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
+  auto inputShape = packOp.getSourceType().getShape();
+  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+      packOp.getDimAndTileMapping();
   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> readSizes;
   SmallVector<int64_t> readShape;
-  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
-      packOp.getDimAndTileMapping();
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
-    if (!dimAndTileMapping.count(i)) {
-      readSizes.push_back(oneIdxAttr);
+    if (dimAndTileMapping.count(i)) {
+      readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
+                              .value_or(ShapedType::kDynamic));
+      readSizes.push_back(dimAndTileMapping[i]);
       continue;
     }
-    readSizes.push_back(dimAndTileMapping[i]);
-    readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
-                            .value_or(ShapedType::kDynamic));
+    if (ShapedType::isDynamic(inputShape[i])) {
+      readSizes.push_back(
+          rewriter.create<tensor::DimOp>(loc, input, i).getResult());
+    } else {
+      readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
+    }
+    if (inputShape[i] != 1)
+      readShape.push_back(inputShape[i]);
   }
+
   Type elemType = packOp.getSourceType().getElementType();
   auto readType = RankedTensorType::get(readShape, elemType);
 
-  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
   Value tile = rewriter.create<tensor::ExtractSliceOp>(
       loc, readType, input, readOffsets, readSizes, readStrides);
 
   // 2. Transpose the tile to match the inner tile order.
-  SmallVector<int64_t> perm =
-      getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos());
-  // The permutation is inverted when normalizing so invert back to match the
-  // ordering in the pack op.
-  perm = invertPermutationVector(perm);
+
+  SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
+      inputShape, innerDimsPos, packOp.getOuterDimsPerm());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
              llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1316,9 +1374,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   int64_t destRank = packOp.getDestRank();
   SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
   SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
-  SmallVector<OpFoldResult> writeSizes(srcRank, oneIdxAttr);
-  for (auto size : transpShape)
-    writeSizes.push_back(rewriter.getIndexAttr(size));
+  SmallVector<OpFoldResult> writeSizes =
+      tensor::getMixedSizes(rewriter, loc, packOp.getDest());
 
   auto insert = rewriter.create<tensor::InsertSliceOp>(
       loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
@@ -1333,35 +1390,59 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   int64_t srcRank = unpackOp.getSourceRank();
   int64_t destRank = unpackOp.getDestRank();
   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
-  if (llvm::any_of(srcShape.take_front(destRank),
-                   [](int64_t val) { return val != 1; })) {
+  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
+  if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
+        return srcShape[index] != 1;
+      })) {
     return rewriter.notifyMatchFailure(
-        unpackOp, "require the outer dimension of the result are all 1s");
+        unpackOp,
+        "require the tiled outer dimensions of the result are all 1s");
   }
 
   // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
   Location loc = unpackOp.getLoc();
+  Value source = unpackOp.getSource();
+  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+      unpackOp.getDimAndTileMapping();
   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
+  SmallVector<OpFoldResult> readSizes;
+  SmallVector<int64_t> readShape;
+  for (auto i : llvm::seq<unsigned>(0, destRank)) {
+    if (dimAndTileMapping.count(i)) {
+      readSizes.push_back(oneIdxAttr);
+      continue;
+    }
 
+    if (ShapedType::isDynamic(srcShape[i])) {
+      readSizes.push_back(
+          rewriter.create<tensor::DimOp>(loc, source, i).getResult());
+    } else {
+      readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+    }
+    if (srcShape[i] != 1)
+      readShape.push_back(srcShape[i]);
+  }
   auto mixedTiles = unpackOp.getMixedTiles();
-  SmallVector<OpFoldResult> readSizes(destRank, oneIdxAttr);
   readSizes.append(mixedTiles.begin(), mixedTiles.end());
 
   // Explicitly create the type for extract_slice op because the inner tile
   // size could be 1. We want to represent the whole inner tile in this case.
-  ArrayRef<int64_t> readShape = srcShape.drop_front(destRank);
+  auto tileShape = srcShape.drop_front(destRank);
+  // Append the inner tile shape to the permuted and rank-reduced outer shape.
+  readShape.append(tileShape.begin(), tileShape.end());
   Type elemType = unpackOp.getSourceType().getElementType();
   auto readType = RankedTensorType::get(readShape, elemType);
   Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
       loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
 
   // 2. Transpose the tile to match the outer corresponding tile order.
-  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
-  SmallVector<int64_t> perm =
-      getPackUnpackNormalizedInnerPerm(srcRank, innerDimsPos);
+  SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
+      srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
+  // Unpack is a transition out of packed space so we invert the permutation.
+  perm = invertPermutationVector(perm);
   SmallVector<int64_t> transpShape(readShape);
   applyPermutationToVector<int64_t>(transpShape, perm);
 
@@ -1375,11 +1456,13 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
   SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
   SmallVector<OpFoldResult> tileSizes;
-  for (int dim : innerDimsPos)
-    tileSizes.push_back(getAsOpFoldResult(
-        rewriter.createOrFold<tensor::DimOp>(loc, unpackOp.getDest(), dim)));
+  ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
+  for (auto i : llvm::seq<unsigned>(0, destRank)) {
+    if (dimAndTileMapping.count(i) || destShape[i] != 1)
+      tileSizes.push_back(getAsOpFoldResult(
+          rewriter.createOrFold<tensor::DimOp>(loc, unpackOp.getDest(), i)));
+  }
 
-  applyPermutationToVector<OpFoldResult>(tileSizes, perm);
   auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
       loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
 
@@ -1387,10 +1470,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   SmallVector<OpFoldResult> writeSizes;
   SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
   SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
-  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
-      unpackOp.getDimAndTileMapping();
   for (int i = 0, idx = 0; i < destRank; ++i) {
-    if (dimAndTileMapping.count(i))
+    if (dimAndTileMapping.count(i) || destShape[i] != 1)
       writeSizes.push_back(tileSizes[idx++]);
     else
       writeSizes.push_back(oneIdxAttr);

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 8e9b77ed6f679..283cb43e2997b 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -76,3 +76,22 @@ func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 5, 7, 3] [1, 1, 1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1: tensor<3x1x1x1x8x32xf32>) -> tensor<3x1x1x1x8x32xf32> {
+  %0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32>
+  return %0 : tensor<3x1x1x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [3, 1, 32, 8] [1, 1, 1, 1]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<3x8x32xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[TILE]] : tensor<3x32x8xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<3x8x32xf32>)
+// CHECK-SAME:      permutation = [0, 2, 1]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [3, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK:         return %[[INSERT]]

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index cc734c24d4f56..a596690c2e4fd 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -55,3 +55,42 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
 //                They have the same type, so the insert_slice op is folded
 //                away.
 // CHECK:         return %[[TRANSP]]
+
+// -----
+
+func.func @simple_NCHWc_to_NCHW(%arg0: tensor<2x1x16x8x32xf32>, %arg1: tensor<2x32x16x8xf32>) -> tensor<2x32x16x8xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %arg1 : tensor<2x1x16x8x32xf32> -> tensor<2x32x16x8xf32>
+  return %0 : tensor<2x32x16x8xf32>
+}
+// CHECK-LABEL: func.func @simple_NCHWc_to_NCHW
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x32x16x8xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[TILE]] : tensor<2x16x8x32xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<2x32x16x8xf32>)
+// CHECK-SAME:      permutation = [0, 3, 1, 2]
+//                They have the same type, so the insert_slice op is folded
+//                away.
+// CHECK:         return %[[TRANSP]]
+
+
+// -----
+
+func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x16x8xf32>) -> tensor<1x32x16x8xf32> {
+  %0 = tensor.unpack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [] inner_tiles = [] into %arg1 : tensor<1x16x8x32xf32> -> tensor<1x32x16x8xf32>
+  return %0 : tensor<1x32x16x8xf32>
+}
+// CHECK-LABEL: func.func @simple_NHWC_to_NCHW
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 16, 8, 32] [1, 1, 1, 1]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x8xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[TILE]] : tensor<16x8x32xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<32x16x8xf32>)
+// CHECK-SAME:      permutation = [2, 0, 1]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1]
+// CHECK:         return %[[INSERT]]


        


More information about the Mlir-commits mailing list