[PATCH] D137730: [mlir][TilingInterface] Fix a crash in PartialTilingInterface for some inputs
Murali Vijayaraghavan via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 17 15:45:31 PST 2022
vmurali updated this revision to Diff 476257.
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D137730/new/
https://reviews.llvm.org/D137730
Files:
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Index: mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
===================================================================
--- mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -197,3 +197,36 @@
// CHECK: linalg.yield
// CHECK: } -> tensor<?x?xf32>
// CHECK: return %[[R]] : tensor<?x?xf32>
+
+// -----
+
+func.func @reduction_bug(%arg0: tensor<32x32xi32>, %arg1: tensor<32x32xi32>, %out: tensor<32xi32>) -> tensor<32xi32> {
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<32x32xi32>, tensor<32x32xi32>) outs(%out : tensor<32xi32>) {
+ ^bb0(%a: i32, %b: i32, %c: i32):
+ %r1 = arith.muli %a, %b: i32
+ %r2 = arith.addi %c, %r1 : i32
+ linalg.yield %r2 : i32
+ } -> tensor<32xi32>
+ return %red : tensor<32xi32>
+}
+
+transform.sequence failures(suppress) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 0, 8] }
+}
+
+// // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0)>
+// // CHECK-LABEL: func @reduction_bug
+// // CHECK: %[[RED:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// // CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<32x32xi32>, tensor<32x32xi32>) outs(%[[F]] : tensor<32xi32>) {
+// // CHECK: arith.muli
+// // CHECK: arith.addi
+// // CHECK: linalg.yield
+// // CHECK: } -> tensor<32xi32>
+// // CHECK: return %[[RED]] : tensor<32xi32>
Index: mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
===================================================================
--- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -453,6 +453,19 @@
break;
}
}
+ {
+ auto origResultTensor = cast<DestinationStyleOpInterface>(op.getOperation())
+ .getDpsInitOperand(0);
+ size_t origResultSize = 0;
+ if (auto shapedType =
+ origResultTensor->get().getType().dyn_cast<ShapedType>())
+ origResultSize = shapedType.getShape().size();
+ if (iterationDomain.size() != origResultSize + 1) {
+ return b.notifyMatchFailure(
+ op, "only support result tensor whose rank is exactly one dimension "
+ "smaller than the number of loops.");
+ }
+ }
// 1. create the inital tensor value.
FailureOr<Operation *> identityTensor =
op.generateInitialTensorForPartialReduction(b, loc, tileSize,
Index: mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -276,6 +276,9 @@
SmallVector<int64_t> newOutputShape;
ArrayRef<int64_t> oldShape =
linalgOp.getShape(linalgOp.getDpsInitOperand(0));
+ assert(sizes.size() == oldShape.size() + 1 &&
+ "result tensor should have rank exactly one dimension smaller than "
+ "the number of loops.");
SmallVector<Value> dynamicDims;
for (int64_t idx : llvm::seq<int64_t>(0, oldShape.size() + 1)) {
if (idx == insertSplitDimension) {
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D137730.476257.patch
Type: text/x-patch
Size: 3833 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20221117/866cedf8/attachment.bin>
More information about the llvm-commits
mailing list