[Mlir-commits] [mlir] b4baccc - Introduce tensor.insert op to Tensor dialect.

Hanhan Wang llvmlistbot at llvm.org
Sun Jun 13 13:46:04 PDT 2021


Author: Hanhan Wang
Date: 2021-06-13T13:45:40-07:00
New Revision: b4baccc2a760ea13901f201e6ca326284254d205

URL: https://github.com/llvm/llvm-project/commit/b4baccc2a760ea13901f201e6ca326284254d205
DIFF: https://github.com/llvm/llvm-project/commit/b4baccc2a760ea13901f201e6ca326284254d205.diff

LOG: Introduce tensor.insert op to Tensor dialect.

Add `tensor.insert` op to make `tensor.extract`/`tensor.insert` work in pairs
for `scalar` domain. Like `subtensor`/`subtensor_insert` work in pairs in
`tensor` domain, and `vector.transfer_read`/`vector.transfer_write` work in
pairs in `vector` domain.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D104139

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/invalid.mlir
    mlir/test/Dialect/Tensor/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 17141da1b3e88..6b06099257d03 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -183,6 +183,57 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// InsertOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_InsertOp : Tensor_Op<"insert",
+    [NoSideEffect,
+     TypesMatchWith<"result type matches type of dest",
+                    "dest", "result",
+                    "$_self.cast<ShapedType>()">,
+     TypesMatchWith<"scalar type matches element type of dest",
+                    "dest", "scalar",
+                    "$_self.cast<ShapedType>().getElementType()">]> {
+  let summary = "element insertion operation";
+  let description = [{
+    The `tensor.insert` op writes a tensor into a tensor `dest`as specified by
+    the operation's indices.
+
+    It returns a copy of `dest` with the proper subtensor updated with the value
+    of `scalar`.
+
+    The arity of indices must match the rank of the tensor `dest` (i.e., if a
+    tensor is of rank 3, then 3 indices are required for the extract. The
+    indices should all be of `index` type.
+
+    Example:
+
+    ```mlir
+    %4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32>
+    %5 = tensor.insert %rt into %dest[%1, %2] : tensor<?x?xi32>
+    %6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32>
+    ```
+  }];
+
+  let arguments = (ins AnyType:$scalar,
+                       AnyTensor:$dest,
+                       Variadic<Index>:$indices);
+  let results = (outs AnyTensor:$result);
+  let assemblyFormat = [{
+    $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$scalar, "Value":$dest,
+      CArg<"ValueRange", "{}">:$indices), [{
+      auto resType = dest.getType();
+      build($_builder, $_state, resType, scalar, dest, indices);
+    }]>];
+
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2c9680adbf1b4..9b1592e9dffbd 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -286,6 +286,28 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExtractElementFromTensorFromElements>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// InsertOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(InsertOp op) {
+  // Verify the # indices match if we have a ranked type.
+  if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
+    if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
+      return op.emitOpError("incorrect number of indices");
+  return success();
+}
+
+OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
+  Attribute scalar = operands[0];
+  Attribute dest = operands[1];
+  if (scalar && dest)
+    if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
+      if (scalar == splatDest.getSplatValue())
+        return dest;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // GenerateOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 478117b325c94..e4f5cc7fe5562 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -96,6 +96,19 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
 
 // -----
 
+// CHECK-LABEL: func @fold_insert
+func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
+  // Fold an insert into a splat.
+  // CHECK-DAG: %[[C4:.+]] = constant dense<4.{{0*}}e+00> : tensor<4xf32>
+  %0 = constant dense<4.0> : tensor<4xf32>
+  %1 = constant 4.0 : f32
+  %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32>
+  // CHECK-NEXT: return %[[C4]]
+  return %ins_1 : tensor<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_from_tensor.cast
 // CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
 func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 79fef8c0f8e47..edbea9a98a987 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -16,6 +16,14 @@ func @extract_too_many_indices(%arg0: tensor<?xf32>) {
 
 // -----
 
+func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+  // expected-error at +1 {{incorrect number of indices}}
+  %0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
+  return
+}
+
+// -----
+
 func @tensor.from_elements_wrong_result_type() {
   // expected-error at +2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
   %c0 = constant 0 : i32

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 450da06a25938..a8bc69933e9cf 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -22,6 +22,19 @@ func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
   return
 }
 
+// CHECK-LABEL:   func @insert(
+// CHECK-SAME:                  %[[SCALAR:.*]]: f32
+// CHECK-SAME:                  %[[INDEX:.*]]: index
+// CHECK-SAME:                  %[[DEST1:.*]]: tensor<?x?x?xf32>
+// CHECK-SAME:                  %[[DEST2:.*]]: tensor<*xf32>
+func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>, %arg3: tensor<*xf32>) {
+  // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
+  %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
+  // CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32>
+  %1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32>
+  return
+}
+
 // CHECK-LABEL: func @tensor.from_elements() {
 func @tensor.from_elements() {
   %c0 = "std.constant"() {value = 0: index} : () -> index


        


More information about the Mlir-commits mailing list