[Mlir-commits] [mlir] [mlir] add tensor_static.extract/insert to take only static indices. (PR #110550)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 30 11:30:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Yi Zhang (cathyzhyi)

<details>
<summary>Changes</summary>



---

Patch is 23.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110550.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+81) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+149-30) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+136) 
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+39) 
- (modified) mlir/test/Dialect/Tensor/ops.mlir (+21) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..8fcc413edf2725 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -344,6 +344,43 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ExtractStaticOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ExtractStaticOp : Tensor_Op<"extract_static", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    Pure,
+    TypesMatchWith<"result type matches element type of tensor",
+                   "tensor", "result",
+                   "::llvm::cast<TensorType>($_self).getElementType()">]> {
+  let summary = "element extraction operation with static indices";
+  let description = [{
+    The same as `tensor.extract` op except that `tensor.extract_static` op only
+    takes static indices.
+
+    Example:
+
+    ```mlir
+    %4 = tensor.extract_static %t[1, 2] : tensor<4x4xi32>
+    %5 = tensor.extract_static %rt[1, 2] : tensor<?x?xi32>
+    ```
+  }];
+
+  let arguments = (ins
+    AnyRankedTensor:$tensor,
+    DenseI64ArrayAttr:$static_indices
+  );
+
+  let results = (outs AnyType:$result);
+  let assemblyFormat = [{$tensor `` $static_indices attr-dict `:` type($tensor)}];
+
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
+
 
 //===----------------------------------------------------------------------===//
 // ExtractSliceOp
@@ -822,6 +859,50 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// InsertStaticOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_InsertStaticOp : Tensor_Op<"insert_static", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DestinationStyleOpInterface,
+    Pure,
+    TypesMatchWith<"result type matches type of dest",
+                   "dest", "result",
+                   "$_self">,
+    TypesMatchWith<"scalar type matches element type of dest",
+                   "dest", "scalar",
+                   "::llvm::cast<TensorType>($_self).getElementType()">]> {
+  let summary = "element insertion operation with static indices";
+  let description = [{
+    The same as `tensor.insert` op except that `tensor.insert_static` op only
+    takes static indices.
+
+    Example:
+
+    ```mlir
+    %4 = tensor.insert_static %t into %dest[1, 2] : tensor<4x4xi32>
+    %5 = tensor.insert_static %rt into %dest[1, 2] : tensor<?x?xi32>
+    ```
+  }];
+
+  let arguments = (ins AnyType:$scalar,
+                       AnyRankedTensor:$dest,
+                       DenseI64ArrayAttr:$static_indices);
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    $scalar `into` $dest `` $static_indices attr-dict `:` type($dest)
+  }];
+
+  let extraClassDeclaration = [{
+    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+  }];
+
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
+
 //===----------------------------------------------------------------------===//
 // InsertSliceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1ac96756e22b5e..26d4434a484d61 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -39,6 +39,59 @@ using llvm::divideCeilSigned;
 using llvm::divideFloorSigned;
 using llvm::mod;
 
