[Mlir-commits] [mlir] cfaf329 - [mlir][tensor] Disallow unranked tensors for tensor.extract/insert

Matthias Springer llvmlistbot at llvm.org
Thu Oct 27 01:09:39 PDT 2022


Author: Matthias Springer
Date: 2022-10-27T10:09:31+02:00
New Revision: cfaf3292df51090d03c1f98a95668246006813e1

URL: https://github.com/llvm/llvm-project/commit/cfaf3292df51090d03c1f98a95668246006813e1
DIFF: https://github.com/llvm/llvm-project/commit/cfaf3292df51090d03c1f98a95668246006813e1.diff

LOG: [mlir][tensor] Disallow unranked tensors for tensor.extract/insert

When writing a tensor.extract/tensor.insert, the rank of the tensor is implied by the number of specified indices. When extracting from/inserting into an unranked tensor, it should first be casted to a ranked version.

Differential Revision: https://reviews.llvm.org/D136756

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index f57852bc6c30b..86f3c06fbaccb 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -216,23 +216,20 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
                    "$_self.cast<ShapedType>().getElementType()">]> {
   let summary = "element extraction operation";
   let description = [{
-    The `tensor.extract` op reads a tensor and returns one
-    element from it specified by an index list. The output of the op is a
-    new value with the same type as the elements of the tensor. The
-    arity of indices must match the rank of the accessed value (i.e., if a
-    tensor is of rank 3, then 3 indices are required for the extract. The
-    indices should all be of `index` type.
+    The `tensor.extract` op reads a ranked tensor and returns one element as
+    specified by the given indices. The result of the op is a value with the
+    same type as the elements of the tensor. The arity of indices must match
+    the rank of the accessed value. All indices should all be of `index` type.
 
     Example:
 
     ```mlir
     %4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
     %5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
-    %6 = tensor.extract %ut[%1, %2] : tensor<*xi32>
     ```
   }];
 
-  let arguments = (ins AnyTensor:$tensor, Variadic<Index>:$indices);
+  let arguments = (ins AnyRankedTensor:$tensor, Variadic<Index>:$indices);
   let results = (outs AnyType:$result);
   let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
 
@@ -242,6 +239,7 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
       build($_builder, $_state, resType, tensor, indices);
     }]>];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }
@@ -684,35 +682,33 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
     Pure,
     TypesMatchWith<"result type matches type of dest",
                    "dest", "result",
-                   "$_self.cast<ShapedType>()">,
+                   "$_self">,
     TypesMatchWith<"scalar type matches element type of dest",
                    "dest", "scalar",
                    "$_self.cast<ShapedType>().getElementType()">]> {
   let summary = "element insertion operation";
   let description = [{
-    The `tensor.insert` op writes a tensor into a tensor `dest`as specified by
-    the operation's indices.
+    The `tensor.insert` op inserts a scalar into a ranked tensor `dest` as
+    specified by the operation's indices.
 
-    It returns a copy of `dest` with the proper slice updated with the value
+    It returns a copy of `dest` with the indexed position updated to the value
     of `scalar`.
 
-    The arity of indices must match the rank of the tensor `dest` (i.e., if a
-    tensor is of rank 3, then 3 indices are required for the extract. The
-    indices should all be of `index` type.
+    The arity of `indices `must match the rank of the tensor `dest`. All
+    indices should be of `index` type.
 
     Example:
 
     ```mlir
     %4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32>
     %5 = tensor.insert %rt into %dest[%1, %2] : tensor<?x?xi32>
-    %6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32>
     ```
   }];
 
   let arguments = (ins AnyType:$scalar,
-                       AnyTensor:$dest,
+                       AnyRankedTensor:$dest,
                        Variadic<Index>:$indices);
-  let results = (outs AnyTensor:$result);
+  let results = (outs AnyRankedTensor:$result);
   let assemblyFormat = [{
     $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
   }];

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 282a342df1d8a..445e78e295fd1 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -773,6 +773,34 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // ExtractOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
+/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
+///
+/// to
+///
+/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
+struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+                                PatternRewriter &rewriter) const final {
+    auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
+    if (!tensorCast)
+      return failure();
+    if (!tensorCast.getSource().getType().isa<RankedTensorType>())
+      return failure();
+    rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+        extract, tensorCast.getSource(), extract.getIndices());
+    return success();
+  }
+};
+
+} // namespace
+
 void ExtractOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "extracted");
@@ -780,10 +808,9 @@ void ExtractOp::getAsmResultNames(
 
 LogicalResult ExtractOp::verify() {
   // Verify the # indices match if we have a ranked type.
-  if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
-    if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
-      return emitOpError("incorrect number of indices for extract_element");
-
+  auto tensorType = getTensor().getType().cast<RankedTensorType>();
+  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
+    return emitOpError("incorrect number of indices for extract_element");
   return success();
 }
 
@@ -833,6 +860,11 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                            MLIRContext *context) {
+  results.add<ExtractFromTensorCast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // FromElementsOp
 //===----------------------------------------------------------------------===//
@@ -1009,9 +1041,9 @@ void InsertOp::getAsmResultNames(
 
 LogicalResult InsertOp::verify() {
   // Verify the # indices match if we have a ranked type.
-  if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
-    if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
-      return emitOpError("incorrect number of indices");
+  auto destType = getDest().getType().cast<RankedTensorType>();
+  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
+    return emitOpError("incorrect number of indices");
   return success();
 }
 
@@ -1181,36 +1213,12 @@ struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
   }
 };
 
-/// Canonicalizes the pattern of the form
-///
-/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
-/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
-///
-/// to
-///
-/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
-struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
-  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
-                                PatternRewriter &rewriter) const final {
-    auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
-    if (!tensorCast)
-      return failure();
-
-    rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
-        extract, tensorCast.getSource(), extract.getIndices());
-    return success();
-  }
-};
-
 } // namespace
 
 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
-  // TODO: Move extract patterns to tensor::ExtractOp.
-  results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
-              StaticTensorGenerate>(context);
+  // TODO: Move extract pattern to tensor::ExtractOp.
+  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 28bcabc60dbd9..9cddfd88735ab 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -115,12 +115,12 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
 // -----
 
 // CHECK-LABEL: func @extract_from_tensor.cast
-// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
-func.func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
+func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
   // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   // CHECK-NOT: tensor.cast
-  %casted = tensor.cast %tensor : tensor<*xf32> to tensor<?xf32>
+  %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32>
   // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
   %result = tensor.extract %casted[%c0] : tensor<?xf32>
   return %result : f32

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index f9a81e8c490b4..aadf6ab90250d 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -34,12 +34,9 @@ func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
 // CHECK-SAME:                  %[[SCALAR:.*]]: f32
 // CHECK-SAME:                  %[[INDEX:.*]]: index
 // CHECK-SAME:                  %[[DEST1:.*]]: tensor<?x?x?xf32>
-// CHECK-SAME:                  %[[DEST2:.*]]: tensor<*xf32>
-func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>, %arg3: tensor<*xf32>) {
+func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>) {
   // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
   %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
-  // CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32>
-  %1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32>
   return
 }
 


        


More information about the Mlir-commits mailing list