[Mlir-commits] [mlir] b4baccc - Introduce tensor.insert op to Tensor dialect.
Hanhan Wang
llvmlistbot at llvm.org
Sun Jun 13 13:46:04 PDT 2021
Author: Hanhan Wang
Date: 2021-06-13T13:45:40-07:00
New Revision: b4baccc2a760ea13901f201e6ca326284254d205
URL: https://github.com/llvm/llvm-project/commit/b4baccc2a760ea13901f201e6ca326284254d205
DIFF: https://github.com/llvm/llvm-project/commit/b4baccc2a760ea13901f201e6ca326284254d205.diff
LOG: Introduce tensor.insert op to Tensor dialect.
Add `tensor.insert` op to make `tensor.extract`/`tensor.insert` work in pairs
for `scalar` domain. Like `subtensor`/`subtensor_insert` work in pairs in
`tensor` domain, and `vector.transfer_read`/`vector.transfer_write` work in
pairs in `vector` domain.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D104139
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 17141da1b3e88..6b06099257d03 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -183,6 +183,57 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// InsertOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_InsertOp : Tensor_Op<"insert",
+ [NoSideEffect,
+ TypesMatchWith<"result type matches type of dest",
+ "dest", "result",
+ "$_self.cast<ShapedType>()">,
+ TypesMatchWith<"scalar type matches element type of dest",
+ "dest", "scalar",
+ "$_self.cast<ShapedType>().getElementType()">]> {
+ let summary = "element insertion operation";
+ let description = [{
+ The `tensor.insert` op writes a tensor into a tensor `dest`as specified by
+ the operation's indices.
+
+ It returns a copy of `dest` with the proper subtensor updated with the value
+ of `scalar`.
+
+ The arity of indices must match the rank of the tensor `dest` (i.e., if a
+ tensor is of rank 3, then 3 indices are required for the extract. The
+ indices should all be of `index` type.
+
+ Example:
+
+ ```mlir
+ %4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32>
+ %5 = tensor.insert %rt into %dest[%1, %2] : tensor<?x?xi32>
+ %6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32>
+ ```
+ }];
+
+ let arguments = (ins AnyType:$scalar,
+ AnyTensor:$dest,
+ Variadic<Index>:$indices);
+ let results = (outs AnyTensor:$result);
+ let assemblyFormat = [{
+ $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$scalar, "Value":$dest,
+ CArg<"ValueRange", "{}">:$indices), [{
+ auto resType = dest.getType();
+ build($_builder, $_state, resType, scalar, dest, indices);
+ }]>];
+
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2c9680adbf1b4..9b1592e9dffbd 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -286,6 +286,28 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExtractElementFromTensorFromElements>(context);
}
+//===----------------------------------------------------------------------===//
+// InsertOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(InsertOp op) {
+ // Verify the # indices match if we have a ranked type.
+ if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
+ if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
+ return op.emitOpError("incorrect number of indices");
+ return success();
+}
+
+OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
+ Attribute scalar = operands[0];
+ Attribute dest = operands[1];
+ if (scalar && dest)
+ if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
+ if (scalar == splatDest.getSplatValue())
+ return dest;
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 478117b325c94..e4f5cc7fe5562 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -96,6 +96,19 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
// -----
+// CHECK-LABEL: func @fold_insert
+func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
+ // Fold an insert into a splat.
+ // CHECK-DAG: %[[C4:.+]] = constant dense<4.{{0*}}e+00> : tensor<4xf32>
+ %0 = constant dense<4.0> : tensor<4xf32>
+ %1 = constant 4.0 : f32
+ %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32>
+ // CHECK-NEXT: return %[[C4]]
+ return %ins_1 : tensor<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.cast
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 79fef8c0f8e47..edbea9a98a987 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -16,6 +16,14 @@ func @extract_too_many_indices(%arg0: tensor<?xf32>) {
// -----
+func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+ // expected-error at +1 {{incorrect number of indices}}
+ %0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
+ return
+}
+
+// -----
+
func @tensor.from_elements_wrong_result_type() {
// expected-error at +2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
%c0 = constant 0 : i32
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 450da06a25938..a8bc69933e9cf 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -22,6 +22,19 @@ func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
return
}
+// CHECK-LABEL: func @insert(
+// CHECK-SAME: %[[SCALAR:.*]]: f32
+// CHECK-SAME: %[[INDEX:.*]]: index
+// CHECK-SAME: %[[DEST1:.*]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[DEST2:.*]]: tensor<*xf32>
+func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>, %arg3: tensor<*xf32>) {
+ // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
+ %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
+ // CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32>
+ %1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32>
+ return
+}
+
// CHECK-LABEL: func @tensor.from_elements() {
func @tensor.from_elements() {
%c0 = "std.constant"() {value = 0: index} : () -> index
More information about the Mlir-commits
mailing list