[Mlir-commits] [mlir] [mlir] Change `tensor.extract/insert` to take static/dynamic indices. (PR #104488)
Yi Zhang
llvmlistbot at llvm.org
Thu Aug 15 12:36:56 PDT 2024
https://github.com/cathyzhyi created https://github.com/llvm/llvm-project/pull/104488
This changes the ODS of `tensor.extract/insert` op. Some new builder methods are added and the verifiers/canonicalizers are updated. One of the canonicalization pattern of `shape.shape_of` is also updated.
>From 0cd08bf39376f35c6b8efcaeb2879284fbfb71a0 Mon Sep 17 00:00:00 2001
From: Yi Zhang <cathyzhyi at google.com>
Date: Thu, 15 Aug 2024 15:08:22 -0400
Subject: [PATCH] [mlir] Change `tensor.extract/insert` to take static/dynamic
indices.
This changes the ODS of `tensor.extract/insert` op. Some new builder methods are
added and the verifiers/canonicalizers are updated. One of the canonicalization
pattern of `shape.shape_of` is also updated.
---
.../mlir/Dialect/Tensor/IR/TensorOps.td | 52 +++++++-
mlir/lib/Dialect/Shape/IR/Shape.cpp | 26 ++++
.../Dialect/Shape/IR/ShapeCanonicalization.td | 6 -
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 117 ++++++++++++++++--
mlir/test/Dialect/Shape/canonicalize.mlir | 13 ++
mlir/test/Dialect/Tensor/canonicalize.mlir | 12 +-
mlir/test/Dialect/Tensor/invalid.mlir | 38 +++++-
mlir/test/Dialect/Tensor/ops.mlir | 6 +
8 files changed, 245 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..997d0ccb28d769 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -332,12 +332,37 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
```mlir
%4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
%5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
+ %6 = tensor.extract %rt[3, 4] : tensor<?x?xi32>
+ %7 = tensor.extract %rt[%1, 4] : tensor<?x?xi32>
```
}];
- let arguments = (ins AnyRankedTensor:$tensor, Variadic<Index>:$indices);
+ let arguments = (ins
+ AnyRankedTensor:$tensor,
+ Variadic<Index>:$indices,
+ DenseI64ArrayAttr:$static_indices
+ );
let results = (outs AnyType:$result);
- let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
+ let assemblyFormat = [{
+ $tensor ``
+ custom<DynamicIndexList>($indices, $static_indices)
+ attr-dict `:` type($tensor)
+ }];
+
+ let builders = [
+ // Build an ExtractOp with mixed static and dynamic indexes.
+ OpBuilder<(ins "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an ExtractOp with mixed static, dynamic indexes and inferred result type.
+ OpBuilder<(ins "Type":$resultType, "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an ExtractOp with dynamic indexes.
+ OpBuilder<(ins "Value":$source, CArg<"ValueRange", "{}">:$indexes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an ExtractOp with dynamic indexes and inferred result type.
+ OpBuilder<(ins "Type":$resultType, "Value":$source, CArg<"ValueRange", "{}">:$indexes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ ];
let hasCanonicalizer = 1;
let hasFolder = 1;
@@ -808,16 +833,35 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let arguments = (ins AnyType:$scalar,
AnyRankedTensor:$dest,
- Variadic<Index>:$indices);
+ Variadic<Index>:$indices,
+ DenseI64ArrayAttr:$static_indices
+ );
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
- $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
+ $scalar `into`
+ $dest `` custom<DynamicIndexList>($indices, $static_indices)
+ attr-dict `:` type($dest)
}];
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
}];
+ let builders = [
+ // Build an InsertOp with mixed static and dynamic indexes.
+ OpBuilder<(ins "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an InsertOp with mixed static, dynamic indexes and inferred result type.
+ OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an InsertOp with dynamic indexes.
+ OpBuilder<(ins "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an InsertOp with dynamic indexes and inferred result type.
+ OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ ];
+
let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 8eb8e579954faa..89184f2162c2c4 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1736,6 +1736,32 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
}
};
+struct ExtractFromShapeOfExtentTensor
+ : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ auto tensorShapeOfOp = op.getTensor().getDefiningOp<shape::ShapeOfOp>();
+ if (!tensorShapeOfOp)
+ return rewriter.notifyMatchFailure(op, "producer is not shape.shape_of");
+
+ int64_t staticIndice = op.getStaticIndices()[0];
+ Type indexType = rewriter.getIndexType();
+ Value indice =
+ staticIndice != ShapedType::kDynamic
+ ? tensorShapeOfOp->getDialect()
+ ->materializeConstant(
+ rewriter, IntegerAttr::get(indexType, staticIndice),
+ indexType, op.getLoc())
+ ->getResult(0)
+ : op.getIndices()[0];
+ rewriter.replaceOpWithNewOp<tensor::DimOp>(op, tensorShapeOfOp.getArg(),
+ indice);
+ return success();
+ }
+};
+
// Canonicalize
// ```
// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index cb294ae2978fce..e135105d6980b6 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -44,9 +44,3 @@ def SizeToIndexToSizeCanonicalization : Pat<
def TensorCastConstShape : Pat <
(Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
[(HasStaticShape $res)]>;
-
-// tensor.extract from shape_of -> tensor.dim. We can take the first index
-// because shape_of always returns a 1D tensor.
-def ExtractFromShapeOfExtentTensor : Pat<
- (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
- (Tensor_DimOp $arg, (TakeFront $indices))>;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e11c6aaccf74dd..bb4d3eccc7377f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -27,7 +28,9 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/MathExtras.h"
#include <algorithm>
#include <optional>
@@ -39,6 +42,19 @@ using llvm::divideCeilSigned;
using llvm::divideFloorSigned;
using llvm::mod;
+static LogicalResult
+checkTensorRankMatchIndices(Value tensor, ValueRange dynamicIndices,
+ ArrayRef<int64_t> staticIndices) {
+ auto tensorType = llvm::cast<RankedTensorType>(tensor.getType());
+ int64_t dynamicDimCount = llvm::count_if(staticIndices, [](int64_t element) {
+ return element == ShapedType::kDynamic;
+ });
+ if (tensorType.getRank() != staticIndices.size() ||
+ dynamicDimCount != static_cast<int64_t>(dynamicIndices.size()))
+ return LogicalResult::failure();
+ return LogicalResult::success();
+}
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
@@ -1120,10 +1136,49 @@ void ExtractOp::getAsmResultNames(
setNameFn(getResult(), "extracted");
}
+// Build an ExtractOp with mixed static and dynamic indexes.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
+ ArrayRef<OpFoldResult> indices,
+ ArrayRef<NamedAttribute> attrs) {
+ Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
+ build(b, result, resultType, tensor, indices, attrs);
+}
+
+// Build an ExtractOp with mixed static, dynamic indexes and inferred result
+// Type.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
+ Value tensor, ArrayRef<OpFoldResult> indices,
+ ArrayRef<NamedAttribute> attrs) {
+ SmallVector<int64_t> staticIndices;
+ SmallVector<Value> dynamicIndices;
+ dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
+ result.addAttributes(attrs);
+ build(b, result, resultType, tensor, dynamicIndices,
+ b.getDenseI64ArrayAttr(staticIndices));
+}
+
+// Build an ExtractOp with dynamic indexes and inferred result type.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
+ Value tensor, ValueRange indices,
+ ArrayRef<NamedAttribute> attrs) {
+ SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+ llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+ build(b, result, resultType, tensor, indicesValues, attrs);
+}
+
+// Build an ExtractOp with dynamic indexes.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
+ ValueRange indices, ArrayRef<NamedAttribute> attrs) {
+ Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
+ SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+ llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+ build(b, result, resultType, tensor, indicesValues, attrs);
+}
+
LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
- auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
- if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
+ if (failed(checkTensorRankMatchIndices(getTensor(), getIndices(),
+ getStaticIndices())))
return emitOpError("incorrect number of indices for extract_element");
return success();
}
@@ -1137,12 +1192,18 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
- for (Attribute indice : adaptor.getIndices()) {
- if (!indice || !llvm::isa<IntegerAttr>(indice))
- return {};
- indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
+ auto dynamicIndicesIt = adaptor.getIndices().begin();
+ for (int64_t i : getStaticIndices()) {
+ if (i != ShapedType::kDynamic) {
+ indices.push_back(i);
+ } else {
+ Attribute indice = *dynamicIndicesIt;
+ if (!indice || !llvm::isa<IntegerAttr>(indice))
+ return {};
+ indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
+ dynamicIndicesIt++;
+ }
}
-
// Fold extract(from_elements(...)).
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
@@ -1354,10 +1415,48 @@ void InsertOp::getAsmResultNames(
setNameFn(getResult(), "inserted");
}
+// Build an ExtractOp with mixed static and dynamic indexes.
+void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
+ Value dest, ArrayRef<OpFoldResult> indices,
+ ArrayRef<NamedAttribute> attrs) {
+ build(b, result, dest.getType(), scalar, dest, indices, attrs);
+}
+
+// Build an InsertOp with mixed static, dynamic indexes and inferred result
+// Type.
+void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
+ Value scalar, Value dest, ArrayRef<OpFoldResult> indices,
+ ArrayRef<NamedAttribute> attrs) {
+ SmallVector<int64_t> staticIndices;
+ SmallVector<Value> dynamicIndices;
+ dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
+ result.addAttributes(attrs);
+ build(b, result, resultType, scalar, dest, dynamicIndices,
+ b.getDenseI64ArrayAttr(staticIndices));
+}
+
+// Build an ExtractOp with dynamic indexes and inferred result type.
+void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
+ Value scalar, Value dest, ValueRange indices,
+ ArrayRef<NamedAttribute> attrs) {
+ SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+ llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+ build(b, result, resultType, scalar, dest, indicesValues, attrs);
+}
+
+// Build an InsertOp with dynamic indexes.
+void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
+ Value dest, ValueRange indices,
+ ArrayRef<NamedAttribute> attrs) {
+ SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+ llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+ build(b, result, dest.getType(), scalar, dest, indicesValues, attrs);
+}
+
LogicalResult InsertOp::verify() {
// Verify the # indices match if we have a ranked type.
- auto destType = llvm::cast<RankedTensorType>(getDest().getType());
- if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
+ if (failed(checkTensorRankMatchIndices(getDest(), getIndices(),
+ getStaticIndices())))
return emitOpError("incorrect number of indices");
return success();
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 5b98a7790debf2..8c04e574dbc518 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1519,6 +1519,19 @@ func.func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
return %result : index
}
+// -----
+
+// CHECK-LABEL: func @extract_shapeof_static_indice
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf64>
+func.func @extract_shapeof_static_indice(%arg0 : tensor<?x?xf64>) -> index {
+// CHECK: %[[C1:.*]] = arith.constant 1
+ %shape = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
+// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+ %result = tensor.extract %shape[1] : tensor<2xindex>
+// CHECK: return %[[DIM]]
+ return %result : index
+}
+
// -----
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4b8efde78cc23c..8f7c7478669b4f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -137,11 +137,12 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
// -----
// CHECK-LABEL: func @fold_extract
-func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
+func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, i32, complex<f32>) {
%const_0 = arith.constant 0 : index
%const_1 = arith.constant 1 : index
%const_3 = arith.constant 3 : index
// CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
+ // CHECK-DAG: [[CNEG1:%.+]] = arith.constant -1 : i32
// CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
// CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
@@ -162,13 +163,16 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
%ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
+ // Fold an extract into a dense tensor with mixed dynamic and static indexes.
+ %ext_5 = tensor.extract %3[%const_1, 0, 2] : tensor<2x1x4xi32>
+
// Fold an extract into a complex constant.
// CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
%4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
- %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
+ %ext_6 = tensor.extract %4[] : tensor<complex<f32>>
- // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
- return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+ // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[CNEG1]], [[C5]]
+ return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6: f32, f16, f16, i32, i32, complex<f32>
}
// -----
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 41b6529f64afa3..8c594ddacb8d33 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -64,7 +64,7 @@ func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
// -----
-func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
+func.func @extract_too_few_indices(%arg0: tensor<?xf32>) {
// expected-error at +1 {{incorrect number of indices for extract_element}}
%0 = tensor.extract %arg0[] : tensor<?xf32>
return
@@ -72,7 +72,24 @@ func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
// -----
-func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+func.func @extract_too_many_static_indices(%arg0: tensor<?xf32>) {
+ // expected-error at +1 {{incorrect number of indices for extract_element}}
+ %0 = tensor.extract %arg0[2, 3] : tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @extract_too_many_mixed_indices(%arg0: tensor<?xf32>) {
+ %c1 = arith.constant 1 : index
+ // expected-error at +1 {{incorrect number of indices for extract_element}}
+ %0 = tensor.extract %arg0[%c1, 2, 3] : tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @insert_too_few_indices(%arg0: f32, %arg1: tensor<?xf32>) {
// expected-error at +1 {{incorrect number of indices}}
%0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
return
@@ -80,6 +97,23 @@ func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
// -----
+func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+ // expected-error at +1 {{incorrect number of indices}}
+ %0 = tensor.insert %arg0 into %arg1[2, 3] : tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @insert_too_many_mixed_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+ %c1 = arith.constant 1 : index
+ // expected-error at +1 {{incorrect number of indices}}
+ %0 = tensor.insert %arg0 into %arg1[%c1, 2, 3] : tensor<?xf32>
+ return
+}
+
+// -----
+
func.func @tensor.from_elements_wrong_result_type() {
// expected-error at +2 {{'tensor.from_elements' invalid kind of type specified}}
%c0 = arith.constant 0 : i32
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 378137a14b59ff..0a4cd08239c5b4 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -58,6 +58,9 @@ func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> {
func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
// CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
%0 = tensor.extract %arg0[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
+
+ // CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], 2, 3] : tensor<?x?x?xf32>
+ %1 = tensor.extract %arg0[%arg1, 2, 3] : tensor<?x?x?xf32>
return
}
@@ -70,6 +73,9 @@ func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>) {
// CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
%0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
+
+ // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], 2, 3] : tensor<?x?x?xf32>
+ %1 = tensor.insert %arg0 into %arg2[%arg1, 2, 3] : tensor<?x?x?xf32>
return
}
More information about the Mlir-commits
mailing list