[Mlir-commits] [mlir] cd0d095 - [mlir][tensor] Check ops generated by InsertSliceOpCastFolder are valid

Ivan Butygin llvmlistbot at llvm.org
Sun Feb 13 10:38:33 PST 2022


Author: Ivan Butygin
Date: 2022-02-13T21:37:31+03:00
New Revision: cd0d095c07b6dc925354cb8a9a54cafe654a6c4d

URL: https://github.com/llvm/llvm-project/commit/cd0d095c07b6dc925354cb8a9a54cafe654a6c4d
DIFF: https://github.com/llvm/llvm-project/commit/cd0d095c07b6dc925354cb8a9a54cafe654a6c4d.diff

LOG: [mlir][tensor] Check ops generated by InsertSliceOpCastFolder are valid

Fixes https://github.com/llvm/llvm-project/issues/53099

Differential Revision: https://reviews.llvm.org/D119663

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 71871ee50381f..a13a274c28e2a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1305,16 +1305,29 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
 }
 
+static SliceVerificationResult
+verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
+                    ArrayAttr staticOffsets, ArrayAttr staticSizes,
+                    ArrayAttr staticStrides,
+                    ShapedType *expectedType = nullptr) {
+  // insert_slice is the inverse of extract_slice, use the same type inference.
+  auto expected = ExtractSliceOp::inferRankReducedResultType(
+                      srcType.getRank(), dstType.cast<RankedTensorType>(),
+                      extractFromI64ArrayAttr(staticOffsets),
+                      extractFromI64ArrayAttr(staticSizes),
+                      extractFromI64ArrayAttr(staticStrides))
+                      .cast<ShapedType>();
+  if (expectedType)
+    *expectedType = expected;
+  return isRankReducedType(expected, srcType);
+}
+
 /// Verifier for InsertSliceOp.
 LogicalResult InsertSliceOp::verify() {
-  // insert_slice is the inverse of extract_slice, use the same type inference.
-  auto expectedType = ExtractSliceOp::inferRankReducedResultType(
-      getSourceType().getRank(), getType(),
-      extractFromI64ArrayAttr(static_offsets()),
-      extractFromI64ArrayAttr(static_sizes()),
-      extractFromI64ArrayAttr(static_strides()));
+  ShapedType expectedType;
   auto result =
-      isRankReducedType(expectedType.cast<ShapedType>(), getSourceType());
+      verifyInsertSliceOp(getSourceType(), getType(), static_offsets(),
+                          static_sizes(), static_strides(), &expectedType);
   return produceSliceErrorMsg(result, *this, expectedType);
 }
 
@@ -1446,12 +1459,20 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
     if (!sourceCastSource && !destCastSource)
       return failure();
 
+    auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.source());
+    auto dst = (destCastSource ? *destCastSource : insertSliceOp.dest());
+
+    auto srcType = src.getType().cast<ShapedType>();
+    auto dstType = dst.getType().cast<ShapedType>();
+    if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.static_offsets(),
+                            insertSliceOp.static_sizes(),
+                            insertSliceOp.static_strides()) !=
+        SliceVerificationResult::Success)
+      return failure();
+
     Value replacement = rewriter.create<InsertSliceOp>(
-        insertSliceOp.getLoc(),
-        (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
-        (destCastSource ? *destCastSource : insertSliceOp.dest()),
-        insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
-        insertSliceOp.getMixedStrides());
+        insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
+        insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
 
     if (replacement.getType() != insertSliceOp.getType()) {
       replacement = rewriter.create<tensor::CastOp>(

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f489c4188523f..4e4bbb8a12672 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1231,3 +1231,18 @@ func @splat_fold() -> tensor<4xf32> {
   // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
   // CHECK-NEXT: return [[T]] : tensor<4xf32>
 }
+
+// -----
+
+// There was an issue in cast + insert_slice folding generating invalid ir.
+// https://github.com/llvm/llvm-project/issues/53099
+// CHECK-LABEL: func @insert_slice_cast
+func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
+  // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
+  %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
+  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
+  // CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
+  %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
+  // CHECK: return %[[RES]] : tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}


        


More information about the Mlir-commits mailing list