+namespace {
+template <typename ExtractOpTy>
+OpFoldResult foldExtractFromElementsHelper(ExtractOpTy op,
+                                           FromElementsOp fromElementsOp,
+                                           ArrayRef<uint64_t> indices) {
+  // Fold extract(from_elements(...)).
+  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
+  auto rank = tensorType.getRank();
+  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
+         "rank mismatch");
+  int flatIndex = 0;
+  int stride = 1;
+  for (int i = rank - 1; i >= 0; --i) {
+    flatIndex += indices[i] * stride;
+    stride *= tensorType.getDimSize(i);
+  }
+  // Prevent out of bounds accesses. This can happen in invalid code that
+  // will never execute.
+  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
+      flatIndex < 0)
+    return {};
+  return fromElementsOp.getElements()[flatIndex];
+}
+
+LogicalResult verifyStaticIndicesInBound(RankedTensorType type,
+                                         ArrayRef<int64_t> indices) {
+  ArrayRef<int64_t> shape = type.getShape();
+  for (auto [dim, index] : llvm::zip(shape, indices)) {
+    if (index < 0)
+      return failure();
+    if (ShapedType::isDynamic(dim))
+      continue;
+    if (index >= dim)
+      return failure();
+  }
+  return success();
+}
+
+template <typename InsertOpTy, typename AdapterTy>
+OpFoldResult insertOpFoldHelper(InsertOpTy insert, AdapterTy adaptor) {
+  Attribute scalar = adaptor.getScalar();
+  Attribute dest = adaptor.getDest();
+  if (scalar && dest) {
+    if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest)) {
+      if (scalar == splatDest.getSplatValue<Attribute>())
+        return dest;
+    }
+  }
+  return {};
+}
+
+} // namespace
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
@@ -1097,18 +1150,28 @@ namespace {
 /// to
 ///
 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
-struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
-  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+template <typename ExtractOpTy>
+struct ExtractFromTensorCast : public OpRewritePattern<ExtractOpTy> {
+  using OpRewritePattern<ExtractOpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+  LogicalResult matchAndRewrite(ExtractOpTy extract,
                                 PatternRewriter &rewriter) const final {
-    auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
+    auto tensorCast =
+        extract.getTensor().template getDefiningOp<tensor::CastOp>();
     if (!tensorCast)
       return failure();
     if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
       return failure();
-    rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
-        extract, tensorCast.getSource(), extract.getIndices());
+    Operation *op = extract;
+    if (auto extractOp = llvm::dyn_cast<tensor::ExtractOp>(op)) {
+      rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+          extractOp, tensorCast.getSource(), extractOp.getIndices());
+    } else if (auto extractStaticOp =
+                   llvm::dyn_cast<tensor::ExtractStaticOp>(op)) {
+      rewriter.replaceOpWithNewOp<tensor::ExtractStaticOp>(
+          extractStaticOp, tensorCast.getSource(),
+          extractStaticOp.getStaticIndices());
+    }
     return success();
   }
 };
@@ -1145,22 +1208,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
 
   // Fold extract(from_elements(...)).
   if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
-    auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
-    auto rank = tensorType.getRank();
-    assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
-           "rank mismatch");
-    int flatIndex = 0;
-    int stride = 1;
-    for (int i = rank - 1; i >= 0; --i) {
-      flatIndex += indices[i] * stride;
-      stride *= tensorType.getDimSize(i);
-    }
-    // Prevent out of bounds accesses. This can happen in invalid code that
-    // will never execute.
-    if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
-        flatIndex < 0)
-      return {};
-    return fromElementsOp.getElements()[flatIndex];
+    return foldExtractFromElementsHelper<ExtractOp>(*this, fromElementsOp,
+                                                    indices);
   }
 
   // If this is an elements attribute, query the value at the given indices.
