[Mlir-commits] [mlir] 4d27f06 - [mlir][Tensor] Fix ExtractSliceFromReshape transform edge case
Christopher Bate
llvmlistbot at llvm.org
Mon Sep 19 13:02:53 PDT 2022
Author: Christopher Bate
Date: 2022-09-19T14:02:45-06:00
New Revision: 4d27f06f9454a6733c3f801c8b992193702607b3
URL: https://github.com/llvm/llvm-project/commit/4d27f06f9454a6733c3f801c8b992193702607b3
DIFF: https://github.com/llvm/llvm-project/commit/4d27f06f9454a6733c3f801c8b992193702607b3.diff
LOG: [mlir][Tensor] Fix ExtractSliceFromReshape transform edge case
The transformation would fail if none of the sliced dimensions were
linearized by the producing `tensor.collapse_shape`. This is a trivial
edge case but it wasn't correctly tested. Fixes the issue and adds a test.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D134088
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e6b6048f8180..f693b3503abf 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -441,14 +441,16 @@ class SliceFromCollapseHelper {
/// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
/// multi-index (%3) that would be passed to this function to generate the
/// parameters for the `tensor.extract_slice` op (%4).
- SmallVector<Range> getExtractSliceParams(ArrayRef<ValueRange> multiIndices);
+ SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
+ ArrayRef<ValueRange> multiIndices);
/// This function takes indices in the index space of the "tiled dimensions"
/// described above and returns a set of Range variables that describe how the
/// slice should be inserted into the destination. In the example above, `%iv`
/// would be passed to this function to generate the parameters for the
/// `tensor.insert_slice` op producing %6.
- SmallVector<Range> getInsertSliceParams(ValueRange tileIndices);
+ SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
+ ValueRange tileIndices);
private:
SmallVector<ReassociationIndices> reassociationIndices;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
index 4acd5482e823..dcee9deff5ca 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
@@ -164,13 +164,14 @@ tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
}
}
- auto extractParams = helper.getExtractSliceParams(multiIndices);
+ SmallVector<Range> extractParams =
+ helper.getExtractSliceParams(builder.getContext(), multiIndices);
Value subTileResult = builder.create<tensor::ExtractSliceOp>(
loc, collapseShapeOp.getSrc(), extractParams);
SmallVector<Range> insertParams =
- helper.getInsertSliceParams(tileInductionVars);
+ helper.getInsertSliceParams(builder.getContext(), tileInductionVars);
// Collapse the dimensions of the source slice back down.
Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 7f5b63814e69..9bca50f64321 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -298,11 +298,8 @@ llvm::SmallBitVector mlir::getLinearizedDimensions(
}
SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
- ArrayRef<ValueRange> multiIndices) {
- assert(!multiIndices.empty() && !multiIndices[0].empty() &&
- "multiIndices should not be empty");
+ MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
unsigned loopIdx = 0;
- MLIRContext *ctx = multiIndices[0][0].getContext();
auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
SmallVector<Range> offsetsSizesAndStrides;
@@ -339,8 +336,8 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
}
SmallVector<Range>
-SliceFromCollapseHelper::getInsertSliceParams(ValueRange tileIndices) {
- MLIRContext *ctx = tileIndices[0].getContext();
+SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
+ ValueRange tileIndices) {
auto one = IntegerAttr::get(IndexType::get(ctx), 1);
auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
SmallVector<Range> insertParams;
diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
index d8ca129bf59a..02e2502f9ffd 100644
--- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
+++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
@@ -162,3 +162,18 @@ func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32
// CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1]
return %slice : tensor<?x22xf32>
}
+
+// -----
+
+// CHECK: @no_sliced_linearized_dims(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
+func.func @no_sliced_linearized_dims(%input: tensor<30x11x100xf32>, %offt: index, %size: index) -> tensor<330x?xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<30x11x100xf32> into tensor<330x100xf32>
+ %slice = tensor.extract_slice %collapsed [0, %offt] [330, %size] [1, 1] : tensor<330x100xf32> to tensor<330x?xf32>
+ // CHECK-NOT: scf.for
+ // CHECK: %[[init:.+]] = linalg.init_tensor [330, %[[arg2]]]
+ // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, %[[arg1]]] [30, 11, %[[arg2]]] [1, 1, 1]
+ // CHECK: %[[c:.+]] = tensor.collapse_shape %[[e]] {{\[}}[0, 1], [2]]
+ // CHECK: %[[res:.+]] = tensor.insert_slice %[[c]] into %[[init]]
+ // CHECK: return %[[res]]
+ return %slice : tensor<330x?xf32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index f5a7f984ab0a..5dd5d763388a 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -151,6 +151,13 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Value> lbs(numTiledDims, zero);
SmallVector<Value> steps(numTiledDims, one);
+
+ // Below, we pass out the result of the loop body builder lambda via the
+ // `insertResult` variable. In certain cases, no loops will be created, but
+ // the body builder will still execute. In this case, the results will not
+ // be passed to the LoopNest object.
+ // TODO: remove this workaround if `scf::buildLoopNest` behavior is updated.
+ Value insertResult = nullptr;
scf::LoopNest nest = scf::buildLoopNest(
rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
@@ -159,11 +166,16 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor
helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
// Insert the slice into the destination.
- Value result = nestedBuilder.create<tensor::InsertSliceOp>(
+ insertResult = nestedBuilder.create<tensor::InsertSliceOp>(
loc, tile, iterArgs[0], insertParams);
- return {result};
+ return {insertResult};
});
- rewriter.replaceOp(op, nest.getResults()[0]);
+
+ if (!nest.loops.empty())
+ rewriter.replaceOp(op, nest.getResults());
+ else
+ rewriter.replaceOp(op, insertResult);
+
return success();
}
};
More information about the Mlir-commits
mailing list