[Mlir-commits] [mlir] d7904a7 - [MLIR] Fold outer dims permutation to pack when propagating

Lorenzo Chelini llvmlistbot at llvm.org
Fri Jan 13 07:11:49 PST 2023


Author: Lorenzo Chelini
Date: 2023-01-13T16:11:44+01:00
New Revision: d7904a702fe80e482b7fbb132c46863afd6eb3be

URL: https://github.com/llvm/llvm-project/commit/d7904a702fe80e482b7fbb132c46863afd6eb3be
DIFF: https://github.com/llvm/llvm-project/commit/d7904a702fe80e482b7fbb132c46863afd6eb3be.diff

LOG: [MLIR] Fold outer dims permutation to pack when propagating

Instead of folding the transpose into the linalg.generic keep the
transposition in the packing operation, effectively making the
linalg.generic transparent to the propagation. Additionally, if the init
operand of the generic has users pack the init and pass it as the
operand to the generic.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D141483

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 5e540979f58ed..5660704606d0c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -87,11 +87,39 @@ static PackInfo getPackingInfoFromConsumer(
   return packInfo;
 }
 
+static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
+                                             ArrayRef<AffineExpr> exprs) {
+  // Compute `outer_dims_perm`. See example:
+  // current exprs      : (d0, d1, d2, d3) -> (d2, d3)
+  // perm               : [0, 3, 1, 2]
+  // First map d2, d3 with their position in the array as:
+  // currentPositionTileLoops: dim | pos
+  //                           d2  | 0
+  //                           d3  | 1
+  // then scan `perm` in order and get the `outer_dims_perm`
+  // to be used, here it would be [1, 0].
+  assert(!perm.empty() && "expect perm not to be empty");
+  assert(!exprs.empty() && "expect exprs not to be empty");
+  if (exprs.size() == 1)
+    return {};
+  SmallVector<int64_t> outerDimsPerm;
+  DenseMap<int64_t, int64_t> currentPositionTileLoops;
+  for (auto [pos, expr] : llvm::enumerate(exprs)) {
+    unsigned posInDomain = expr.cast<AffineDimExpr>().getPosition();
+    currentPositionTileLoops[posInDomain] = pos;
+  }
+  for (int64_t loopIdx : perm) {
+    if (currentPositionTileLoops.count(loopIdx))
+      outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
+  }
+  return outerDimsPerm;
+}
+
 /// Returns a tuple for packed operand and indexing_map with the assumptions:
 ///   1) The generic op is the producer of the pack op.
 ///   2) The generic op has only one result.
 /// If the operand is a scalar or packing dimensions are all irrelevant to the
-/// operand, the opreand and the updated indexing map will be returned.
+/// operand, the operand and the updated indexing map will be returned.
 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
 ///
 ///   #map0 = affine_map<(d0, d1) -> (d0, d1)>
@@ -148,16 +176,26 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
     exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
   }
 
-  // Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op.
-  // TODO: should we propagate the permutation of outer dims to the pack op?
+  // Step 2. Handle outer dim permutations.
   SmallVector<int64_t> outerDimsPerm;
   if (!packInfo.outerDimsOnDomainPerm.empty()) {
+    outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
+
+    // Step 2.1: Fold transpose into the linalg.generic.
     SmallVector<int64_t> inversedOuterPerm =
         invertPermutationVector(packInfo.outerDimsOnDomainPerm);
     for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
       int64_t dimPos = exprs[i].cast<AffineDimExpr>().getPosition();
       exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
     }
+    // Step 2.2: Undo the transposition on `exprs` and propagate the
+    // transposition on the pack using outerDimsPerm.
+    if (!outerDimsPerm.empty()) {
+      SmallVector<AffineExpr> auxVec = exprs;
+      for (const auto &en : enumerate(outerDimsPerm))
+        auxVec[en.index()] = exprs[en.value()];
+      exprs = auxVec;
+    }
   }
   auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
 
@@ -254,9 +292,7 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
     indexingMaps.push_back(packedIndexingMap);
   }
 