@@ -1175,7 +1224,56 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractFromTensorCast>(context);
+  results.add<ExtractFromTensorCast<tensor::ExtractOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractStaticOp
+//===----------------------------------------------------------------------===//
+
+void ExtractStaticOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "extracted");
+}
+
+LogicalResult ExtractStaticOp::verify() {
+  // Verify the # indices match if we have a ranked type.
+  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
+  if (tensorType.getRank() != static_cast<int64_t>(getStaticIndices().size()))
+    return emitOpError("incorrect number of indices for extract_static");
+  if (failed(verifyStaticIndicesInBound(tensorType, getStaticIndices())))
+    return emitOpError("static index out of bound for extract_static");
+  return success();
+}
+
+OpFoldResult ExtractStaticOp::fold(FoldAdaptor adaptor) {
+  // If this is a splat elements attribute, simply return the value. All of
+  // the elements of a splat attribute are the same.
+  if (Attribute tensor = adaptor.getTensor()) {
+    if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
+      return splatTensor.getSplatValue<Attribute>();
+  }
+
+  SmallVector<uint64_t, 8> indices(getStaticIndices());
+  // Fold extract(from_elements(...)).
+  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
+    return foldExtractFromElementsHelper<ExtractStaticOp>(*this, fromElementsOp,
+                                                          indices);
+  }
+
+  // If this is an elements attribute, query the value at the given indices.
+  if (Attribute tensor = adaptor.getTensor()) {
+    auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
+    if (elementsAttr && elementsAttr.isValidIndex(indices))
+      return elementsAttr.getValues<Attribute>()[indices];
+  }
+
+  return {};
+}
+
+void ExtractStaticOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                  MLIRContext *context) {
+  results.add<ExtractFromTensorCast<tensor::ExtractStaticOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1368,13 +1466,34 @@ LogicalResult InsertOp::verify() {
 }
 
 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
-  Attribute scalar = adaptor.getScalar();
-  Attribute dest = adaptor.getDest();
-  if (scalar && dest)
-    if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
-      if (scalar == splatDest.getSplatValue<Attribute>())
-        return dest;
-  return {};
+  return insertOpFoldHelper<InsertOp,
+                            InsertOpGenericAdaptor<ArrayRef<Attribute>>>(
+      *this, adaptor);
+}
+
+//===----------------------------------------------------------------------===//
+// InsertStaticOp
+//===----------------------------------------------------------------------===//
+
+void InsertStaticOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "inserted");
+}
+
+LogicalResult InsertStaticOp::verify() {
+  // Verify the # indices match if we have a ranked type.
+  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
+  if (destType.getRank() != static_cast<int64_t>(getStaticIndices().size()))
+    return emitOpError("incorrect number of indices for insert_static");
+  if (failed(verifyStaticIndicesInBound(destType, getStaticIndices())))
+    return emitOpError("static index out of bound for insert_static");
+  return success();
+}
+
+OpFoldResult InsertStaticOp::fold(FoldAdaptor adaptor) {
+  return insertOpFoldHelper<InsertStaticOp,
+                            InsertStaticOpGenericAdaptor<ArrayRef<Attribute>>>(
+      *this, adaptor);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 86754c1c37536d..25b46e7877cbaa 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -173,6 +173,40 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
 
 // -----
 
+// CHECK-LABEL: func @fold_extract_static
+func.func @fold_extract_static() -> (f32, f16, f16, i32, complex<f32>) {
+  // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
+  // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
+  // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
+
+  // Fold an extract into a splat.
+  // CHECK-DAG: [[C4:%.+]] = arith.constant 4.{{0*}}e+00 : f32
+  %0 = arith.constant dense<4.0> : tensor<4xf32>
+  %ext_1 = tensor.extract_static %0[1] : tensor<4xf32>
+
+  // Fold an extract into a sparse with a sparse index.
+  %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]],  [-5.0, -2.0]> : tensor<4x4x4xf16>
+  %ext_2 = tensor.extract_static %1[1, 1, 1] : tensor<4x4x4xf16>
+
+  // Fold an extract into a sparse with a non sparse index.
+  %2 = arith.constant sparse<[[1, 1, 1]],  [-2.0]> : tensor<2x2x2xf16>
+  %ext_3 = tensor.extract_static %2[0, 0, 0] : tensor<2x2x2xf16>
+
+  // Fold an extract into a dense tensor.
+  %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
+  %ext_4 = tensor.extract_static %3[1, 0, 3] : tensor<2x1x4xi32>
+
+  // Fold an extract into a complex constant.
+  // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
+  %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
+  %ext_5 = tensor.extract_static %4[] : tensor<complex<f32>>
+
+  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
+  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_insert
 func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
   // Fold an insert into a splat.
