[Mlir-commits] [mlir] [mlir] add tensor_static.extract/insert to take only static indices. (PR #110550)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 30 11:30:57 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Yi Zhang (cathyzhyi)
<details>
<summary>Changes</summary>
---
Patch is 23.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110550.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+81)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+149-30)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+136)
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+39)
- (modified) mlir/test/Dialect/Tensor/ops.mlir (+21)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..8fcc413edf2725 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -344,6 +344,43 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// ExtractStaticOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ExtractStaticOp : Tensor_Op<"extract_static", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ Pure,
+ TypesMatchWith<"result type matches element type of tensor",
+ "tensor", "result",
+ "::llvm::cast<TensorType>($_self).getElementType()">]> {
+ let summary = "element extraction operation with static indices";
+ let description = [{
+ The same as `tensor.extract` op except that `tensor.extract_static` op only
+ takes static indices.
+
+ Example:
+
+ ```mlir
+ %4 = tensor.extract_static %t[1, 2] : tensor<4x4xi32>
+ %5 = tensor.extract_static %rt[1, 2] : tensor<?x?xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ AnyRankedTensor:$tensor,
+ DenseI64ArrayAttr:$static_indices
+ );
+
+ let results = (outs AnyType:$result);
+ let assemblyFormat = [{$tensor `` $static_indices attr-dict `:` type($tensor)}];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
+
//===----------------------------------------------------------------------===//
// ExtractSliceOp
@@ -822,6 +859,50 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// InsertStaticOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_InsertStaticOp : Tensor_Op<"insert_static", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DestinationStyleOpInterface,
+ Pure,
+ TypesMatchWith<"result type matches type of dest",
+ "dest", "result",
+ "$_self">,
+ TypesMatchWith<"scalar type matches element type of dest",
+ "dest", "scalar",
+ "::llvm::cast<TensorType>($_self).getElementType()">]> {
+ let summary = "element insertion operation with static indices";
+ let description = [{
+ The same as `tensor.insert` op except that `tensor.insert_static` op only
+ takes static indices.
+
+ Example:
+
+ ```mlir
+ %4 = tensor.insert_static %t into %dest[1, 2] : tensor<4x4xi32>
+ %5 = tensor.insert_static %rt into %dest[1, 2] : tensor<?x?xi32>
+ ```
+ }];
+
+ let arguments = (ins AnyType:$scalar,
+ AnyRankedTensor:$dest,
+ DenseI64ArrayAttr:$static_indices);
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ $scalar `into` $dest `` $static_indices attr-dict `:` type($dest)
+ }];
+
+ let extraClassDeclaration = [{
+ MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+ }];
+
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
+
//===----------------------------------------------------------------------===//
// InsertSliceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1ac96756e22b5e..26d4434a484d61 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -39,6 +39,59 @@ using llvm::divideCeilSigned;
using llvm::divideFloorSigned;
using llvm::mod;
+namespace {
+template <typename ExtractOpTy>
+OpFoldResult foldExtractFromElementsHelper(ExtractOpTy op,
+ FromElementsOp fromElementsOp,
+ ArrayRef<uint64_t> indices) {
+ // Fold extract(from_elements(...)).
+ auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
+ auto rank = tensorType.getRank();
+ assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
+ "rank mismatch");
+ int flatIndex = 0;
+ int stride = 1;
+ for (int i = rank - 1; i >= 0; --i) {
+ flatIndex += indices[i] * stride;
+ stride *= tensorType.getDimSize(i);
+ }
+ // Prevent out of bounds accesses. This can happen in invalid code that
+ // will never execute.
+ if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
+ flatIndex < 0)
+ return {};
+ return fromElementsOp.getElements()[flatIndex];
+}
+
+LogicalResult verifyStaticIndicesInBound(RankedTensorType type,
+ ArrayRef<int64_t> indices) {
+ ArrayRef<int64_t> shape = type.getShape();
+ for (auto [dim, index] : llvm::zip(shape, indices)) {
+ if (index < 0)
+ return failure();
+ if (ShapedType::isDynamic(dim))
+ continue;
+ if (index >= dim)
+ return failure();
+ }
+ return success();
+}
+
+template <typename InsertOpTy, typename AdapterTy>
+OpFoldResult insertOpFoldHelper(InsertOpTy insert, AdapterTy adaptor) {
+ Attribute scalar = adaptor.getScalar();
+ Attribute dest = adaptor.getDest();
+ if (scalar && dest) {
+ if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest)) {
+ if (scalar == splatDest.getSplatValue<Attribute>())
+ return dest;
+ }
+ }
+ return {};
+}
+
+} // namespace
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
@@ -1097,18 +1150,28 @@ namespace {
/// to
///
/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
-struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+template <typename ExtractOpTy>
+struct ExtractFromTensorCast : public OpRewritePattern<ExtractOpTy> {
+ using OpRewritePattern<ExtractOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+ LogicalResult matchAndRewrite(ExtractOpTy extract,
PatternRewriter &rewriter) const final {
- auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
+ auto tensorCast =
+ extract.getTensor().template getDefiningOp<tensor::CastOp>();
if (!tensorCast)
return failure();
if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
return failure();
- rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
- extract, tensorCast.getSource(), extract.getIndices());
+ Operation *op = extract;
+ if (auto extractOp = llvm::dyn_cast<tensor::ExtractOp>(op)) {
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+ extractOp, tensorCast.getSource(), extractOp.getIndices());
+ } else if (auto extractStaticOp =
+ llvm::dyn_cast<tensor::ExtractStaticOp>(op)) {
+ rewriter.replaceOpWithNewOp<tensor::ExtractStaticOp>(
+ extractStaticOp, tensorCast.getSource(),
+ extractStaticOp.getStaticIndices());
+ }
return success();
}
};
@@ -1145,22 +1208,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold extract(from_elements(...)).
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
- auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
- auto rank = tensorType.getRank();
- assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
- "rank mismatch");
- int flatIndex = 0;
- int stride = 1;
- for (int i = rank - 1; i >= 0; --i) {
- flatIndex += indices[i] * stride;
- stride *= tensorType.getDimSize(i);
- }
- // Prevent out of bounds accesses. This can happen in invalid code that
- // will never execute.
- if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
- flatIndex < 0)
- return {};
- return fromElementsOp.getElements()[flatIndex];
+ return foldExtractFromElementsHelper<ExtractOp>(*this, fromElementsOp,
+ indices);
}
// If this is an elements attribute, query the value at the given indices.
@@ -1175,7 +1224,56 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractFromTensorCast>(context);
+ results.add<ExtractFromTensorCast<tensor::ExtractOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractStaticOp
+//===----------------------------------------------------------------------===//
+
+void ExtractStaticOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "extracted");
+}
+
+LogicalResult ExtractStaticOp::verify() {
+ // Verify the # indices match if we have a ranked type.
+ auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
+ if (tensorType.getRank() != static_cast<int64_t>(getStaticIndices().size()))
+ return emitOpError("incorrect number of indices for extract_static");
+ if (failed(verifyStaticIndicesInBound(tensorType, getStaticIndices())))
+ return emitOpError("static index out of bound for extract_static");
+ return success();
+}
+
+OpFoldResult ExtractStaticOp::fold(FoldAdaptor adaptor) {
+ // If this is a splat elements attribute, simply return the value. All of
+ // the elements of a splat attribute are the same.
+ if (Attribute tensor = adaptor.getTensor()) {
+ if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
+ return splatTensor.getSplatValue<Attribute>();
+ }
+
+ SmallVector<uint64_t, 8> indices(getStaticIndices());
+ // Fold extract(from_elements(...)).
+ if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
+ return foldExtractFromElementsHelper<ExtractStaticOp>(*this, fromElementsOp,
+ indices);
+ }
+
+ // If this is an elements attribute, query the value at the given indices.
+ if (Attribute tensor = adaptor.getTensor()) {
+ auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
+ if (elementsAttr && elementsAttr.isValidIndex(indices))
+ return elementsAttr.getValues<Attribute>()[indices];
+ }
+
+ return {};
+}
+
+void ExtractStaticOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ExtractFromTensorCast<tensor::ExtractStaticOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1368,13 +1466,34 @@ LogicalResult InsertOp::verify() {
}
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
- Attribute scalar = adaptor.getScalar();
- Attribute dest = adaptor.getDest();
- if (scalar && dest)
- if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
- if (scalar == splatDest.getSplatValue<Attribute>())
- return dest;
- return {};
+ return insertOpFoldHelper<InsertOp,
+ InsertOpGenericAdaptor<ArrayRef<Attribute>>>(
+ *this, adaptor);
+}
+
+//===----------------------------------------------------------------------===//
+// InsertStaticOp
+//===----------------------------------------------------------------------===//
+
+void InsertStaticOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "inserted");
+}
+
+LogicalResult InsertStaticOp::verify() {
+ // Verify the # indices match if we have a ranked type.
+ auto destType = llvm::cast<RankedTensorType>(getDest().getType());
+ if (destType.getRank() != static_cast<int64_t>(getStaticIndices().size()))
+ return emitOpError("incorrect number of indices for insert_static");
+ if (failed(verifyStaticIndicesInBound(destType, getStaticIndices())))
+ return emitOpError("static index out of bound for insert_static");
+ return success();
+}
+
+OpFoldResult InsertStaticOp::fold(FoldAdaptor adaptor) {
+ return insertOpFoldHelper<InsertStaticOp,
+ InsertStaticOpGenericAdaptor<ArrayRef<Attribute>>>(
+ *this, adaptor);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 86754c1c37536d..25b46e7877cbaa 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -173,6 +173,40 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
// -----
+// CHECK-LABEL: func @fold_extract_static
+func.func @fold_extract_static() -> (f32, f16, f16, i32, complex<f32>) {
+ // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
+ // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
+ // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
+
+ // Fold an extract into a splat.
+ // CHECK-DAG: [[C4:%.+]] = arith.constant 4.{{0*}}e+00 : f32
+ %0 = arith.constant dense<4.0> : tensor<4xf32>
+ %ext_1 = tensor.extract_static %0[1] : tensor<4xf32>
+
+ // Fold an extract into a sparse with a sparse index.
+ %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : tensor<4x4x4xf16>
+ %ext_2 = tensor.extract_static %1[1, 1, 1] : tensor<4x4x4xf16>
+
+ // Fold an extract into a sparse with a non sparse index.
+ %2 = arith.constant sparse<[[1, 1, 1]], [-2.0]> : tensor<2x2x2xf16>
+ %ext_3 = tensor.extract_static %2[0, 0, 0] : tensor<2x2x2xf16>
+
+ // Fold an extract into a dense tensor.
+ %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
+ %ext_4 = tensor.extract_static %3[1, 0, 3] : tensor<2x1x4xi32>
+
+ // Fold an extract into a complex constant.
+ // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
+ %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
+ %ext_5 = tensor.extract_static %4[] : tensor<complex<f32>>
+
+ // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
+ return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_insert
func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
// Fold an insert into a splat.
@@ -186,6 +220,19 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
// -----
+// CHECK-LABEL: func @fold_insert_static
+func.func @fold_insert_static() -> (tensor<4xf32>) {
+ // Fold an insert into a splat.
+ // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32>
+ %0 = arith.constant dense<4.0> : tensor<4xf32>
+ %1 = arith.constant 4.0 : f32
+ %ins_1 = tensor.insert_static %1 into %0[3] : tensor<4xf32>
+ // CHECK-NEXT: return %[[C4]]
+ return %ins_1 : tensor<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.cast
// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
@@ -200,6 +247,18 @@ func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
// -----
+// CHECK-LABEL: func @extract_static_from_tensor.cast
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
+func.func @extract_static_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
+ // CHECK-NOT: tensor.cast
+ %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32>
+ // CHECK-NEXT: tensor.extract_static %[[TENSOR]][0]
+ %result = tensor.extract_static %casted[0] : tensor<?xf32>
+ return %result : f32
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.from_elements
func.func @extract_from_tensor.from_elements(%element : index) -> index {
// CHECK-SAME: ([[ARG:%.*]]: index)
@@ -212,6 +271,17 @@ func.func @extract_from_tensor.from_elements(%element : index) -> index {
// -----
+// CHECK-LABEL: func @extract_static_from_tensor.from_elements
+func.func @extract_static_from_tensor.from_elements(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %tensor = tensor.from_elements %element : tensor<1xindex>
+ %extracted_element = tensor.extract_static %tensor[0] : tensor<1xindex>
+ // CHECK: [[ARG]] : index
+ return %extracted_element : index
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.from_elements_0d
func.func @extract_from_tensor.from_elements_0d(%element : index) -> index {
// CHECK-SAME: ([[ARG:%.*]]: index)
@@ -224,6 +294,17 @@ func.func @extract_from_tensor.from_elements_0d(%element : index) -> index {
// -----
+// CHECK-LABEL: func @extract_static_from_tensor.from_elements_0d
+func.func @extract_static_from_tensor.from_elements_0d(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %tensor = tensor.from_elements %element : tensor<index>
+ %extracted_element = tensor.extract_static %tensor[] : tensor<index>
+ // CHECK: [[ARG]] : index
+ return %extracted_element : index
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.from_elements_3d
func.func @extract_from_tensor.from_elements_3d()
-> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
@@ -261,6 +342,61 @@ func.func @extract_from_tensor.from_elements_3d()
return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
: f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
}
+
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
+// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
+// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
+// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
+// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
+// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
+// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
+// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
+// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
+// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
+// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
+// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
+
+// CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]],
+// CHECK-SAME: %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]]
+
+
+// -----
+
+// CHECK-LABEL: func @extract_static_from_tensor.from_elements_3d
+func.func @extract_static_from_tensor.from_elements_3d()
+ -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+ %f0 = arith.constant 0.0 : f32
+ %f1 = arith.constant 1.0 : f32
+ %f2 = arith.constant 2.0 : f32
+ %f3 = arith.constant 3.0 : f32
+ %f4 = arith.constant 4.0 : f32
+ %f5 = arith.constant 5.0 : f32
+ %f6 = arith.constant 6.0 : f32
+ %f7 = arith.constant 7.0 : f32
+ %f8 = arith.constant 8.0 : f32
+ %f9 = arith.constant 9.0 : f32
+ %f10 = arith.constant 10.0 : f32
+ %f11 = arith.constant 11.0 : f32
+
+ %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+ : tensor<3x2x2xf32>
+
+ %r0 = tensor.extract_static %tensor[0, 0, 0] : tensor<3x2x2xf32>
+ %r1 = tensor.extract_static %tensor[0, 0, 1] : tensor<3x2x2xf32>
+ %r2 = tensor.extract_static %tensor[0, 1, 0] : tensor<3x2x2xf32>
+ %r3 = tensor.extract_static %tensor[0, 1, 1] : tensor<3x2x2xf32>
+ %r4 = tensor.extract_static %tensor[1, 0, 0] : tensor<3x2x2xf32>
+ %r5 = tensor.extract_static %tensor[1, 0, 1] : tensor<3x2x2xf32>
+ %r6 = tensor.extract_static %tensor[1, 1, 0] : tensor<3x2x2xf32>
+ %r7 = tensor.extract_static %tensor[1, 1, 1] : tensor<3x2x2xf32>
+ %r8 = tensor.extract_static %tensor[2, 0, 0] : tensor<3x2x2xf32>
+ %r9 = tensor.extract_static %tensor[2, 0, 1] : tensor<3x2x2xf32>
+ %r10 = tensor.extract_static %tensor[2, 1, 0] : tensor<3x2x2xf32>
+ %r11 = tensor.extract_static %tensor[2, 1, 1] : tensor<3x2x2xf32>
+ return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
+ : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
+}
+
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 84e6c59e403dde..4be9b6a9c87183 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -72,6 +72,22 @@ f...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/110550
More information about the Mlir-commits
mailing list