[Mlir-commits] [mlir] [mlir] Change `tensor.extract/insert` to take static/dynamic indices. (PR #104488)

Yi Zhang llvmlistbot at llvm.org
Thu Aug 15 13:08:49 PDT 2024


https://github.com/cathyzhyi updated https://github.com/llvm/llvm-project/pull/104488

>From dba24f6714b8e5b3790ef3a8461b77bc7dd7e6b2 Mon Sep 17 00:00:00 2001
From: Yi Zhang <cathyzhyi at google.com>
Date: Thu, 15 Aug 2024 15:08:22 -0400
Subject: [PATCH] [mlir] Change `tensor.extract/insert` to take static/dynamic
 indices.

This changes the ODS of `tensor.extract/insert` op. Some new builder methods are
added and the verifiers/canonicalizers are updated. One of the canonicalization
pattern of `shape.shape_of` is also updated.
---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  52 +++++++-
 mlir/lib/Dialect/Shape/IR/Shape.cpp           |  26 ++++
 .../Dialect/Shape/IR/ShapeCanonicalization.td |   6 -
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 114 ++++++++++++++++--
 mlir/test/Dialect/Shape/canonicalize.mlir     |  13 ++
 mlir/test/Dialect/Tensor/canonicalize.mlir    |  12 +-
 mlir/test/Dialect/Tensor/invalid.mlir         |  38 +++++-
 mlir/test/Dialect/Tensor/ops.mlir             |   6 +
 8 files changed, 242 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..997d0ccb28d769 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -332,12 +332,37 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
     ```mlir
     %4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
     %5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
+    %6 = tensor.extract %rt[3, 4] : tensor<?x?xi32>
+    %7 = tensor.extract %rt[%1, 4] : tensor<?x?xi32>
     ```
   }];
 
-  let arguments = (ins AnyRankedTensor:$tensor, Variadic<Index>:$indices);
+  let arguments = (ins
+    AnyRankedTensor:$tensor,
+    Variadic<Index>:$indices,
+    DenseI64ArrayAttr:$static_indices
+  );
   let results = (outs AnyType:$result);