@@ -186,6 +220,19 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @fold_insert_static
+func.func @fold_insert_static() -> (tensor<4xf32>) {
+  // Fold an insert into a splat.
+  // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32>
+  %0 = arith.constant dense<4.0> : tensor<4xf32>
+  %1 = arith.constant 4.0 : f32
+  %ins_1 = tensor.insert_static %1 into %0[3] : tensor<4xf32>
+  // CHECK-NEXT: return %[[C4]]
+  return %ins_1 : tensor<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_from_tensor.cast
 // CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
 func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
@@ -200,6 +247,18 @@ func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @extract_static_from_tensor.cast
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
+func.func @extract_static_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
+  // CHECK-NOT: tensor.cast
+  %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32>
+  // CHECK-NEXT: tensor.extract_static %[[TENSOR]][0]
+  %result = tensor.extract_static %casted[0] : tensor<?xf32>
+  return %result : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_from_tensor.from_elements
 func.func @extract_from_tensor.from_elements(%element : index) -> index {
   // CHECK-SAME: ([[ARG:%.*]]: index)
@@ -212,6 +271,17 @@ func.func @extract_from_tensor.from_elements(%element : index) -> index {
 
 // -----
 
+// CHECK-LABEL: func @extract_static_from_tensor.from_elements
+func.func @extract_static_from_tensor.from_elements(%element : index) -> index {
+  // CHECK-SAME: ([[ARG:%.*]]: index)
+  %tensor = tensor.from_elements %element : tensor<1xindex>
+  %extracted_element = tensor.extract_static %tensor[0] : tensor<1xindex>
+  // CHECK: [[ARG]] : index
+  return %extracted_element : index
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_from_tensor.from_elements_0d
 func.func @extract_from_tensor.from_elements_0d(%element : index) -> index {
   // CHECK-SAME: ([[ARG:%.*]]: index)
@@ -224,6 +294,17 @@ func.func @extract_from_tensor.from_elements_0d(%element : index) -> index {
 
 // -----
 
+// CHECK-LABEL: func @extract_static_from_tensor.from_elements_0d
+func.func @extract_static_from_tensor.from_elements_0d(%element : index) -> index {
+  // CHECK-SAME: ([[ARG:%.*]]: index)
+  %tensor = tensor.from_elements %element : tensor<index>
+  %extracted_element = tensor.extract_static %tensor[] : tensor<index>
+  // CHECK: [[ARG]] : index
+  return %extracted_element : index
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_from_tensor.from_elements_3d
 func.func @extract_from_tensor.from_elements_3d()
     -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
@@ -261,6 +342,61 @@ func.func @extract_from_tensor.from_elements_3d()
   return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
          : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
 }
+
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
+// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
+// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
+// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
+// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
+// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
+// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
+// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
+// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
+// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
+// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
+// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
+
+// CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]],
+// CHECK-SAME:   %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]]
+
+
+// -----
+
+// CHECK-LABEL: func @extract_static_from_tensor.from_elements_3d
+func.func @extract_static_from_tensor.from_elements_3d()
+    -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+  %f0 = arith.constant 0.0 : f32
+  %f1 = arith.constant 1.0 : f32
+  %f2 = arith.constant 2.0 : f32
+  %f3 = arith.constant 3.0 : f32
+  %f4 = arith.constant 4.0 : f32
+  %f5 = arith.constant 5.0 : f32
+  %f6 = arith.constant 6.0 : f32
+  %f7 = arith.constant 7.0 : f32
+  %f8 = arith.constant 8.0 : f32
+  %f9 = arith.constant 9.0 : f32
+  %f10 = arith.constant 10.0 : f32
+  %f11 = arith.constant 11.0 : f32
+
+  %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+         : tensor<3x2x2xf32>
+
+  %r0 = tensor.extract_static %tensor[0, 0, 0] : tensor<3x2x2xf32>
+  %r1 = tensor.extract_static %tensor[0, 0, 1] : tensor<3x2x2xf32>
+  %r2 = tensor.extract_static %tensor[0, 1, 0] : tensor<3x2x2xf32>
+  %r3 = tensor.extract_static %tensor[0, 1, 1] : tensor<3x2x2xf32>
+  %r4 = tensor.extract_static %tensor[1, 0, 0] : tensor<3x2x2xf32>
+  %r5 = tensor.extract_static %tensor[1, 0, 1] : tensor<3x2x2xf32>
+  %r6 = tensor.extract_static %tensor[1, 1, 0] : tensor<3x2x2xf32>
+  %r7 = tensor.extract_static %tensor[1, 1, 1] : tensor<3x2x2xf32>
+  %r8 = tensor.extract_static %tensor[2, 0, 0] : tensor<3x2x2xf32>
+  %r9 = tensor.extract_static %tensor[2, 0, 1] : tensor<3x2x2xf32>
+  %r10 = tensor.extract_static %tensor[2, 1, 0] : tensor<3x2x2xf32>
+  %r11 = tensor.extract_static %tensor[2, 1, 1] : tensor<3x2x2xf32>
+  return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
+         : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
+}
+
 // CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
 // CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
 // CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 84e6c59e403dde..4be9b6a9c87183 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -72,6 +72,22 @@ f...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/110550


More information about the Mlir-commits mailing list