[compiler-rt] [libc] [mlir] [llvm] [clang] [clang-tools-extra] [flang] [mlir][Linalg] Support dynamic shapes in `lower_pack` transform (PR #76003)

via cfe-commits cfe-commits at lists.llvm.org
Tue Dec 19 21:34:21 PST 2023


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/76003

>From 860a2f794bdf12ff1f08d4802570757e805264b0 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 18 Dec 2023 15:53:41 -0600
Subject: [PATCH 1/8] [mlir][Linalg] Support dynamic sizes in `lower_pack`
 transform

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  3 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  2 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 69 +++++++++++++------
 .../Dialect/Linalg/transform-lower-pack.mlir  | 20 ++++++
 4 files changed, 70 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 77ed9db5e71bd1..4abd3740b57105 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -498,7 +498,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
 
   let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
   let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
-                      Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
+                      Type<Or<[Transform_ConcreteOpType<"tensor.expand_shape">.predicate,
+                               Transform_ConcreteOpType<"tensor.reshape">.predicate]>>:$expand_shape_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
   let assemblyFormat = [{
     $target attr-dict `:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a848d12fbbb50e..344e801835ccc9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op,
 
 struct LowerPackResult {
   tensor::PadOp padOp;
-  tensor::ExpandShapeOp expandShapeOp;
+  Operation *expandShapeOp;
   linalg::TransposeOp transposeOp;
 };
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9d230e2c2e5749..359274866748fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -218,21 +218,11 @@ struct PackedOperandsDimList {
 
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                              tensor::PackOp packOp) {
-  // 1. Filter out NYI cases.
-  auto packedTensorType =
-      cast<RankedTensorType>(packOp->getResultTypes().front());
-  if (llvm::any_of(packOp.getStaticInnerTiles(),
-                   [](int64_t size) { return ShapedType::isDynamic(size); })) {
-    return rewriter.notifyMatchFailure(
-        packOp,
-        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
-  }
-
   Location loc = packOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(packOp);
 
-  // 2. Compute the permutation vector to shuffle packed shape into the shape
+  // 1. Compute the permutation vector to shuffle packed shape into the shape.
   // before any outer or inner permutations have been applied. The permutation
   // can be obtained from two permutations:
   //   a) Compute the permutation vector to move the last `numPackedDims` into
@@ -240,6 +230,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   //   b) Compute the permutation vector to move outer dims if the pack op
   //      has outer_dims_perm.
   // Apply (b) permutation on (a) permutation to get the final permutation.
+  auto packedTensorType =
+      cast<RankedTensorType>(packOp->getResultTypes().front());
   int64_t numPackedDims = packOp.getInnerDimsPos().size();
   int64_t packedRank = packedTensorType.getRank();
   auto lastDims = llvm::to_vector(
@@ -259,12 +251,12 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
   applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
 
-  // 3. Compute the stripMinedShape: this is the packed shape before any outer
+  // 2. Compute the stripMinedShape: this is the packed shape before any outer.
   // or inner permutations have been applied.
   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
-  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+  // 3. Pad the source of packOp to a shape we can expand into stripMinedShape.
   SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
                                  rewriter.getIndexAttr(0));
   SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
@@ -351,24 +343,57 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                              /*transposeOp=*/nullptr};
     }
   }
-  // 5. Expand from the padded result to the stripMinedShape.
-  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
-      loc,
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
-      padOp.getResult(), packingMetadata.reassociations);
 
-  // 6. Transpose stripMinedShape to packedShape.
+  RankedTensorType expandSourceType = padOp.getResult().getType().cast<RankedTensorType>();
+  RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+
+  // Dynamic dim is factorable only if the expanded version has at most one dynamic dim
+  bool isFactorable = true;
+  for (const auto &[i, rIndcs] : llvm::enumerate(packingMetadata.reassociations)) {
+    if (!expandSourceType.isDynamicDim(i))
+      continue;
+    int64_t numDyn = 0;
+    for (auto j : rIndcs) {
+      if ((stripMinedShape[j] == ShapedType::kDynamic) && (++numDyn > 1)) {
+        isFactorable = false;
+        break;
+      }
+    }
+  }
+
+  // 4. Expand from the padded result to the stripMinedShape.
   SmallVector<int64_t> transpPerm =
       invertPermutationVector(packedToStripMinedShapePerm);
+  Operation *reshapeOp;
+  if (!isFactorable) {
+    SmallVector<OpFoldResult> sizes =
+        tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+    applyPermutationToVector(sizes, transpPerm);
+    Value shapeInitTensor =
+        rewriter.create<tensor::EmptyOp>(loc, RankedTensorType::get({expandDestType.getRank()}, rewriter.getIndexType()), ValueRange{}); 
+    Value shapeTensor = shapeInitTensor;
+    for (const auto &[i, size] : llvm::enumerate(sizes)) {
+      Value dim = (expandDestType.isDynamicDim(i)) ? cast<Value>(size) : rewriter.create<arith::ConstantIndexOp>(loc, getConstantIntValue(size).value()).getResult();
+      shapeTensor = rewriter.create<tensor::InsertOp>(loc, dim, shapeTensor, SmallVector<Value>({rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
+    }
+    reshapeOp = rewriter.create<tensor::ReshapeOp>(loc, expandDestType, padOp.getResult(), shapeTensor);
+  } else {
+    reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+        loc,
+        expandDestType,
+        padOp.getResult(), packingMetadata.reassociations);
+  }
+
+  // 5. Transpose stripMinedShape to packedShape.
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
+      loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm);
 
   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
-             DBGS() << "reshape op: " << reshapeOp; DBGSNL();
+             DBGS() << "reshape op: " << &reshapeOp; DBGSNL();
              llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
              DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
 
-  // 7. Replace packOp by transposeOp.
+  // 6. Replace packOp by transposeOp.
   rewriter.replaceOp(packOp, transposeOp->getResults());
 
   return LowerPackResult{padOp, reshapeOp, transposeOp};
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 316df431a9c0c8..6a203dab91e58b 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -61,6 +61,26 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @pack_all_dyn(
+func.func @pack_all_dyn(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1
+    : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+
+  return %pack : tensor<?x?x?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.pack">
+    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.reshape">, !transform.op<"linalg.transpose">)
+      transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func.func @pack_as_pad(
 func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32

>From c8db4ac07c017dbdfbd8f91d47f32015ca9dce67 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 19:11:22 -0600
Subject: [PATCH 2/8] Refactor

---
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 54 ++++++++++---------
 1 file changed, 28 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 359274866748fc..21446d07b784a9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -344,44 +344,46 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
     }
   }
 
-  RankedTensorType expandSourceType = padOp.getResult().getType().cast<RankedTensorType>();
-  RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
-
-  // Dynamic dim is factorable only if the expanded version has at most one dynamic dim
-  bool isFactorable = true;
-  for (const auto &[i, rIndcs] : llvm::enumerate(packingMetadata.reassociations)) {
-    if (!expandSourceType.isDynamicDim(i))
-      continue;
-    int64_t numDyn = 0;
-    for (auto j : rIndcs) {
-      if ((stripMinedShape[j] == ShapedType::kDynamic) && (++numDyn > 1)) {
-        isFactorable = false;
-        break;
-      }
-    }
-  }
-
   // 4. Expand from the padded result to the stripMinedShape.
+  // Check if any dims are not factorable.  A dim is factorable if the expansion
+  // requires at most dynamnic dim
+  RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
   SmallVector<int64_t> transpPerm =
       invertPermutationVector(packedToStripMinedShapePerm);
   Operation *reshapeOp;
-  if (!isFactorable) {
+  if (llvm::any_of(packingMetadata.reassociations,
+                   [&](const auto &rAssoc) -> bool {
+                     return llvm::count_if(rAssoc, [&](int64_t r) {
+                              return stripMinedShape[r] == ShapedType::kDynamic;
+                            }) > 1;
+                   })) {
     SmallVector<OpFoldResult> sizes =
         tensor::getMixedSizes(rewriter, loc, packOp.getDest());
     applyPermutationToVector(sizes, transpPerm);
-    Value shapeInitTensor =
-        rewriter.create<tensor::EmptyOp>(loc, RankedTensorType::get({expandDestType.getRank()}, rewriter.getIndexType()), ValueRange{}); 
+    // Create a `tensor` of `index` types for the `shape` operand of `tensor.reshape`
+    Value shapeInitTensor = rewriter.create<tensor::EmptyOp>(
+        loc,
+        RankedTensorType::get({expandDestType.getRank()},
+                              rewriter.getIndexType()),
+        ValueRange{});
     Value shapeTensor = shapeInitTensor;
     for (const auto &[i, size] : llvm::enumerate(sizes)) {
-      Value dim = (expandDestType.isDynamicDim(i)) ? cast<Value>(size) : rewriter.create<arith::ConstantIndexOp>(loc, getConstantIntValue(size).value()).getResult();
-      shapeTensor = rewriter.create<tensor::InsertOp>(loc, dim, shapeTensor, SmallVector<Value>({rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
+      Value dim = (expandDestType.isDynamicDim(i))
+                      ? cast<Value>(size)
+                      : rewriter
+                            .create<arith::ConstantIndexOp>(
+                                loc, getConstantIntValue(size).value())
+                            .getResult();
+      shapeTensor = rewriter.create<tensor::InsertOp>(
+          loc, dim, shapeTensor,
+          SmallVector<Value>(
+              {rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
     }
-    reshapeOp = rewriter.create<tensor::ReshapeOp>(loc, expandDestType, padOp.getResult(), shapeTensor);
+    reshapeOp = rewriter.create<tensor::ReshapeOp>(
+        loc, expandDestType, padOp.getResult(), shapeTensor);
   } else {
     reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
-        loc,
-        expandDestType,
-        padOp.getResult(), packingMetadata.reassociations);
+        loc, expandDestType, padOp.getResult(), packingMetadata.reassociations);
   }
 
   // 5. Transpose stripMinedShape to packedShape.

>From e68b32e372de420b2e6ece98e574836920014c54 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 21:49:38 -0600
Subject: [PATCH 3/8] Add regression test

---
 .../Dialect/Linalg/transform-lower-pack.mlir  | 36 ++++++++++++++++---
 1 file changed, 31 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 6a203dab91e58b..13d74cbe433264 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -61,11 +61,37 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: func.func @pack_all_dyn(
-func.func @pack_all_dyn(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
-  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1
-    : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
-
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 64)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 128)>
+// CHECK: func.func @pack_dyn_tiles(
+// CHECK-SAME:                            %[[ARG0:.*]]: [[TENSOR_TY_0:tensor<64x128xf32>]]
+// CHECK-SAME:                            %[[ARG1:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:                            %[[TILE0:.*]]: index,
+// CHECK-SAME:                            %[[TILE1:.*]]: index
+func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor<?x?x?x?xf32>, %tile_0: index, %tile_1: index) -> tensor<?x?x?x?xf32> {
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG:  %[[PAD0:.*]] = affine.apply #[[MAP0]]()[%[[TILE0]], %[[DIM0]]]
+// CHECK-DAG:  %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG:  %[[PAD1:.*]] = affine.apply #[[MAP1]]()[%[[TILE1]], %[[DIM1]]]
+// CHECK-DAG:   %[[CST:.*]]  = arith.constant 0.000000e+00 : f32
+// CHECK:      %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[PAD0]], %[[PAD1]]]
+// CHECK-NEXT:                   ^bb0
+// CHECK-NEXT:                    tensor.yield %[[CST]] : f32
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:  %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG:  %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK-NEXT:  %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex>
+// CHECK-NEXT:  %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]]
+// CHECK-NEXT:  %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]]
+// CHECK-NEXT:  %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]]
+// CHECK-NEXT:  %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]]
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]])
+// CHECK-NEXT:  %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[EXPANDED]] : {{.*}}) outs(%[[ARG1]] {{.*}}) permutation = [0, 2, 1, 3] 
+  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%tile_0, %tile_1] into %arg1
+    : tensor<64x128xf32> -> tensor<?x?x?x?xf32>
   return %pack : tensor<?x?x?x?xf32>
 }
 

>From 0975552abe2d404388af48eafc39b464f69a4834 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 21:53:42 -0600
Subject: [PATCH 4/8] Fix comment

---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 21446d07b784a9..1f63d0ab706cdb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -345,12 +345,14 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   }
 
   // 4. Expand from the padded result to the stripMinedShape.
-  // Check if any dims are not factorable.  A dim is factorable if the expansion
-  // requires at most dynamnic dim
-  RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+  RankedTensorType expandDestType =
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
   SmallVector<int64_t> transpPerm =
       invertPermutationVector(packedToStripMinedShapePerm);
   Operation *reshapeOp;
+  // Check if any dims are not factorable and thus need a `tensor.reshape`
+  // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion
+  // requires at most dynamnic dim
   if (llvm::any_of(packingMetadata.reassociations,
                    [&](const auto &rAssoc) -> bool {
                      return llvm::count_if(rAssoc, [&](int64_t r) {
@@ -360,7 +362,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
     SmallVector<OpFoldResult> sizes =
         tensor::getMixedSizes(rewriter, loc, packOp.getDest());
     applyPermutationToVector(sizes, transpPerm);
-    // Create a `tensor` of `index` types for the `shape` operand of `tensor.reshape`
+    // Create a `tensor` of `index` types for the `shape` operand of
+    // `tensor.reshape`
     Value shapeInitTensor = rewriter.create<tensor::EmptyOp>(
         loc,
         RankedTensorType::get({expandDestType.getRank()},

>From 48deca06d650959ba3727df9697566a0fd6a6cd2 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 22:31:12 -0600
Subject: [PATCH 5/8] Properly check optional value

---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 1f63d0ab706cdb..2a1c72942df0bb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -371,12 +371,15 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
         ValueRange{});
     Value shapeTensor = shapeInitTensor;
     for (const auto &[i, size] : llvm::enumerate(sizes)) {
-      Value dim = (expandDestType.isDynamicDim(i))
-                      ? cast<Value>(size)
-                      : rewriter
-                            .create<arith::ConstantIndexOp>(
-                                loc, getConstantIntValue(size).value())
-                            .getResult();
+      auto maybeConstInt = getConstantIntValue(size);
+      assert(maybeConstInt.has_value() ||
+             expandDestType.isDynamicDim(i) && "expected dynamic dim");
+      Value dim =
+          (maybeConstInt.has_value())
+              ? rewriter
+                    .create<arith::ConstantIndexOp>(loc, maybeConstInt.value())
+                    .getResult()
+              : cast<Value>(size);
       shapeTensor = rewriter.create<tensor::InsertOp>(
           loc, dim, shapeTensor,
           SmallVector<Value>(

>From 194f8194659908f8127b99a807033192e1477def Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 22:37:10 -0600
Subject: [PATCH 6/8] Revert accidental change

---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2a1c72942df0bb..6018d58b94eb72 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -372,8 +372,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
     Value shapeTensor = shapeInitTensor;
     for (const auto &[i, size] : llvm::enumerate(sizes)) {
       auto maybeConstInt = getConstantIntValue(size);
-      assert(maybeConstInt.has_value() ||
-             expandDestType.isDynamicDim(i) && "expected dynamic dim");
+      assert((maybeConstInt.has_value() || expandDestType.isDynamicDim(i)) &&
+             "expected dynamic dim");
       Value dim =
           (maybeConstInt.has_value())
               ? rewriter
@@ -397,7 +397,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
       loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm);
 
   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
-             DBGS() << "reshape op: " << &reshapeOp; DBGSNL();
+             DBGS() << "reshape op: " << reshapeOp; DBGSNL();
              llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
              DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
 

>From f14c48803b0799631dab840a8a8fa75fd92b70f4 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 23:00:44 -0600
Subject: [PATCH 7/8] Add clarifying comment

---
 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 2 +-
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp        | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 344e801835ccc9..06e8586f4288b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op,
 
 struct LowerPackResult {
   tensor::PadOp padOp;
-  Operation *expandShapeOp;
+  Operation *expandShapeOp; // `tensor::ExpandShapeOp` or `tensor::ReshapeOp`
   linalg::TransposeOp transposeOp;
 };
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 6018d58b94eb72..3e41399c336a93 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -222,7 +222,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(packOp);
 
-  // 1. Compute the permutation vector to shuffle packed shape into the shape.
+  // 1. Compute the permutation vector to shuffle packed shape into the shape
   // before any outer or inner permutations have been applied. The permutation
   // can be obtained from two permutations:
   //   a) Compute the permutation vector to move the last `numPackedDims` into
@@ -251,7 +251,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
   applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
 
-  // 2. Compute the stripMinedShape: this is the packed shape before any outer.
+  // 2. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.
   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);

>From a787e4e8c682eeb64ec1ea12d0538a626eaf12be Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 23:34:00 -0600
Subject: [PATCH 8/8] Fix comment

---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 3e41399c336a93..4550589ded6df8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -352,7 +352,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   Operation *reshapeOp;
   // Check if any dims are not factorable and thus need a `tensor.reshape`
   // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion
-  // requires at most dynamnic dim
+  // requires at most one dynamnic dim
   if (llvm::any_of(packingMetadata.reassociations,
                    [&](const auto &rAssoc) -> bool {
                      return llvm::count_if(rAssoc, [&](int64_t r) {



More information about the cfe-commits mailing list