-  let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
+  let assemblyFormat = [{
+    $tensor ``
+    custom<DynamicIndexList>($indices, $static_indices)
+    attr-dict `:` type($tensor)
+  }];
+
+  let builders = [
+    // Build an ExtractOp with mixed static and dynamic indexes.
+    OpBuilder<(ins "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an ExtractOp with mixed static, dynamic indexes and inferred result type.
+    OpBuilder<(ins "Type":$resultType, "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an ExtractOp with dynamic indexes.
+    OpBuilder<(ins "Value":$source, CArg<"ValueRange", "{}">:$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an ExtractOp with dynamic indexes and inferred result type.
+    OpBuilder<(ins "Type":$resultType, "Value":$source, CArg<"ValueRange", "{}">:$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+  ];
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
@@ -808,16 +833,35 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
   let arguments = (ins AnyType:$scalar,
                        AnyRankedTensor:$dest,
-                       Variadic<Index>:$indices);
+                       Variadic<Index>:$indices,
+                       DenseI64ArrayAttr:$static_indices
+  );
   let results = (outs AnyRankedTensor:$result);
   let assemblyFormat = [{
-    $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
+    $scalar `into`
+    $dest `` custom<DynamicIndexList>($indices, $static_indices)
+    attr-dict `:` type($dest)
   }];
 
   let extraClassDeclaration = [{
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
   }];
 
+  let builders = [
+    // Build an InsertOp with mixed static and dynamic indexes.
+    OpBuilder<(ins "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an InsertOp with mixed static, dynamic indexes and inferred result type.
+    OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an InsertOp with dynamic indexes.
+    OpBuilder<(ins "Value":$scalar, "Value":$dest,  CArg<"ValueRange", "{}">:$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an InsertOp with dynamic indexes and inferred result type.
+    OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 8eb8e579954faa..89184f2162c2c4 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1736,6 +1736,32 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
   }
 };
 
+struct ExtractFromShapeOfExtentTensor
+    : public OpRewritePattern<tensor::ExtractOp> {
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    auto tensorShapeOfOp = op.getTensor().getDefiningOp<shape::ShapeOfOp>();
+    if (!tensorShapeOfOp)
+      return rewriter.notifyMatchFailure(op, "producer is not shape.shape_of");
+
+    int64_t staticIndice = op.getStaticIndices()[0];
+    Type indexType = rewriter.getIndexType();
+    Value indice =
+        staticIndice != ShapedType::kDynamic
+            ? tensorShapeOfOp->getDialect()
+                  ->materializeConstant(
+                      rewriter, IntegerAttr::get(indexType, staticIndice),
+                      indexType, op.getLoc())
+                  ->getResult(0)
+            : op.getIndices()[0];
+    rewriter.replaceOpWithNewOp<tensor::DimOp>(op, tensorShapeOfOp.getArg(),
+                                               indice);
+    return success();
+  }
+};
+
 // Canonicalize
 // ```
 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index cb294ae2978fce..e135105d6980b6 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -44,9 +44,3 @@ def SizeToIndexToSizeCanonicalization : Pat<
 def TensorCastConstShape : Pat <
   (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
   [(HasStaticShape $res)]>;
-
-// tensor.extract from shape_of -> tensor.dim. We can take the first index
-// because shape_of always returns a 1D tensor.
-def ExtractFromShapeOfExtentTensor : Pat<
-  (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
-  (Tensor_DimOp $arg, (TakeFront $indices))>;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e11c6aaccf74dd..73dc98ee93ed45 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -39,6 +39,19 @@ using llvm::divideCeilSigned;
 using llvm::divideFloorSigned;
 using llvm::mod;
 
+static LogicalResult
+checkTensorRankMatchIndices(Value tensor, ValueRange dynamicIndices,
+                            ArrayRef<int64_t> staticIndices) {
+  auto tensorType = llvm::cast<RankedTensorType>(tensor.getType());
+  int64_t dynamicDimCount = llvm::count_if(staticIndices, [](int64_t element) {
+    return element == ShapedType::kDynamic;
+  });
+  if (tensorType.getRank() != staticIndices.size() ||
+      dynamicDimCount != static_cast<int64_t>(dynamicIndices.size()))
+    return LogicalResult::failure();
+  return LogicalResult::success();
+}
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
@@ -1120,10 +1133,49 @@ void ExtractOp::getAsmResultNames(
   setNameFn(getResult(), "extracted");
 }
 
+// Build an ExtractOp with mixed static and dynamic indexes.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
+                      ArrayRef<OpFoldResult> indices,
+                      ArrayRef<NamedAttribute> attrs) {
+  Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
+  build(b, result, resultType, tensor, indices, attrs);
+}
+
+// Build an ExtractOp with mixed static, dynamic indexes and inferred result
+// Type.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                      Value tensor, ArrayRef<OpFoldResult> indices,
+                      ArrayRef<NamedAttribute> attrs) {
+  SmallVector<int64_t> staticIndices;
+  SmallVector<Value> dynamicIndices;
+  dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
+  result.addAttributes(attrs);
+  build(b, result, resultType, tensor, dynamicIndices,
+        b.getDenseI64ArrayAttr(staticIndices));
+}
+
+// Build an ExtractOp with dynamic indexes and inferred result type.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                      Value tensor, ValueRange indices,
+                      ArrayRef<NamedAttribute> attrs) {
+  SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+      llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+  build(b, result, resultType, tensor, indicesValues, attrs);
+}
+
+// Build an ExtractOp with dynamic indexes.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
+                      ValueRange indices, ArrayRef<NamedAttribute> attrs) {
+  Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
+  SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+      llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+  build(b, result, resultType, tensor, indicesValues, attrs);
+}
+
 LogicalResult ExtractOp::verify() {
   // Verify the # indices match if we have a ranked type.
-  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
-  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
+  if (failed(checkTensorRankMatchIndices(getTensor(), getIndices(),
+                                         getStaticIndices())))
     return emitOpError("incorrect number of indices for extract_element");
   return success();
 }
@@ -1137,12 +1189,18 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
 
   // Collect the constant indices into the tensor.
   SmallVector<uint64_t, 8> indices;
-  for (Attribute indice : adaptor.getIndices()) {
-    if (!indice || !llvm::isa<IntegerAttr>(indice))
-      return {};
-    indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
+  auto dynamicIndicesIt = adaptor.getIndices().begin();
+  for (int64_t i : getStaticIndices()) {
+    if (i != ShapedType::kDynamic) {
+      indices.push_back(i);
+    } else {
+      Attribute indice = *dynamicIndicesIt;
+      if (!indice || !llvm::isa<IntegerAttr>(indice))
+        return {};
+      indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
+      dynamicIndicesIt++;
+    }
   }
-
   // Fold extract(from_elements(...)).
   if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
     auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
@@ -1354,10 +1412,48 @@ void InsertOp::getAsmResultNames(
   setNameFn(getResult(), "inserted");
 }
 
+// Build an ExtractOp with mixed static and dynamic indexes.
+void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
+                     Value dest, ArrayRef<OpFoldResult> indices,
+                     ArrayRef<NamedAttribute> attrs) {
+  build(b, result, dest.getType(), scalar, dest, indices, attrs);
+}
+
+// Build an InsertOp with mixed static, dynamic indexes and inferred result
+// Type.
+void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                     Value scalar, Value dest, ArrayRef<OpFoldResult> indices,
+                     ArrayRef<NamedAttribute> attrs) {
+  SmallVector<int64_t> staticIndices;
+  SmallVector<Value> dynamicIndices;
+  dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
+  result.addAttributes(attrs);
+  build(b, result, resultType, scalar, dest, dynamicIndices,
+        b.getDenseI64ArrayAttr(staticIndices));
+}
+
+// Build an ExtractOp with dynamic indexes and inferred result type.
+void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                     Value scalar, Value dest, ValueRange indices,
+                     ArrayRef<NamedAttribute> attrs) {
+  SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+      llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+  build(b, result, resultType, scalar, dest, indicesValues, attrs);
+}
+
+// Build an InsertOp with dynamic indexes.
+void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
+                     Value dest, ValueRange indices,
+                     ArrayRef<NamedAttribute> attrs) {
+  SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+      llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+  build(b, result, dest.getType(), scalar, dest, indicesValues, attrs);
+}
+
 LogicalResult InsertOp::verify() {
   // Verify the # indices match if we have a ranked type.
-  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
-  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
+  if (failed(checkTensorRankMatchIndices(getDest(), getIndices(),
+                                         getStaticIndices())))
     return emitOpError("incorrect number of indices");
   return success();
 }
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 5b98a7790debf2..8c04e574dbc518 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1519,6 +1519,19 @@ func.func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
  return %result : index
 }
 
