When an expanded dim is not factorable, emit a `tensor.reshape` instead of a `tensor.expand_shape`

4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2-1) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+1-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+54-21) 
- (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+46) 

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 77ed9db5e71bd1..4abd3740b57105 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -498,7 +498,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
   let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
   let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
-                      Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
+                      Type<Or<[Transform_ConcreteOpType<"tensor.expand_shape">.predicate,
+                               Transform_ConcreteOpType<"tensor.reshape">.predicate]>>:$expand_shape_op,
   let assemblyFormat = [{
     $target attr-dict `:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a848d12fbbb50e..06e8586f4288b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op,
 struct LowerPackResult {
   tensor::PadOp padOp;
-  tensor::ExpandShapeOp expandShapeOp;
+  Operation *expandShapeOp; // `tensor::ExpandShapeOp` or `tensor::ReshapeOp`
   linalg::TransposeOp transposeOp;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9d230e2c2e5749..4550589ded6df8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -218,21 +218,11 @@ struct PackedOperandsDimList {
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                              tensor::PackOp packOp) {
-  // 1. Filter out NYI cases.
-  auto packedTensorType =
-      cast<RankedTensorType>(packOp->getResultTypes().front());
-  if (llvm::any_of(packOp.getStaticInnerTiles(),
-                   [](int64_t size) { return ShapedType::isDynamic(size); })) {
-    return rewriter.notifyMatchFailure(
-        packOp,
-        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
-  }
   Location loc = packOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
-  // 2. Compute the permutation vector to shuffle packed shape into the shape
+  // 1. Compute the permutation vector to shuffle packed shape into the shape
   // before any outer or inner permutations have been applied. The permutation
   // can be obtained from two permutations:
   //   a) Compute the permutation vector to move the last `numPackedDims` into
@@ -240,6 +230,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   //   b) Compute the permutation vector to move outer dims if the pack op
   //      has outer_dims_perm.
   // Apply (b) permutation on (a) permutation to get the final permutation.
+  auto packedTensorType =
+      cast<RankedTensorType>(packOp->getResultTypes().front());
   int64_t numPackedDims = packOp.getInnerDimsPos().size();
   int64_t packedRank = packedTensorType.getRank();
   auto lastDims = llvm::to_vector(
@@ -259,12 +251,12 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
   applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
-  // 3. Compute the stripMinedShape: this is the packed shape before any outer
+  // 2. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.
   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
-  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+  // 3. Pad the source of packOp to a shape we can expand into stripMinedShape.
   SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
   SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
@@ -351,24 +343,65 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
-  // 5. Expand from the padded result to the stripMinedShape.
-  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
-      loc,
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
-      padOp.getResult(), packingMetadata.reassociations);
-  // 6. Transpose stripMinedShape to packedShape.
+  // 4. Expand from the padded result to the stripMinedShape.
+  RankedTensorType expandDestType =
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
   SmallVector<int64_t> transpPerm =
+  Operation *reshapeOp;
+  // Check if any dims are not factorable and thus need a `tensor.reshape`
+  // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion
+  // requires at most one dynamnic dim
+  if (llvm::any_of(packingMetadata.reassociations,
+                   [&](const auto &rAssoc) -> bool {
+                     return llvm::count_if(rAssoc, [&](int64_t r) {
+                              return stripMinedShape[r] == ShapedType::kDynamic;
+                            }) > 1;
+                   })) {
+    SmallVector<OpFoldResult> sizes =
+        tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+    applyPermutationToVector(sizes, transpPerm);
+    // Create a `tensor` of `index` types for the `shape` operand of
+    // `tensor.reshape`
+    Value shapeInitTensor = rewriter.create<tensor::EmptyOp>(
+        loc,
+        RankedTensorType::get({expandDestType.getRank()},
+                              rewriter.getIndexType()),
+        ValueRange{});
+    Value shapeTensor = shapeInitTensor;
+    for (const auto &[i, size] : llvm::enumerate(sizes)) {
+      auto maybeConstInt = getConstantIntValue(size);
+      assert((maybeConstInt.has_value() || expandDestType.isDynamicDim(i)) &&
+             "expected dynamic dim");
+      Value dim =
+          (maybeConstInt.has_value())
+              ? rewriter
+                    .create<arith::ConstantIndexOp>(loc, maybeConstInt.value())
+                    .getResult()
+              : cast<Value>(size);
+      shapeTensor = rewriter.create<tensor::InsertOp>(
+          loc, dim, shapeTensor,
+          SmallVector<Value>(
+              {rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
+    }
+    reshapeOp = rewriter.create<tensor::ReshapeOp>(
+        loc, expandDestType, padOp.getResult(), shapeTensor);
+  } else {
+    reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+        loc, expandDestType, padOp.getResult(), packingMetadata.reassociations);
+  }
+  // 5. Transpose stripMinedShape to packedShape.
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
+      loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm);
              DBGS() << "reshape op: " << reshapeOp; DBGSNL();
              llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
              DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
-  // 7. Replace packOp by transposeOp.
+  // 6. Replace packOp by transposeOp.
   rewriter.replaceOp(packOp, transposeOp->getResults());
   return LowerPackResult{padOp, reshapeOp, transposeOp};
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 316df431a9c0c8..13d74cbe433264 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -61,6 +61,52 @@ module attributes {transform.with_named_sequence} {
 // -----
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 64)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 128)>
+// CHECK: func.func @pack_dyn_tiles(
+// CHECK-SAME:                            %[[ARG0:.*]]: [[TENSOR_TY_0:tensor<64x128xf32>]]
+// CHECK-SAME:                            %[[ARG1:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:                            %[[TILE0:.*]]: index,
+// CHECK-SAME:                            %[[TILE1:.*]]: index
+func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor<?x?x?x?xf32>, %tile_0: index, %tile_1: index) -> tensor<?x?x?x?xf32> {
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG:  %[[PAD0:.*]] = affine.apply #[[MAP0]]()[%[[TILE0]], %[[DIM0]]]
+// CHECK-DAG:  %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG:  %[[PAD1:.*]] = affine.apply #[[MAP1]]()[%[[TILE1]], %[[DIM1]]]
+// CHECK-DAG:   %[[CST:.*]]  = arith.constant 0.000000e+00 : f32
+// CHECK:      %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[PAD0]], %[[PAD1]]]
+// CHECK-NEXT:                   ^bb0
+// CHECK-NEXT:                    tensor.yield %[[CST]] : f32
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:  %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG:  %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK-NEXT:  %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex>
+// CHECK-NEXT:  %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]]
+// CHECK-NEXT:  %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]]
+// CHECK-NEXT:  %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]]
+// CHECK-NEXT:  %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]]
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]])
+// CHECK-NEXT:  %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[EXPANDED]] : {{.*}}) outs(%[[ARG1]] {{.*}}) permutation = [0, 2, 1, 3] 
+  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%tile_0, %tile_1] into %arg1
+    : tensor<64x128xf32> -> tensor<?x?x?x?xf32>
+  return %pack : tensor<?x?x?x?xf32>
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.pack">
+    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.reshape">, !transform.op<"linalg.transpose">)
+      transform.yield
+  }
+// -----
 // CHECK-LABEL: func.func @pack_as_pad(
 func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32




