[Mlir-commits] [mlir] 864adf3 - [mlir] Allow empty position in vector.insert and vector.extract

Matthias Springer llvmlistbot at llvm.org
Wed May 12 20:55:41 PDT 2021


Author: Matthias Springer
Date: 2021-05-13T12:54:18+09:00
New Revision: 864adf399e58a6bfd823136fc2cbcfe9dff5b4a8

URL: https://github.com/llvm/llvm-project/commit/864adf399e58a6bfd823136fc2cbcfe9dff5b4a8
DIFF: https://github.com/llvm/llvm-project/commit/864adf399e58a6bfd823136fc2cbcfe9dff5b4a8.diff

LOG: [mlir] Allow empty position in vector.insert and vector.extract

Such ops are no-ops and are folded to their respective `source`/`vector` operand.

Differential Revision: https://reviews.llvm.org/D101879

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 45c0ccaa0928..6c621f93c024 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -764,7 +764,9 @@ def Vector_InsertOp :
       return dest().getType().cast<VectorType>();
     }
   }];
+
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def Vector_InsertSlicesOp :

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9ecee857e2e5..9db34d7411e9 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -656,6 +656,12 @@ class VectorExtractOpConversion
     if (!llvmResultType)
       return failure();
 
+    // Extract entire vector. Should be handled by folder, but just to be safe.
+    if (positionArrayAttr.empty()) {
+      rewriter.replaceOp(extractOp, adaptor.vector());
+      return success();
+    }
+
     // One-shot extraction of vector from array (only requires extractvalue).
     if (resultType.isa<VectorType>()) {
       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -762,6 +768,13 @@ class VectorInsertOpConversion
     if (!llvmResultType)
       return failure();
 
+    // Overwrite entire vector with value. Should be handled by folder, but
+    // just to be safe.
+    if (positionArrayAttr.empty()) {
+      rewriter.replaceOp(insertOp, adaptor.source());
+      return success();
+    }
+
     // One-shot insertion of a vector into an array (only requires insertvalue).
     if (sourceType.isa<VectorType>()) {
       Value inserted = rewriter.create<LLVM::InsertValueOp>(

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index c86817cdc3ab..f24b9171203a 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -872,8 +872,6 @@ static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
 
 static LogicalResult verify(vector::ExtractOp op) {
   auto positionAttr = op.position().getValue();
-  if (positionAttr.empty())
-    return op.emitOpError("expected non-empty position attribute");
   if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
     return op.emitOpError(
         "expected position attribute of rank smaller than vector rank");
@@ -1151,6 +1149,8 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
 }
 
 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
+  if (position().empty())
+    return vector();
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
   if (succeeded(foldExtractOpFromTranspose(*this)))
@@ -1557,8 +1557,6 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
 
 static LogicalResult verify(InsertOp op) {
   auto positionAttr = op.position().getValue();
-  if (positionAttr.empty())
-    return op.emitOpError("expected non-empty position attribute");
   auto destVectorType = op.getDestVectorType();
   if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
     return op.emitOpError(
@@ -1612,6 +1610,15 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<InsertToShapeCast>(context);
 }
 
+// Eliminates insert operations that produce values identical to their source
+// value. This happens when the source and destination vectors have identical
+// sizes.
+OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
+  if (position().empty())
+    return source();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // InsertSlicesOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 545a3ac8c463..e06380df3f66 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -80,13 +80,6 @@ func @extract_vector_type(%arg0: index) {
 
 // -----
 
-func @extract_position_empty(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected non-empty position attribute}}
-  %1 = vector.extract %arg0[] : vector<4x8x16xf32>
-}
-
-// -----
-
 func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute of rank smaller than vector}}
   %1 = vector.extract %arg0[0, 0, 0, 0] : vector<4x8x16xf32>
@@ -138,13 +131,6 @@ func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
 
 // -----
 
-func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected non-empty position attribute}}
-  %1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32>
-}
-
-// -----
-
 func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute of rank smaller than dest vector rank}}
   %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c3bb8fffbb1a..8beff28ef8a0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -158,14 +158,16 @@ func @extract_element(%a: vector<16xf32>) -> f32 {
 }
 
 // CHECK-LABEL: @extract
-func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
+func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) {
+  // CHECK: vector.extract {{.*}}[] : vector<4x8x16xf32>
+  %0 = vector.extract %arg0[] : vector<4x8x16xf32>
   // CHECK: vector.extract {{.*}}[3] : vector<4x8x16xf32>
   %1 = vector.extract %arg0[3] : vector<4x8x16xf32>
   // CHECK-NEXT: vector.extract {{.*}}[3, 3] : vector<4x8x16xf32>
   %2 = vector.extract %arg0[3, 3] : vector<4x8x16xf32>
   // CHECK-NEXT: vector.extract {{.*}}[3, 3, 3] : vector<4x8x16xf32>
   %3 = vector.extract %arg0[3, 3, 3] : vector<4x8x16xf32>
-  return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
+  return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
 }
 
 // CHECK-LABEL: @insert_element
@@ -185,7 +187,9 @@ func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8
   %2 = vector.insert %b, %res[3, 3] : vector<16xf32> into vector<4x8x16xf32>
   // CHECK: vector.insert %{{.*}}, %{{.*}}[3, 3, 3] : f32 into vector<4x8x16xf32>
   %3 = vector.insert %a, %res[3, 3, 3] : f32 into vector<4x8x16xf32>
-  return %3 : vector<4x8x16xf32>
+  // CHECK: vector.insert %{{.*}}, %{{.*}}[] : vector<4x8x16xf32> into vector<4x8x16xf32>
+  %4 = vector.insert %3, %3[] : vector<4x8x16xf32> into vector<4x8x16xf32>
+  return %4 : vector<4x8x16xf32>
 }
 
 // CHECK-LABEL: @outerproduct


        


More information about the Mlir-commits mailing list