[Mlir-commits] [mlir] 4747e1b - [mlir][Linalg] Fix tensor.extract_slice(linalg.init_tensor) canonicalization for rank-reducing extract
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jul 8 11:14:09 PDT 2021
Author: Nicolas Vasilache
Date: 2021-07-08T18:13:51Z
New Revision: 4747e1b83ba0117a88551a358f8960060ffa7558
URL: https://github.com/llvm/llvm-project/commit/4747e1b83ba0117a88551a358f8960060ffa7558
DIFF: https://github.com/llvm/llvm-project/commit/4747e1b83ba0117a88551a358f8960060ffa7558.diff
LOG: [mlir][Linalg] Fix tensor.extract_slice(linalg.init_tensor) canonicalization for rank-reducing extract
Differential Revision: https://reviews.llvm.org/D105636
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ea12a312d9c0..4ef942d54776 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -772,11 +772,11 @@ struct FoldInitTensorWithExtractSliceOp
PatternRewriter &rewriter) const override {
if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>())
return failure();
+ // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
+ // as well as its result type.
rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
sliceOp, sliceOp.sizes(),
- llvm::to_vector<4>(llvm::map_range(
- sliceOp.static_sizes(),
- [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
+ sliceOp.result().getType().cast<RankedTensorType>().getShape(),
sliceOp.getSourceType().getElementType());
return success();
}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 864c79a35756..5f7ad8ddfde2 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -890,3 +890,15 @@ func @init_canonicalize(%i : index) {
return
}
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing_init_extract
+func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
+ // CHECK: linalg.init_tensor [2] : tensor<2xf32>
+ %a = linalg.init_tensor [%sz, 2] : tensor<?x2xf32>
+
+ // CHECK-NOT: extract
+ %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
+ return %r: tensor<2xf32>
+}
More information about the Mlir-commits
mailing list