[Mlir-commits] [mlir] f2e945a - Revert "[mlir][tensor] Fix insert_slice + tensor cast overflow"

Nicolas Vasilache llvmlistbot at llvm.org
Fri Dec 10 14:54:02 PST 2021


Author: Nicolas Vasilache
Date: 2021-12-10T22:53:52Z
New Revision: f2e945a393511bd79d045a3dd9854264c07bb99f

URL: https://github.com/llvm/llvm-project/commit/f2e945a393511bd79d045a3dd9854264c07bb99f
DIFF: https://github.com/llvm/llvm-project/commit/f2e945a393511bd79d045a3dd9854264c07bb99f.diff

LOG: Revert "[mlir][tensor] Fix insert_slice + tensor cast overflow"

This reverts commit 5601821daec72b221631cfd6175760557281d602.

The prefix + canonical complete behavior is actually obsolete and should not be reintroduced.
Reverting.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 49cfec663ef77..edddfb86e5539 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1417,11 +1417,11 @@ struct InsertSliceOpSourceCastInserter final
       return failure();
     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
                                      srcType.getShape().end());
-    // Offsets / sizes / strides can be a subprefix of the rank; take only the
-    // leading dimensions.
-    for (auto en : llvm::enumerate(insertSliceOp.getMixedSizes()))
-      if (Optional<int64_t> constInt = getConstantIntValue(en.value()))
-        newSrcShape[en.index()] = *constInt;
+    for (int64_t i = 0; i < srcType.getRank(); ++i) {
+      if (Optional<int64_t> constInt =
+              getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
+        newSrcShape[i] = *constInt;
+    }
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 50fda25cce265..fc9abe439b8a2 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -536,21 +536,6 @@ func @insert_tensor_cast_on_insert_slice_src(
 
 // -----
 
-// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src_prefix(
-// CHECK-SAME:      %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
-//      CHECK:    %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x?xf32>
-//      CHECK:    %[[r:.*]] =  tensor.insert_slice %[[cast]] into %[[arg1]][0, 1] [64, 5] [1, 1] : tensor<64x5x?xf32> into tensor<?x?x?xf32>
-//      CHECK:    return %[[r]]
-func @insert_tensor_cast_on_insert_slice_src_prefix(
-    %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
-  %c64 = arith.constant 64: index
-  %r = tensor.insert_slice %arg0 into %arg1[0, 1] [%c64, 5] [1, 1]
-    : tensor<?x5x?xf32> into tensor<?x?x?xf32>
-  return %r : tensor<?x?x?xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func @fold_extract_insert
 //  CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
 func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {


        


More information about the Mlir-commits mailing list