[Mlir-commits] [mlir] 5601821 - [mlir][tensor] Fix insert_slice + tensor cast overflow

Nicolas Vasilache llvmlistbot at llvm.org
Fri Dec 10 13:41:30 PST 2021


Author: Nicolas Vasilache
Date: 2021-12-10T21:41:26Z
New Revision: 5601821daec72b221631cfd6175760557281d602

URL: https://github.com/llvm/llvm-project/commit/5601821daec72b221631cfd6175760557281d602
DIFF: https://github.com/llvm/llvm-project/commit/5601821daec72b221631cfd6175760557281d602.diff

LOG: [mlir][tensor] Fix insert_slice + tensor cast overflow

InsertSliceOp may have subprefix semantics where missing trailing dimensions
are automatically inferred directly from the operand shape.
This revision fixes an overflow that occurs in such cases when the impl is based on the op rank.

Differential Revision: https://reviews.llvm.org/D115549

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index edddfb86e553..49cfec663ef7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1417,11 +1417,11 @@ struct InsertSliceOpSourceCastInserter final
       return failure();
     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
                                      srcType.getShape().end());
-    for (int64_t i = 0; i < srcType.getRank(); ++i) {
-      if (Optional<int64_t> constInt =
-              getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
-        newSrcShape[i] = *constInt;
-    }
+    // Offsets / sizes / strides can be a subprefix of the rank; take only the
+    // leading dimensions.
+    for (auto en : llvm::enumerate(insertSliceOp.getMixedSizes()))
+      if (Optional<int64_t> constInt = getConstantIntValue(en.value()))
+        newSrcShape[en.index()] = *constInt;
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index fc9abe439b8a..50fda25cce26 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -536,6 +536,21 @@ func @insert_tensor_cast_on_insert_slice_src(
 
 // -----
 
+// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src_prefix(
+// CHECK-SAME:      %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
+//      CHECK:    %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x?xf32>
+//      CHECK:    %[[r:.*]] =  tensor.insert_slice %[[cast]] into %[[arg1]][0, 1] [64, 5] [1, 1] : tensor<64x5x?xf32> into tensor<?x?x?xf32>
+//      CHECK:    return %[[r]]
+func @insert_tensor_cast_on_insert_slice_src_prefix(
+    %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
+  %c64 = arith.constant 64: index
+  %r = tensor.insert_slice %arg0 into %arg1[0, 1] [%c64, 5] [1, 1]
+    : tensor<?x5x?xf32> into tensor<?x?x?xf32>
+  return %r : tensor<?x?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract_insert
 //  CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
 func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {


        


More information about the Mlir-commits mailing list