-  int64_t numLoops = genericOp.getNumLoops();
   int64_t numInnerLoops = packInfo.getNumTiledLoops();
-  int64_t newNumLoops = numLoops + numInnerLoops;
   SmallVector<utils::IteratorType> iterTypes =
       genericOp.getIteratorTypesArray();
   iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
@@ -265,24 +301,18 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
   auto [packedOutOperand, packedOutIndexingMap] =
       getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp,
                                      opOperand);
-  SmallVector<AffineExpr> outExprs(
-      packedOutIndexingMap.getResults().drop_back(numInnerLoops));
-  // Apply transpose to the indexing map, because we'll replace the init operand
-  // with the destination of pack op.
-  auto outerDimsPerm = packOp.getOuterDimsPerm();
-  if (!outerDimsPerm.empty()) {
-    applyPermutationToVector<AffineExpr>(outExprs, outerDimsPerm);
-  }
-  for (int i = 0; i < numInnerLoops; ++i)
-    outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i));
-  AffineMap outMap =
-      AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext());
-  indexingMaps.push_back(outMap);
+  indexingMaps.push_back(packedOutIndexingMap);
 
+  // We'll replace the init operand with the destination of pack op if the init
+  // operand has not users in the body of the linalg.generic (pure elementwise).
+  // If it has users we need to pack the init operand too and replace the init
+  // with the packing result.
+  Value dest = (genericOp.getRegionOutputArgs()[0].use_empty())
+                   ? packOp.getDest()
+                   : packedOutOperand;
   auto newGenericOp = rewriter.create<linalg::GenericOp>(
-      loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps,
-      iterTypes, /*bodyBuild=*/nullptr,
-      linalg::getPrunedAttributeList(genericOp));
+      loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
   rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
                              newGenericOp.getRegion().begin());
   return newGenericOp;

diff  --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index bb84272bf8b02..cd9d3acc7a635 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -96,17 +96,16 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: ten
     into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
   return %pack : tensor<16x4x32x16xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK:      func.func @elem_pack_transpose_outer_dims
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16xi32>
+// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
 // CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:     into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<4x16x32x16xi32>
+// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:     into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
 // CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
 // CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[PACK_ARG0]]
 // CHECK-SAME:     outs(%[[DEST]]
@@ -131,17 +130,16 @@ func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>,
     into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
   return %pack : tensor<16x4x16x32xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK:      func.func @elem_pack_transpose_inner_and_outer_dims
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
+// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
 // CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
 // CHECK-SAME:     into %[[ARG0_EMPTY]]
 // CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
 // CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[PACK_ARG0]]
 // CHECK-SAME:     outs(%[[DEST]]
@@ -285,7 +283,7 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1)>
-func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32>
+func.func @transpose_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32>
 {
   %init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
   %transpose = linalg.generic {
@@ -308,3 +306,61 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
     into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
   return %4 : tensor<200x4x16x100x16x32xi32>
 }
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
+// CHECK:     func.func @transpose_pack_with_outer_dims
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<200x4x16x100x16x32xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32]
+// CHECK-SAME:  into %[[ARG0_EMPTY]]
+// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
+// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:  into %[[ARG2_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP]]]
+// CHECK-SAME:  ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
+// CHECK-SAME:  outs(%[[DEST]]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
+  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<128x256xi32>)
+      outs(%init : tensor<128x256xi32>) {
+    ^bb0(%arg3: i32, %arg4: i32):
+      %4 = arith.addi %arg3, %arg4 : i32
+      linalg.yield %4 : i32
+  } -> tensor<128x256xi32>
+  %empty = tensor.empty() : tensor<16x4x32x16xi32>
+  %pack = tensor.pack %elem
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [32, 16]
+    into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+  return %pack : tensor<16x4x32x16xi32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG0_EMPTY]]
+// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG1_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME:  ins(%[[PACKED_ARG0]]
+// CHECK-SAME:  outs(%[[PACKED_ARG1]]


        


More information about the Mlir-commits mailing list