[Mlir-commits] [mlir] 5f2618f - [mlir][linalg] Allow constant exprs in pack/unpack propagation through elementwise
Quinn Dawkins
llvmlistbot at llvm.org
Tue Feb 21 22:35:40 PST 2023
Author: Quinn Dawkins
Date: 2023-02-22T01:31:16-05:00
New Revision: 5f2618fe168fffc62ac9dc1396d32f8d53e79621
URL: https://github.com/llvm/llvm-project/commit/5f2618fe168fffc62ac9dc1396d32f8d53e79621
DIFF: https://github.com/llvm/llvm-project/commit/5f2618fe168fffc62ac9dc1396d32f8d53e79621.diff
LOG: [mlir][linalg] Allow constant exprs in pack/unpack propagation through elementwise
The pack/unpack propagation patterns currently assume all map results
for non-scalar arguments are AffineDimExprs, leading to crashes when the
input operand being packed has constant expressions.
Differential Revision: https://reviews.llvm.org/D144443
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 206f6c51a4929..3848510cee598 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -163,8 +163,13 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
// Step 1. Construct the information of packing data dimensions; append inner
// dimensions to the indexing maps for the operand.
for (auto [index, expr] : llvm::enumerate(exprs)) {
- int64_t dimPos = expr.cast<AffineDimExpr>().getPosition();
- domainDimToOperandDim[dimPos] = index;
+ if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ int64_t dimPos = dimExpr.getPosition();
+ domainDimToOperandDim[dimPos] = index;
+ continue;
+ }
+ assert(expr.isa<AffineConstantExpr>() &&
+ "Found non-constant and non-affine dim expression");
}
SmallVector<int64_t> innerDimsPos;
SmallVector<OpFoldResult> innerTileSizes;
@@ -186,8 +191,13 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
SmallVector<int64_t> inversedOuterPerm =
invertPermutationVector(packInfo.outerDimsOnDomainPerm);
for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
- int64_t dimPos = exprs[i].cast<AffineDimExpr>().getPosition();
- exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
+ if (auto dimExpr = exprs[i].dyn_cast<AffineDimExpr>()) {
+ int64_t dimPos = dimExpr.getPosition();
+ exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
+ continue;
+ }
+ assert(exprs[i].isa<AffineConstantExpr>() &&
+ "Attempted to permute non-constant and non-affine dim expression");
}
// Step 2.2: Undo the transposition on `exprs` and propagate the
// transposition on the pack using outerDimsPerm.
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index afe184b655adc..546e268c83742 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -231,9 +231,6 @@ func.func @elem_pack_transpose_inner_and_outer_dims2(%arg0: tensor<64xf32>, %des
// -----
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d0)>
-#map2 = affine_map<(d0, d1) -> (d1)>
func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32>
{
%init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
@@ -280,6 +277,52 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
// -----
+func.func @affine_constant_expr_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x1x1x1xi32>, %arg2: tensor<1x128x1x1xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32>
+{
+ %init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
+ %transpose = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, 0)>,
+ affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100x1x1x1xi32>, tensor<1x128x1x1xi32>)
+ outs(%init_transpose : tensor<100x200x128x256xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+ %0 = arith.addi %b0, %b1 : i32
+ %1 = arith.addi %0, %b2 : i32
+ linalg.yield %1 : i32
+ } -> tensor<100x200x128x256xi32>
+ %4 = tensor.pack %transpose
+ inner_dims_pos = [3, 2]
+ inner_tiles = [16, 32]
+ into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
+ return %4 : tensor<100x200x4x16x16x32xi32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, 0, 0, 0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (0, d1, 0, 0, d5)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
+// CHECK: func.func @affine_constant_expr_pack
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [3, 1] inner_tiles = [16, 32]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<1x4x1x1x32xi32>
+// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG2_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
+// CHECK-SAME: outs(%[[DEST]]
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1)>
More information about the Mlir-commits
mailing list