[Mlir-commits] [mlir] 15f52c1 - [mlir][Linalg][Transform] Add support to let `transform.structured.pack_greedily` pad to the next multiple of a static constant
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Mar 27 23:37:20 PDT 2023
Author: Nicolas Vasilache
Date: 2023-03-27T23:37:13-07:00
New Revision: 15f52c1502e6aa2f7553393d76da92b21c7cf493
URL: https://github.com/llvm/llvm-project/commit/15f52c1502e6aa2f7553393d76da92b21c7cf493
DIFF: https://github.com/llvm/llvm-project/commit/15f52c1502e6aa2f7553393d76da92b21c7cf493.diff
LOG: [mlir][Linalg][Transform] Add support to let `transform.structured.pack_greedily` pad to the next multiple of a static constant
This increase the flexibility of the transformation to allow mixed packing / padding specifications.
Differential Revision: https://reviews.llvm.org/D146969
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e107911af8b98..3b80712adcebe 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -588,14 +588,27 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
Target a Linalg op and rewrite it into packed LinalgOp form by trying to
infer whether a known suboperation is embedded
- Different packing strategies are applied in order, when one applies
+ Different packing strategies are applied in order, when one applies
successfully, the transform returns:
1. Matmul packing: Try to infer a matmul operation embedded in the target op.
Specifically, this looks for 2 parallel dimensions that participate in
an outer-product and 1 reduction dimension.
These dimensions are referred as (m, n, k) to match canonical matmul
terminology.
- The packed sizes for (m, n, k) are specified by `matmul_packed_sizes`.
+
+ The packed sizes for (m, n, k) are specified by `matmul_packed_sizes`
+ and the optional `matmul_padded_sizes_next_multiple_of`.
+ When an entry `matmul_packed_sizes[i]` is non-0, the corresponding
+ dimension is packed by `matmul_packed_sizes[i]`.
+ Otherwise, the dimension is merely padded to the next multiple of
+ `matmul_padded_sizes_next_multiple_of[i]`.
+
+ `matmul_padded_sizes_next_multiple_of` is optional and is expected to
+ either be empty or of size `3`, matching the size of `matmul_packed_sizes`.
+ For each individual element of `matmul_packed_sizes` and
+ `matmul_padded_sizes_next_multiple_of`, only one of them is allowed to
+ be non-zero.
+
The ordering of the packed dimensions (mm, nn, kk) is specified by the
`matmul_inner_dims_order` attribute.
@@ -605,10 +618,15 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
3. An interchange transform is applied to isolate the dimensions to pack as
the most minor indexing dimensions of the linalg.generic. The most minor
dimensions are themselves ordered according to `inner_dims_order`.
- 4. Packing is performed by `packed_sizes` and following `inner_dims_order`.
+ 4. An elementwise traversal of `matmul_packed_sizes` and
+ `matmul_padded_sizes_next_multiple_of` is performed and for each
+ dimension `d`, either pack to `matmul_packed_sizes[d]` or pad to the
+ `matmul_padded_sizes_next_multiple_of[d]`.
+ 5. Packing/padding is performed by the amounts determined in step 4. and
+ following `inner_dims_order`.
By normalizing the most minor dimensions to `inner_dims_order`, the transform
- guarantees that packing immediates generates inner dimensions in a desirable
+ guarantees that packing immediately generates inner dimensions in a desirable
layout.
Outer dimension layout permutations are not controlled by this transform op
@@ -625,15 +643,23 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
// TODO: Transform_ConcreteOpType<linalg::LinalgOp> needs interface.
let arguments = (ins TransformHandleTypeInterface:$target,
Variadic<PDL_Operation>:$matmul_packed_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">
- :$static_matmul_packed_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">
- :$matmul_inner_dims_order);
+ ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
+ [DenseArrayCount<3>]>:$static_matmul_packed_sizes,
+ ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
+ [Attr<
+ Or<[DenseArrayCount<0>.predicate,
+ DenseArrayCount<3>.predicate]>,
+ "with 0 or 3 elements"
+ >]>
+ :$matmul_padded_sizes_next_multiple_of,
+ ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
+ [DenseArrayCount<3>]>:$matmul_inner_dims_order);
let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op);
let builders = [
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedMatmulPackedSizes,
+ "ArrayRef<int64_t>":$matmulPaddededSizesNextMultipleOf,
CArg<"ArrayRef<int64_t>", "{}">:$matmulDimsInnerDimsOrder)>
];
@@ -641,7 +667,9 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
$target
oilist(
`matmul_packed_sizes` `=` custom<DynamicIndexList>($matmul_packed_sizes,
- $static_matmul_packed_sizes)
+ $static_matmul_packed_sizes)
+ (`matmul_padded_sizes_next_multiple_of` `=`
+ $matmul_padded_sizes_next_multiple_of^)?
`matmul_inner_dims_order` `=` $matmul_inner_dims_order
)
attr-dict
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6ee0f13049977..44ef944682a1a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -26,6 +26,7 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -1298,11 +1299,18 @@ LogicalResult transform::PackGreedilyOp::verify() {
<< " is not a valid permutation";
}
// TODO: relax to allow empty once we have another strategy than just matmul.
- if (getMatmulInnerDimsOrder().size() != 3 ||
- getMixedMatmulPackedSizes().size() != 3) {
- return emitOpError() << " needs 3 entries for matmul_packed_sizes and "
- << getMatmulInnerDimsOrderAttrName()
- << " order for the matmul strategy";
+ if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
+ for (auto [s, nmo] :
+ llvm::zip_equal(getMixedMatmulPackedSizes(),
+ getMatmulPaddedSizesNextMultipleOf())) {
+ std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
+ if (nmo != 0 &&
+ (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
+ return emitOpError() << "at most one of the packed_size and the "
+ "padded_sizes_next_multiple_of can be nonzero "
+ "for the matmul strategy";
+ }
+ }
}
return success();
}
@@ -1318,8 +1326,12 @@ LogicalResult transform::PackGreedilyOp::verify() {
static FailureOr<PackResult>
packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<OpFoldResult> mnkPackedSizes,
+ ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
ArrayRef<int64_t> mnkOrder) {
assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
+ assert(mnkPaddedSizesNextMultipleOf.empty() ||
+ mnkPaddedSizesNextMultipleOf.size() == 3 &&
+ "num of packing sizes next multiple should be empty or of size 3");
assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
assert(isPermutationVector(mnkOrder) && "expected a permutation");
@@ -1334,9 +1346,15 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
SmallVector<int64_t> mmnnkkPos(numPackedDims);
for (int64_t i = 0, e = numPackedDims; i < e; ++i)
mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
- SmallVector<OpFoldResult> packedSizes(mnkPackedSizes.size());
+ SmallVector<OpFoldResult> packedSizes(numPackedDims);
for (int64_t i = 0, e = numPackedDims; i < e; ++i)
packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
+ SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
+ for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
+ paddedSizesNextMultipleOf[mnkOrder[i]] =
+ mnkPaddedSizesNextMultipleOf.empty() ? 0
+ : mnkPaddedSizesNextMultipleOf[i];
+ }
// 1. Infer dims that are important for matmul.
FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
@@ -1391,10 +1409,37 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// desired outerPerm for each operand.
// This is left for future work.
- // Add leading zeros to match numLoops.
+ // TODO: this creates too much IR, go use reifyResultShapes.
+ SmallVector<Range, 4> loopRanges =
+ cast<LinalgOp>(genericOp.getOperation())
+ .createLoopRanges(rewriter, genericOp.getLoc());
+
+ // Add leading zeros to match numLoops, we only pack the last 3 dimensions
+ // post interchange.
+ LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
+ DBGS() << "paddedSizesNextMultipleOf: ");
+ DBGSNL(););
+ LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
+ [](Range r) { llvm::dbgs() << r.size; });
+ DBGSNL(););
SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
rewriter.getIndexAttr(0));
- llvm::append_range(adjustedPackedSizes, packedSizes);
+ for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
+ if (paddedSizesNextMultipleOf[i] == 0) {
+ adjustedPackedSizes.push_back(packedSizes[i]);
+ continue;
+ }
+ AffineExpr d0, s0;
+ bindDims(rewriter.getContext(), d0);
+ bindSymbols(rewriter.getContext(), s0);
+ adjustedPackedSizes.push_back(makeComposedFoldedAffineApply(
+ rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
+ {loopRanges[adjustedPackedSizes.size()].size,
+ rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
+ }
+ LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
+ DBGS() << "adjustedPackedSizes: ");
+ DBGSNL(););
// TODO: If we wanted to give the genericOp a name after packing, after
// calling `pack` would be a good time.
@@ -1424,6 +1469,8 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
/*rewriter=*/rewriter,
/*linalgOp=*/linalgOp,
/*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
+ /*mnkPaddedSizesNextMultipleOf=*/
+ getMatmulPaddedSizesNextMultipleOf(),
/*mnkOrder=*/getMatmulInnerDimsOrder());
if (succeeded(packResult)) {
results.push_back(packResult->packedLinalgOp);
diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index 8645fa3813cff..2adfd9ae97aa7 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -46,3 +46,28 @@ transform.sequence failures(propagate) {
"transform.structured.multitile_sizes"(%arg0) { target_size = 3, divisor = 2, dimension = 0 }
: (!pdl.operation) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i32>)
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+ // expected-error at below {{not a valid permutation}}
+ transform.structured.pack_greedily %arg0
+ matmul_packed_sizes = [8, 0, 32]
+ matmul_inner_dims_order = [1, 1, 0]
+ : (!pdl.operation) -> !transform.op<"linalg.generic">
+
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+ // expected-error at below {{at most one of the packed_size and the padded_sizes_next_multiple_of can be nonzero}}
+ transform.structured.pack_greedily %arg0
+ matmul_packed_sizes = [1, 1, 1]
+ matmul_padded_sizes_next_multiple_of = [1, 1, 1]
+ matmul_inner_dims_order = [0, 1, 2]
+ : (!pdl.operation) -> !transform.op<"linalg.generic">
+
+}
diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
index 544f4391eb39a..fdb1699e7bb4a 100644
--- a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
+++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
@@ -226,3 +226,52 @@ transform.sequence failures(propagate) {
matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}
+
+// -----
+
+!A_mk = tensor<1023x255xf32>
+!B_nk = tensor<127x255xf32>
+!C_nm = tensor<127x1023xf32>
+
+#mkn_accesses = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (n, m)>
+]
+#mkn_trait = {
+ indexing_maps = #mkn_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// Normalized dims are: ( k, m, n)(kk, mm, nn)
+// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
+// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
+// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
+
+// CHECK-LABEL: @matmul_mk_nk_nm(
+func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm {
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]]
+ // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<1x8x32x130xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<1x128x8x130xf32>)
+ %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %d = arith.mulf %a, %b : f32
+ %e = arith.addf %c, %d : f32
+ linalg.yield %e : f32
+ } -> !C_nm
+ return %0 : !C_nm
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
+ transform.structured.pack_greedily %generic
+ // In this spec, the "k" dimension is not packed but rather padded to the
+ // next multiple of 10 (i.e. 130).
+ matmul_packed_sizes = [8, 0, 32]
+ matmul_padded_sizes_next_multiple_of = [0, 10, 0]
+ matmul_inner_dims_order = [1, 2, 0]
+ : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
+}
More information about the Mlir-commits
mailing list