[Mlir-commits] [mlir] cd0d095 - [mlir][tensor] Check ops generated by InsertSliceOpCastFolder are valid
Ivan Butygin
llvmlistbot at llvm.org
Sun Feb 13 10:38:33 PST 2022
Author: Ivan Butygin
Date: 2022-02-13T21:37:31+03:00
New Revision: cd0d095c07b6dc925354cb8a9a54cafe654a6c4d
URL: https://github.com/llvm/llvm-project/commit/cd0d095c07b6dc925354cb8a9a54cafe654a6c4d
DIFF: https://github.com/llvm/llvm-project/commit/cd0d095c07b6dc925354cb8a9a54cafe654a6c4d.diff
LOG: [mlir][tensor] Check ops generated by InsertSliceOpCastFolder are valid
Fixes https://github.com/llvm/llvm-project/issues/53099
Differential Revision: https://reviews.llvm.org/D119663
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 71871ee50381f..a13a274c28e2a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1305,16 +1305,29 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
+static SliceVerificationResult
+verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
+ ArrayAttr staticOffsets, ArrayAttr staticSizes,
+ ArrayAttr staticStrides,
+ ShapedType *expectedType = nullptr) {
+ // insert_slice is the inverse of extract_slice, use the same type inference.
+ auto expected = ExtractSliceOp::inferRankReducedResultType(
+ srcType.getRank(), dstType.cast<RankedTensorType>(),
+ extractFromI64ArrayAttr(staticOffsets),
+ extractFromI64ArrayAttr(staticSizes),
+ extractFromI64ArrayAttr(staticStrides))
+ .cast<ShapedType>();
+ if (expectedType)
+ *expectedType = expected;
+ return isRankReducedType(expected, srcType);
+}
+
/// Verifier for InsertSliceOp.
LogicalResult InsertSliceOp::verify() {
- // insert_slice is the inverse of extract_slice, use the same type inference.
- auto expectedType = ExtractSliceOp::inferRankReducedResultType(
- getSourceType().getRank(), getType(),
- extractFromI64ArrayAttr(static_offsets()),
- extractFromI64ArrayAttr(static_sizes()),
- extractFromI64ArrayAttr(static_strides()));
+ ShapedType expectedType;
auto result =
- isRankReducedType(expectedType.cast<ShapedType>(), getSourceType());
+ verifyInsertSliceOp(getSourceType(), getType(), static_offsets(),
+ static_sizes(), static_strides(), &expectedType);
return produceSliceErrorMsg(result, *this, expectedType);
}
@@ -1446,12 +1459,20 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
if (!sourceCastSource && !destCastSource)
return failure();
+ auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.source());
+ auto dst = (destCastSource ? *destCastSource : insertSliceOp.dest());
+
+ auto srcType = src.getType().cast<ShapedType>();
+ auto dstType = dst.getType().cast<ShapedType>();
+ if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.static_offsets(),
+ insertSliceOp.static_sizes(),
+ insertSliceOp.static_strides()) !=
+ SliceVerificationResult::Success)
+ return failure();
+
Value replacement = rewriter.create<InsertSliceOp>(
- insertSliceOp.getLoc(),
- (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
- (destCastSource ? *destCastSource : insertSliceOp.dest()),
- insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
- insertSliceOp.getMixedStrides());
+ insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
+ insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
if (replacement.getType() != insertSliceOp.getType()) {
replacement = rewriter.create<tensor::CastOp>(
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f489c4188523f..4e4bbb8a12672 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1231,3 +1231,18 @@ func @splat_fold() -> tensor<4xf32> {
// CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
// CHECK-NEXT: return [[T]] : tensor<4xf32>
}
+
+// -----
+
+// There was an issue in cast + insert_slice folding generating invalid ir.
+// https://github.com/llvm/llvm-project/issues/53099
+// CHECK-LABEL: func @insert_slice_cast
+func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
+ // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
+ %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
+ // CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
+ %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
+ // CHECK: return %[[RES]] : tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list