[Mlir-commits] [mlir] 5601821 - [mlir][tensor] Fix insert_slice + tensor cast overflow
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Dec 10 13:41:30 PST 2021
Author: Nicolas Vasilache
Date: 2021-12-10T21:41:26Z
New Revision: 5601821daec72b221631cfd6175760557281d602
URL: https://github.com/llvm/llvm-project/commit/5601821daec72b221631cfd6175760557281d602
DIFF: https://github.com/llvm/llvm-project/commit/5601821daec72b221631cfd6175760557281d602.diff
LOG: [mlir][tensor] Fix insert_slice + tensor cast overflow
InsertSliceOp may have subprefix semantics where missing trailing dimensions
are automatically inferred directly from the operand shape.
This revision fixes an overflow that occurs in such cases when the impl is based on the op rank.
Differential Revision: https://reviews.llvm.org/D115549
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index edddfb86e553..49cfec663ef7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1417,11 +1417,11 @@ struct InsertSliceOpSourceCastInserter final
return failure();
SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
srcType.getShape().end());
- for (int64_t i = 0; i < srcType.getRank(); ++i) {
- if (Optional<int64_t> constInt =
- getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
- newSrcShape[i] = *constInt;
- }
+ // Offsets / sizes / strides can be a subprefix of the rank; take only the
+ // leading dimensions.
+ for (auto en : llvm::enumerate(insertSliceOp.getMixedSizes()))
+ if (Optional<int64_t> constInt = getConstantIntValue(en.value()))
+ newSrcShape[en.index()] = *constInt;
RankedTensorType newSrcType =
RankedTensorType::get(newSrcShape, srcType.getElementType());
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index fc9abe439b8a..50fda25cce26 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -536,6 +536,21 @@ func @insert_tensor_cast_on_insert_slice_src(
// -----
+// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src_prefix(
+// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
+// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x?xf32>
+// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1] [64, 5] [1, 1] : tensor<64x5x?xf32> into tensor<?x?x?xf32>
+// CHECK: return %[[r]]
+func @insert_tensor_cast_on_insert_slice_src_prefix(
+ %arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
+ %c64 = arith.constant 64: index
+ %r = tensor.insert_slice %arg0 into %arg1[0, 1] [%c64, 5] [1, 1]
+ : tensor<?x5x?xf32> into tensor<?x?x?xf32>
+ return %r : tensor<?x?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_insert
// CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {
More information about the Mlir-commits
mailing list