[Mlir-commits] [libc] [clang] [mlir] [flang] [compiler-rt] [clang-tools-extra] [llvm] [mlir][Linalg] Support dynamic shapes in `lower_pack` transform (PR #76003)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 21 07:33:12 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
Author: None (srcarroll)
<details>
<summary>Changes</summary>
When an expanded dim is not factorable, emit a `tensor.reshape` instead of a `tensor.expand_shape`
---
Full diff: https://github.com/llvm/llvm-project/pull/76003.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2-1)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+1-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+54-21)
- (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+46)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 77ed9db5e71bd1..4abd3740b57105 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -498,7 +498,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
- Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
+ Type<Or<[Transform_ConcreteOpType<"tensor.expand_shape">.predicate,
+ Transform_ConcreteOpType<"tensor.reshape">.predicate]>>:$expand_shape_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a848d12fbbb50e..06e8586f4288b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op,
struct LowerPackResult {
tensor::PadOp padOp;
- tensor::ExpandShapeOp expandShapeOp;
+ Operation *expandShapeOp; // `tensor::ExpandShapeOp` or `tensor::ReshapeOp`
linalg::TransposeOp transposeOp;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9d230e2c2e5749..4550589ded6df8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -218,21 +218,11 @@ struct PackedOperandsDimList {
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp) {
- // 1. Filter out NYI cases.
- auto packedTensorType =
- cast<RankedTensorType>(packOp->getResultTypes().front());
- if (llvm::any_of(packOp.getStaticInnerTiles(),
- [](int64_t size) { return ShapedType::isDynamic(size); })) {
- return rewriter.notifyMatchFailure(
- packOp,
- "non-static shape NYI, needs a more powerful tensor.expand_shape op");
- }
-
Location loc = packOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
- // 2. Compute the permutation vector to shuffle packed shape into the shape
+ // 1. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied. The permutation
// can be obtained from two permutations:
// a) Compute the permutation vector to move the last `numPackedDims` into
@@ -240,6 +230,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
// b) Compute the permutation vector to move outer dims if the pack op
// has outer_dims_perm.
// Apply (b) permutation on (a) permutation to get the final permutation.
+ auto packedTensorType =
+ cast<RankedTensorType>(packOp->getResultTypes().front());
int64_t numPackedDims = packOp.getInnerDimsPos().size();
int64_t packedRank = packedTensorType.getRank();
auto lastDims = llvm::to_vector(
@@ -259,12 +251,12 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
- // 3. Compute the stripMinedShape: this is the packed shape before any outer
+ // 2. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
- // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+ // 3. Pad the source of packOp to a shape we can expand into stripMinedShape.
SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
@@ -351,24 +343,65 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
/*transposeOp=*/nullptr};
}
}
- // 5. Expand from the padded result to the stripMinedShape.
- auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
- loc,
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
- padOp.getResult(), packingMetadata.reassociations);
- // 6. Transpose stripMinedShape to packedShape.
+ // 4. Expand from the padded result to the stripMinedShape.
+ RankedTensorType expandDestType =
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
SmallVector<int64_t> transpPerm =
invertPermutationVector(packedToStripMinedShapePerm);
+ Operation *reshapeOp;
+ // Check if any dims are not factorable and thus need a `tensor.reshape`
+ // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion
+ // requires at most one dynamnic dim
+ if (llvm::any_of(packingMetadata.reassociations,
+ [&](const auto &rAssoc) -> bool {
+ return llvm::count_if(rAssoc, [&](int64_t r) {
+ return stripMinedShape[r] == ShapedType::kDynamic;
+ }) > 1;
+ })) {
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+ applyPermutationToVector(sizes, transpPerm);
+ // Create a `tensor` of `index` types for the `shape` operand of
+ // `tensor.reshape`
+ Value shapeInitTensor = rewriter.create<tensor::EmptyOp>(
+ loc,
+ RankedTensorType::get({expandDestType.getRank()},
+ rewriter.getIndexType()),
+ ValueRange{});
+ Value shapeTensor = shapeInitTensor;
+ for (const auto &[i, size] : llvm::enumerate(sizes)) {
+ auto maybeConstInt = getConstantIntValue(size);
+ assert((maybeConstInt.has_value() || expandDestType.isDynamicDim(i)) &&
+ "expected dynamic dim");
+ Value dim =
+ (maybeConstInt.has_value())
+ ? rewriter
+ .create<arith::ConstantIndexOp>(loc, maybeConstInt.value())
+ .getResult()
+ : cast<Value>(size);
+ shapeTensor = rewriter.create<tensor::InsertOp>(
+ loc, dim, shapeTensor,
+ SmallVector<Value>(
+ {rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
+ }
+ reshapeOp = rewriter.create<tensor::ReshapeOp>(
+ loc, expandDestType, padOp.getResult(), shapeTensor);
+ } else {
+ reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandDestType, padOp.getResult(), packingMetadata.reassociations);
+ }
+
+ // 5. Transpose stripMinedShape to packedShape.
auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
+ loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm);
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "reshape op: " << reshapeOp; DBGSNL();
llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
- // 7. Replace packOp by transposeOp.
+ // 6. Replace packOp by transposeOp.
rewriter.replaceOp(packOp, transposeOp->getResults());
return LowerPackResult{padOp, reshapeOp, transposeOp};
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 316df431a9c0c8..13d74cbe433264 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -61,6 +61,52 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 64)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 128)>
+// CHECK: func.func @pack_dyn_tiles(
+// CHECK-SAME: %[[ARG0:.*]]: [[TENSOR_TY_0:tensor<64x128xf32>]]
+// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[TILE0:.*]]: index,
+// CHECK-SAME: %[[TILE1:.*]]: index
+func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor<?x?x?x?xf32>, %tile_0: index, %tile_1: index) -> tensor<?x?x?x?xf32> {
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[PAD0:.*]] = affine.apply #[[MAP0]]()[%[[TILE0]], %[[DIM0]]]
+// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[PAD1:.*]] = affine.apply #[[MAP1]]()[%[[TILE1]], %[[DIM1]]]
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[PAD0]], %[[PAD1]]]
+// CHECK-NEXT: ^bb0
+// CHECK-NEXT: tensor.yield %[[CST]] : f32
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK-NEXT: %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex>
+// CHECK-NEXT: %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]]
+// CHECK-NEXT: %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]]
+// CHECK-NEXT: %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]]
+// CHECK-NEXT: %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]]
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]])
+// CHECK-NEXT: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[EXPANDED]] : {{.*}}) outs(%[[ARG1]] {{.*}}) permutation = [0, 2, 1, 3]
+ %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%tile_0, %tile_1] into %arg1
+ : tensor<64x128xf32> -> tensor<?x?x?x?xf32>
+ return %pack : tensor<?x?x?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.reshape">, !transform.op<"linalg.transpose">)
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_as_pad(
func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
``````````
</details>
https://github.com/llvm/llvm-project/pull/76003
More information about the Mlir-commits
mailing list