+// -----
+
+// CHECK-LABEL: func @extract_shapeof_static_indice
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<?x?xf64>
+func.func @extract_shapeof_static_indice(%arg0 : tensor<?x?xf64>) -> index {
+// CHECK:        %[[C1:.*]] = arith.constant 1
+ %shape = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
+// CHECK:        %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+ %result = tensor.extract %shape[1] : tensor<2xindex>
+// CHECK:        return %[[DIM]]
+ return %result : index
+}
+
 
 // -----
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4b8efde78cc23c..8f7c7478669b4f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -137,11 +137,12 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
 // -----
 
 // CHECK-LABEL: func @fold_extract
-func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
+func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, i32, complex<f32>) {
   %const_0 = arith.constant 0 : index
   %const_1 = arith.constant 1 : index
   %const_3 = arith.constant 3 : index
   // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
+  // CHECK-DAG: [[CNEG1:%.+]] = arith.constant -1 : i32
   // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
   // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
 
@@ -162,13 +163,16 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
   %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
   %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
 
+  // Fold an extract into a dense tensor with mixed dynamic and static indexes.
+  %ext_5 = tensor.extract %3[%const_1, 0, 2] : tensor<2x1x4xi32>
+
   // Fold an extract into a complex constant.
   // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
   %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
-  %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
+  %ext_6 = tensor.extract %4[] : tensor<complex<f32>>
 
-  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
-  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[CNEG1]], [[C5]]
+  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6: f32, f16, f16, i32, i32, complex<f32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 41b6529f64afa3..8c594ddacb8d33 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -64,7 +64,7 @@ func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
 
 // -----
 
-func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
+func.func @extract_too_few_indices(%arg0: tensor<?xf32>) {
   // expected-error at +1 {{incorrect number of indices for extract_element}}
   %0 = tensor.extract %arg0[] : tensor<?xf32>
   return
@@ -72,7 +72,24 @@ func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
 
 // -----
 
-func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+func.func @extract_too_many_static_indices(%arg0: tensor<?xf32>) {
+  // expected-error at +1 {{incorrect number of indices for extract_element}}
+  %0 = tensor.extract %arg0[2, 3] : tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @extract_too_many_mixed_indices(%arg0: tensor<?xf32>) {
+  %c1 = arith.constant 1 : index
+  // expected-error at +1 {{incorrect number of indices for extract_element}}
+  %0 = tensor.extract %arg0[%c1, 2, 3] : tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @insert_too_few_indices(%arg0: f32, %arg1: tensor<?xf32>) {
   // expected-error at +1 {{incorrect number of indices}}
   %0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
   return
@@ -80,6 +97,23 @@ func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
 
 // -----
 
+func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+  // expected-error at +1 {{incorrect number of indices}}
+  %0 = tensor.insert %arg0 into %arg1[2, 3] : tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @insert_too_many_mixed_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+  %c1 = arith.constant 1 : index
+  // expected-error at +1 {{incorrect number of indices}}
+  %0 = tensor.insert %arg0 into %arg1[%c1, 2, 3] : tensor<?xf32>
+  return
+}
+
+// -----
+
 func.func @tensor.from_elements_wrong_result_type() {
   // expected-error at +2 {{'tensor.from_elements' invalid kind of type specified}}
   %c0 = arith.constant 0 : i32
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 378137a14b59ff..0a4cd08239c5b4 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -58,6 +58,9 @@ func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> {
 func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
   // CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
   %0 = tensor.extract %arg0[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
+
+  // CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], 2, 3] : tensor<?x?x?xf32>
+  %1 = tensor.extract %arg0[%arg1, 2, 3] : tensor<?x?x?xf32>
   return
 }
 
@@ -70,6 +73,9 @@ func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
 func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>) {
   // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
   %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
+
+  // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], 2, 3] : tensor<?x?x?xf32>
+  %1 = tensor.insert %arg0 into %arg2[%arg1, 2, 3] : tensor<?x?x?xf32>
   return
 }
 



More information about the Mlir-commits mailing list