[Mlir-commits] [mlir] cfaf329 - [mlir][tensor] Disallow unranked tensors for tensor.extract/insert
Matthias Springer
llvmlistbot at llvm.org
Thu Oct 27 01:09:39 PDT 2022
Author: Matthias Springer
Date: 2022-10-27T10:09:31+02:00
New Revision: cfaf3292df51090d03c1f98a95668246006813e1
URL: https://github.com/llvm/llvm-project/commit/cfaf3292df51090d03c1f98a95668246006813e1
DIFF: https://github.com/llvm/llvm-project/commit/cfaf3292df51090d03c1f98a95668246006813e1.diff
LOG: [mlir][tensor] Disallow unranked tensors for tensor.extract/insert
When writing a tensor.extract/tensor.insert, the rank of the tensor is implied by the number of specified indices. When extracting from/inserting into an unranked tensor, it should first be casted to a ranked version.
Differential Revision: https://reviews.llvm.org/D136756
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Dialect/Tensor/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index f57852bc6c30b..86f3c06fbaccb 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -216,23 +216,20 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
"$_self.cast<ShapedType>().getElementType()">]> {
let summary = "element extraction operation";
let description = [{
- The `tensor.extract` op reads a tensor and returns one
- element from it specified by an index list. The output of the op is a
- new value with the same type as the elements of the tensor. The
- arity of indices must match the rank of the accessed value (i.e., if a
- tensor is of rank 3, then 3 indices are required for the extract. The
- indices should all be of `index` type.
+ The `tensor.extract` op reads a ranked tensor and returns one element as
+ specified by the given indices. The result of the op is a value with the
+ same type as the elements of the tensor. The arity of indices must match
+ the rank of the accessed value. All indices should all be of `index` type.
Example:
```mlir
%4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
%5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
- %6 = tensor.extract %ut[%1, %2] : tensor<*xi32>
```
}];
- let arguments = (ins AnyTensor:$tensor, Variadic<Index>:$indices);
+ let arguments = (ins AnyRankedTensor:$tensor, Variadic<Index>:$indices);
let results = (outs AnyType:$result);
let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
@@ -242,6 +239,7 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
build($_builder, $_state, resType, tensor, indices);
}]>];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
@@ -684,35 +682,33 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
Pure,
TypesMatchWith<"result type matches type of dest",
"dest", "result",
- "$_self.cast<ShapedType>()">,
+ "$_self">,
TypesMatchWith<"scalar type matches element type of dest",
"dest", "scalar",
"$_self.cast<ShapedType>().getElementType()">]> {
let summary = "element insertion operation";
let description = [{
- The `tensor.insert` op writes a tensor into a tensor `dest`as specified by
- the operation's indices.
+ The `tensor.insert` op inserts a scalar into a ranked tensor `dest` as
+ specified by the operation's indices.
- It returns a copy of `dest` with the proper slice updated with the value
+ It returns a copy of `dest` with the indexed position updated to the value
of `scalar`.
- The arity of indices must match the rank of the tensor `dest` (i.e., if a
- tensor is of rank 3, then 3 indices are required for the extract. The
- indices should all be of `index` type.
+ The arity of `indices `must match the rank of the tensor `dest`. All
+ indices should be of `index` type.
Example:
```mlir
%4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32>
%5 = tensor.insert %rt into %dest[%1, %2] : tensor<?x?xi32>
- %6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32>
```
}];
let arguments = (ins AnyType:$scalar,
- AnyTensor:$dest,
+ AnyRankedTensor:$dest,
Variadic<Index>:$indices);
- let results = (outs AnyTensor:$result);
+ let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
}];
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 282a342df1d8a..445e78e295fd1 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -773,6 +773,34 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ExtractOp
//===----------------------------------------------------------------------===//
+namespace {
+
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
+/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
+///
+/// to
+///
+/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
+struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+ PatternRewriter &rewriter) const final {
+ auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
+ if (!tensorCast)
+ return failure();
+ if (!tensorCast.getSource().getType().isa<RankedTensorType>())
+ return failure();
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+ extract, tensorCast.getSource(), extract.getIndices());
+ return success();
+ }
+};
+
+} // namespace
+
void ExtractOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "extracted");
@@ -780,10 +808,9 @@ void ExtractOp::getAsmResultNames(
LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
- if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
- if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
- return emitOpError("incorrect number of indices for extract_element");
-
+ auto tensorType = getTensor().getType().cast<RankedTensorType>();
+ if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
+ return emitOpError("incorrect number of indices for extract_element");
return success();
}
@@ -833,6 +860,11 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ExtractFromTensorCast>(context);
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
@@ -1009,9 +1041,9 @@ void InsertOp::getAsmResultNames(
LogicalResult InsertOp::verify() {
// Verify the # indices match if we have a ranked type.
- if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
- if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
- return emitOpError("incorrect number of indices");
+ auto destType = getDest().getType().cast<RankedTensorType>();
+ if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
+ return emitOpError("incorrect number of indices");
return success();
}
@@ -1181,36 +1213,12 @@ struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
}
};
-/// Canonicalizes the pattern of the form
-///
-/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
-/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
-///
-/// to
-///
-/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
-struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::ExtractOp extract,
- PatternRewriter &rewriter) const final {
- auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
- if (!tensorCast)
- return failure();
-
- rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
- extract, tensorCast.getSource(), extract.getIndices());
- return success();
- }
-};
-
} // namespace
void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // TODO: Move extract patterns to tensor::ExtractOp.
- results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
- StaticTensorGenerate>(context);
+ // TODO: Move extract pattern to tensor::ExtractOp.
+ results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 28bcabc60dbd9..9cddfd88735ab 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -115,12 +115,12 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
// -----
// CHECK-LABEL: func @extract_from_tensor.cast
-// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
-func.func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
+func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK-NOT: tensor.cast
- %casted = tensor.cast %tensor : tensor<*xf32> to tensor<?xf32>
+ %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32>
// CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
%result = tensor.extract %casted[%c0] : tensor<?xf32>
return %result : f32
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index f9a81e8c490b4..aadf6ab90250d 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -34,12 +34,9 @@ func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
// CHECK-SAME: %[[SCALAR:.*]]: f32
// CHECK-SAME: %[[INDEX:.*]]: index
// CHECK-SAME: %[[DEST1:.*]]: tensor<?x?x?xf32>
-// CHECK-SAME: %[[DEST2:.*]]: tensor<*xf32>
-func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>, %arg3: tensor<*xf32>) {
+func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>) {
// CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
%0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
- // CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32>
- %1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32>
return
}
More information about the Mlir-commits
mailing list