[flang-commits] [compiler-rt] [libc] [mlir] [llvm] [clang] [clang-tools-extra] [flang] [mlir][Linalg] Support dynamic shapes in `lower_pack` transform (PR #76003)
via flang-commits
flang-commits at lists.llvm.org
Tue Dec 19 21:12:06 PST 2023
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/76003
>From 860a2f794bdf12ff1f08d4802570757e805264b0 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 18 Dec 2023 15:53:41 -0600
Subject: [PATCH 1/7] [mlir][Linalg] Support dynamic sizes in `lower_pack`
transform
---
.../Linalg/TransformOps/LinalgTransformOps.td | 3 +-
.../Dialect/Linalg/Transforms/Transforms.h | 2 +-
.../Dialect/Linalg/Transforms/Transforms.cpp | 69 +++++++++++++------
.../Dialect/Linalg/transform-lower-pack.mlir | 20 ++++++
4 files changed, 70 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 77ed9db5e71bd1..4abd3740b57105 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -498,7 +498,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
- Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
+ Type<Or<[Transform_ConcreteOpType<"tensor.expand_shape">.predicate,
+ Transform_ConcreteOpType<"tensor.reshape">.predicate]>>:$expand_shape_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a848d12fbbb50e..344e801835ccc9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op,
struct LowerPackResult {
tensor::PadOp padOp;
- tensor::ExpandShapeOp expandShapeOp;
+ Operation *expandShapeOp;
linalg::TransposeOp transposeOp;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9d230e2c2e5749..359274866748fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -218,21 +218,11 @@ struct PackedOperandsDimList {
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp) {
- // 1. Filter out NYI cases.
- auto packedTensorType =
- cast<RankedTensorType>(packOp->getResultTypes().front());
- if (llvm::any_of(packOp.getStaticInnerTiles(),
- [](int64_t size) { return ShapedType::isDynamic(size); })) {
- return rewriter.notifyMatchFailure(
- packOp,
- "non-static shape NYI, needs a more powerful tensor.expand_shape op");
- }
-
Location loc = packOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
- // 2. Compute the permutation vector to shuffle packed shape into the shape
+ // 1. Compute the permutation vector to shuffle packed shape into the shape.
// before any outer or inner permutations have been applied. The permutation
// can be obtained from two permutations:
// a) Compute the permutation vector to move the last `numPackedDims` into
@@ -240,6 +230,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
// b) Compute the permutation vector to move outer dims if the pack op
// has outer_dims_perm.
// Apply (b) permutation on (a) permutation to get the final permutation.
+ auto packedTensorType =
+ cast<RankedTensorType>(packOp->getResultTypes().front());
int64_t numPackedDims = packOp.getInnerDimsPos().size();
int64_t packedRank = packedTensorType.getRank();
auto lastDims = llvm::to_vector(
@@ -259,12 +251,12 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
- // 3. Compute the stripMinedShape: this is the packed shape before any outer
+ // 2. Compute the stripMinedShape: this is the packed shape before any outer.
// or inner permutations have been applied.
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
- // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+ // 3. Pad the source of packOp to a shape we can expand into stripMinedShape.
SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
@@ -351,24 +343,57 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
/*transposeOp=*/nullptr};
}
}
- // 5. Expand from the padded result to the stripMinedShape.
- auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
- loc,
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
- padOp.getResult(), packingMetadata.reassociations);
- // 6. Transpose stripMinedShape to packedShape.
+ RankedTensorType expandSourceType = padOp.getResult().getType().cast<RankedTensorType>();
+ RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+
+ // Dynamic dim is factorable only if the expanded version has at most one dynamic dim
+ bool isFactorable = true;
+ for (const auto &[i, rIndcs] : llvm::enumerate(packingMetadata.reassociations)) {
+ if (!expandSourceType.isDynamicDim(i))
+ continue;
+ int64_t numDyn = 0;
+ for (auto j : rIndcs) {
+ if ((stripMinedShape[j] == ShapedType::kDynamic) && (++numDyn > 1)) {
+ isFactorable = false;
+ break;
+ }
+ }
+ }
+
+ // 4. Expand from the padded result to the stripMinedShape.
SmallVector<int64_t> transpPerm =
invertPermutationVector(packedToStripMinedShapePerm);
+ Operation *reshapeOp;
+ if (!isFactorable) {
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+ applyPermutationToVector(sizes, transpPerm);
+ Value shapeInitTensor =
+ rewriter.create<tensor::EmptyOp>(loc, RankedTensorType::get({expandDestType.getRank()}, rewriter.getIndexType()), ValueRange{});
+ Value shapeTensor = shapeInitTensor;
+ for (const auto &[i, size] : llvm::enumerate(sizes)) {
+ Value dim = (expandDestType.isDynamicDim(i)) ? cast<Value>(size) : rewriter.create<arith::ConstantIndexOp>(loc, getConstantIntValue(size).value()).getResult();
+ shapeTensor = rewriter.create<tensor::InsertOp>(loc, dim, shapeTensor, SmallVector<Value>({rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
+ }
+ reshapeOp = rewriter.create<tensor::ReshapeOp>(loc, expandDestType, padOp.getResult(), shapeTensor);
+ } else {
+ reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ expandDestType,
+ padOp.getResult(), packingMetadata.reassociations);
+ }
+
+ // 5. Transpose stripMinedShape to packedShape.
auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
+ loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm);
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
- DBGS() << "reshape op: " << reshapeOp; DBGSNL();
+ DBGS() << "reshape op: " << &reshapeOp; DBGSNL();
llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
- // 7. Replace packOp by transposeOp.
+ // 6. Replace packOp by transposeOp.
rewriter.replaceOp(packOp, transposeOp->getResults());
return LowerPackResult{padOp, reshapeOp, transposeOp};
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 316df431a9c0c8..6a203dab91e58b 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -61,6 +61,26 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @pack_all_dyn(
+func.func @pack_all_dyn(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1
+ : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+
+ return %pack : tensor<?x?x?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.reshape">, !transform.op<"linalg.transpose">)
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_as_pad(
func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
>From c8db4ac07c017dbdfbd8f91d47f32015ca9dce67 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 19:11:22 -0600
Subject: [PATCH 2/7] Refactor
---
.../Dialect/Linalg/Transforms/Transforms.cpp | 54 ++++++++++---------
1 file changed, 28 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 359274866748fc..21446d07b784a9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -344,44 +344,46 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
}
}
- RankedTensorType expandSourceType = padOp.getResult().getType().cast<RankedTensorType>();
- RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
-
- // Dynamic dim is factorable only if the expanded version has at most one dynamic dim
- bool isFactorable = true;
- for (const auto &[i, rIndcs] : llvm::enumerate(packingMetadata.reassociations)) {
- if (!expandSourceType.isDynamicDim(i))
- continue;
- int64_t numDyn = 0;
- for (auto j : rIndcs) {
- if ((stripMinedShape[j] == ShapedType::kDynamic) && (++numDyn > 1)) {
- isFactorable = false;
- break;
- }
- }
- }
-
// 4. Expand from the padded result to the stripMinedShape.
+ // Check if any dims are not factorable. A dim is factorable if the expansion
+ // requires at most dynamnic dim
+ RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
SmallVector<int64_t> transpPerm =
invertPermutationVector(packedToStripMinedShapePerm);
Operation *reshapeOp;
- if (!isFactorable) {
+ if (llvm::any_of(packingMetadata.reassociations,
+ [&](const auto &rAssoc) -> bool {
+ return llvm::count_if(rAssoc, [&](int64_t r) {
+ return stripMinedShape[r] == ShapedType::kDynamic;
+ }) > 1;
+ })) {
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
applyPermutationToVector(sizes, transpPerm);
- Value shapeInitTensor =
- rewriter.create<tensor::EmptyOp>(loc, RankedTensorType::get({expandDestType.getRank()}, rewriter.getIndexType()), ValueRange{});
+ // Create a `tensor` of `index` types for the `shape` operand of `tensor.reshape`
+ Value shapeInitTensor = rewriter.create<tensor::EmptyOp>(
+ loc,
+ RankedTensorType::get({expandDestType.getRank()},
+ rewriter.getIndexType()),
+ ValueRange{});
Value shapeTensor = shapeInitTensor;
for (const auto &[i, size] : llvm::enumerate(sizes)) {
- Value dim = (expandDestType.isDynamicDim(i)) ? cast<Value>(size) : rewriter.create<arith::ConstantIndexOp>(loc, getConstantIntValue(size).value()).getResult();
- shapeTensor = rewriter.create<tensor::InsertOp>(loc, dim, shapeTensor, SmallVector<Value>({rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
+ Value dim = (expandDestType.isDynamicDim(i))
+ ? cast<Value>(size)
+ : rewriter
+ .create<arith::ConstantIndexOp>(
+ loc, getConstantIntValue(size).value())
+ .getResult();
+ shapeTensor = rewriter.create<tensor::InsertOp>(
+ loc, dim, shapeTensor,
+ SmallVector<Value>(
+ {rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
}
- reshapeOp = rewriter.create<tensor::ReshapeOp>(loc, expandDestType, padOp.getResult(), shapeTensor);
+ reshapeOp = rewriter.create<tensor::ReshapeOp>(
+ loc, expandDestType, padOp.getResult(), shapeTensor);
} else {
reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
- loc,
- expandDestType,
- padOp.getResult(), packingMetadata.reassociations);
+ loc, expandDestType, padOp.getResult(), packingMetadata.reassociations);
}
// 5. Transpose stripMinedShape to packedShape.
>From e68b32e372de420b2e6ece98e574836920014c54 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 21:49:38 -0600
Subject: [PATCH 3/7] Add regression test
---
.../Dialect/Linalg/transform-lower-pack.mlir | 36 ++++++++++++++++---
1 file changed, 31 insertions(+), 5 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 6a203dab91e58b..13d74cbe433264 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -61,11 +61,37 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func.func @pack_all_dyn(
-func.func @pack_all_dyn(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
- %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1
- : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
-
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 64)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 128)>
+// CHECK: func.func @pack_dyn_tiles(
+// CHECK-SAME: %[[ARG0:.*]]: [[TENSOR_TY_0:tensor<64x128xf32>]]
+// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[TILE0:.*]]: index,
+// CHECK-SAME: %[[TILE1:.*]]: index
+func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor<?x?x?x?xf32>, %tile_0: index, %tile_1: index) -> tensor<?x?x?x?xf32> {
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[PAD0:.*]] = affine.apply #[[MAP0]]()[%[[TILE0]], %[[DIM0]]]
+// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[PAD1:.*]] = affine.apply #[[MAP1]]()[%[[TILE1]], %[[DIM1]]]
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[PAD0]], %[[PAD1]]]
+// CHECK-NEXT: ^bb0
+// CHECK-NEXT: tensor.yield %[[CST]] : f32
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK-NEXT: %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex>
+// CHECK-NEXT: %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]]
+// CHECK-NEXT: %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]]
+// CHECK-NEXT: %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]]
+// CHECK-NEXT: %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]]
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]])
+// CHECK-NEXT: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[EXPANDED]] : {{.*}}) outs(%[[ARG1]] {{.*}}) permutation = [0, 2, 1, 3]
+ %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%tile_0, %tile_1] into %arg1
+ : tensor<64x128xf32> -> tensor<?x?x?x?xf32>
return %pack : tensor<?x?x?x?xf32>
}
>From 0975552abe2d404388af48eafc39b464f69a4834 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 21:53:42 -0600
Subject: [PATCH 4/7] Fix comment
---
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 21446d07b784a9..1f63d0ab706cdb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -345,12 +345,14 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
}
// 4. Expand from the padded result to the stripMinedShape.
- // Check if any dims are not factorable. A dim is factorable if the expansion
- // requires at most dynamnic dim
- RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+ RankedTensorType expandDestType =
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
SmallVector<int64_t> transpPerm =
invertPermutationVector(packedToStripMinedShapePerm);
Operation *reshapeOp;
+ // Check if any dims are not factorable and thus need a `tensor.reshape`
+ // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion
+ // requires at most dynamnic dim
if (llvm::any_of(packingMetadata.reassociations,
[&](const auto &rAssoc) -> bool {
return llvm::count_if(rAssoc, [&](int64_t r) {
@@ -360,7 +362,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
applyPermutationToVector(sizes, transpPerm);
- // Create a `tensor` of `index` types for the `shape` operand of `tensor.reshape`
+ // Create a `tensor` of `index` types for the `shape` operand of
+ // `tensor.reshape`
Value shapeInitTensor = rewriter.create<tensor::EmptyOp>(
loc,
RankedTensorType::get({expandDestType.getRank()},
>From 48deca06d650959ba3727df9697566a0fd6a6cd2 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 22:31:12 -0600
Subject: [PATCH 5/7] Properly check optional value
---
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 15 +++++++++------
1 file changed, 9 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 1f63d0ab706cdb..2a1c72942df0bb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -371,12 +371,15 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
ValueRange{});
Value shapeTensor = shapeInitTensor;
for (const auto &[i, size] : llvm::enumerate(sizes)) {
- Value dim = (expandDestType.isDynamicDim(i))
- ? cast<Value>(size)
- : rewriter
- .create<arith::ConstantIndexOp>(
- loc, getConstantIntValue(size).value())
- .getResult();
+ auto maybeConstInt = getConstantIntValue(size);
+ assert(maybeConstInt.has_value() ||
+ expandDestType.isDynamicDim(i) && "expected dynamic dim");
+ Value dim =
+ (maybeConstInt.has_value())
+ ? rewriter
+ .create<arith::ConstantIndexOp>(loc, maybeConstInt.value())
+ .getResult()
+ : cast<Value>(size);
shapeTensor = rewriter.create<tensor::InsertOp>(
loc, dim, shapeTensor,
SmallVector<Value>(
>From 194f8194659908f8127b99a807033192e1477def Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 22:37:10 -0600
Subject: [PATCH 6/7] Revert accidental change
---
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2a1c72942df0bb..6018d58b94eb72 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -372,8 +372,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
Value shapeTensor = shapeInitTensor;
for (const auto &[i, size] : llvm::enumerate(sizes)) {
auto maybeConstInt = getConstantIntValue(size);
- assert(maybeConstInt.has_value() ||
- expandDestType.isDynamicDim(i) && "expected dynamic dim");
+ assert((maybeConstInt.has_value() || expandDestType.isDynamicDim(i)) &&
+ "expected dynamic dim");
Value dim =
(maybeConstInt.has_value())
? rewriter
@@ -397,7 +397,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm);
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
- DBGS() << "reshape op: " << &reshapeOp; DBGSNL();
+ DBGS() << "reshape op: " << reshapeOp; DBGSNL();
llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
>From f14c48803b0799631dab840a8a8fa75fd92b70f4 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 19 Dec 2023 23:00:44 -0600
Subject: [PATCH 7/7] Add clarifying comment
---
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 2 +-
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 344e801835ccc9..06e8586f4288b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op,
struct LowerPackResult {
tensor::PadOp padOp;
- Operation *expandShapeOp;
+ Operation *expandShapeOp; // `tensor::ExpandShapeOp` or `tensor::ReshapeOp`
linalg::TransposeOp transposeOp;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 6018d58b94eb72..3e41399c336a93 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -222,7 +222,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
- // 1. Compute the permutation vector to shuffle packed shape into the shape.
+ // 1. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied. The permutation
// can be obtained from two permutations:
// a) Compute the permutation vector to move the last `numPackedDims` into
@@ -251,7 +251,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
- // 2. Compute the stripMinedShape: this is the packed shape before any outer.
+ // 2. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
More information about the flang-commits
mailing list