[Mlir-commits] [mlir] 5a8a159 - Add verifier for insert/extract element/value on type match between container and inserted/extracted value, and fix vector.shuffle lowering

Mehdi Amini llvmlistbot at llvm.org
Wed Jul 21 15:35:03 PDT 2021


Author: Mehdi Amini
Date: 2021-07-21T22:28:59Z
New Revision: 5a8a159bf5279455b4688f80a6864ca8f37f4b4e

URL: https://github.com/llvm/llvm-project/commit/5a8a159bf5279455b4688f80a6864ca8f37f4b4e
DIFF: https://github.com/llvm/llvm-project/commit/5a8a159bf5279455b4688f80a6864ca8f37f4b4e.diff

LOG: Add verifier for insert/extract element/value on type match between container and inserted/extracted value, and fix vector.shuffle lowering

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D106398

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 7d8c8a5a9600..6d09c50c92a8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -522,6 +522,7 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
   let builders = [
     OpBuilder<(ins "Value":$vector, "Value":$position,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
+  let verifier = [{ return ::verify(*this); }];
   let parser = [{ return parseExtractElementOp(parser, result); }];
   let printer = [{ printExtractElementOp(p, *this); }];
 }
@@ -532,6 +533,7 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
     $res = builder.CreateExtractValue($container, extractPosition($position));
   }];
   let builders = [LLVM_OneResultOpBuilder];
+  let verifier = [{ return ::verify(*this); }];
   let parser = [{ return parseExtractValueOp(parser, result); }];
   let printer = [{ printExtractValueOp(p, *this); }];
   let hasFolder = 1;
@@ -544,6 +546,7 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
     $res = builder.CreateInsertElement($vector, $value, $position);
   }];
   let builders = [LLVM_OneResultOpBuilder];
+  let verifier = [{ return ::verify(*this);  }];
   let parser = [{ return parseInsertElementOp(parser, result); }];
   let printer = [{ printInsertElementOp(p, *this); }];
 }
@@ -560,6 +563,7 @@ def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
     [{
       build($_builder, $_state, container.getType(), container, value, position);
     }]>];
+  let verifier = [{ return ::verify(*this);  }];
   let parser = [{ return parseInsertValueOp(parser, result); }];
   let printer = [{ printInsertValueOp(p, *this); }];
 }

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 53ce5ca3d452..3b474113b51e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -541,6 +541,12 @@ class VectorShuffleOpConversion
     }
 
     // For all other cases, insert the individual values individually.
+    Type eltType;
+    llvm::errs() << llvmType << "\n";
+    if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
+      eltType = arrayType.getElementType();
+    else
+      eltType = llvmType.cast<VectorType>().getElementType();
     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
     int64_t insPos = 0;
     for (auto en : llvm::enumerate(maskArrayAttr)) {
@@ -551,7 +557,7 @@ class VectorShuffleOpConversion
         value = adaptor.v2();
       }
       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
-                                 llvmType, rank, extPos);
+                                 eltType, rank, extPos);
       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
                          llvmType, rank, insPos++);
     }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 4a8e0b9db0ee..5b9625b62a02 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -970,6 +970,20 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
   return success();
 }
 
+static LogicalResult verify(ExtractElementOp op) {
+  Type vectorType = op.vector().getType();
+  if (!LLVM::isCompatibleVectorType(vectorType))
+    return op->emitOpError("expected LLVM dialect-compatible vector type for "
+                           "operand #1, got")
+           << vectorType;
+  Type valueType = LLVM::getVectorElementType(vectorType);
+  if (valueType != op.res().getType())
+    return op.emitOpError() << "Type mismatch: extracting from " << vectorType
+                            << " should produce " << valueType
+                            << " but this op returns " << op.res().getType();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Printing/parsing for LLVM::ExtractValueOp.
 //===----------------------------------------------------------------------===//
@@ -1024,6 +1038,52 @@ static Type getInsertExtractValueElementType(OpAsmParser &parser,
   return llvmType;
 }
 
+// Extract the type at `position` in the wrapped LLVM IR aggregate type
+// `containerType`. Returns null on failure.
+static Type getInsertExtractValueElementType(Type containerType,
+                                             ArrayAttr positionAttr,
+                                             Operation *op) {
+  Type llvmType = containerType;
+  if (!isCompatibleType(containerType)) {
+    op->emitError("expected LLVM IR Dialect type, got ") << containerType;
+    return {};
+  }
+
+  // Infer the element type from the structure type: iteratively step inside the
+  // type by taking the element type, indexed by the position attribute for
+  // structures.  Check the position index before accessing, it is supposed to
+  // be in bounds.
+  for (Attribute subAttr : positionAttr) {
+    auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
+    if (!positionElementAttr) {
+      op->emitOpError("expected an array of integer literals, got: ")
+          << subAttr;
+      return {};
+    }
+    int position = positionElementAttr.getInt();
+    if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
+      if (position < 0 ||
+          static_cast<unsigned>(position) >= arrayType.getNumElements()) {
+        op->emitOpError("position out of bounds: ") << position;
+        return {};
+      }
+      llvmType = arrayType.getElementType();
+    } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
+      if (position < 0 ||
+          static_cast<unsigned>(position) >= structType.getBody().size()) {
+        op->emitOpError("position out of bounds") << position;
+        return {};
+      }
+      llvmType = structType.getBody()[position];
+    } else {
+      op->emitOpError("expected LLVM IR structure/array type, got: ")
+          << llvmType;
+      return {};
+    }
+  }
+  return llvmType;
+}
+
 // <operation> ::= `llvm.extractvalue` ssa-use
 //                 `[` integer-literal (`,` integer-literal)* `]`
 //                 attribute-dict? `:` type
@@ -1062,6 +1122,20 @@ OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+static LogicalResult verify(ExtractValueOp op) {
+  Type valueType = getInsertExtractValueElementType(op.container().getType(),
+                                                    op.positionAttr(), op);
+  if (!valueType)
+    return failure();
+
+  if (op.res().getType() != valueType)
+    return op.emitOpError()
+           << "Type mismatch: extracting from " << op.container().getType()
+           << " should produce " << valueType << " but this op returns "
+           << op.res().getType();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Printing/parsing for LLVM::InsertElementOp.
 //===----------------------------------------------------------------------===//
@@ -1104,6 +1178,14 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
   return success();
 }
 
+static LogicalResult verify(InsertElementOp op) {
+  Type valueType = LLVM::getVectorElementType(op.vector().getType());
+  if (valueType != op.value().getType())
+    return op.emitOpError()
+           << "Type mismatch: cannot insert " << op.value().getType()
+           << " into " << op.vector().getType();
+  return success();
+}
 //===----------------------------------------------------------------------===//
 // Printing/parsing for LLVM::InsertValueOp.
 //===----------------------------------------------------------------------===//
@@ -1147,6 +1229,20 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
   return success();
 }
 
+static LogicalResult verify(InsertValueOp op) {
+  Type valueType = getInsertExtractValueElementType(op.container().getType(),
+                                                    op.positionAttr(), op);
+  if (!valueType)
+    return failure();
+
+  if (op.value().getType() != valueType)
+    return op.emitOpError()
+           << "Type mismatch: cannot insert " << op.value().getType()
+           << " into " << op.container().getType();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Printing, parsing and verification for LLVM::ReturnOp.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index a28218bdeddb..19b80b86c43f 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -355,6 +355,24 @@ func @insertvalue_wrong_nesting() {
   llvm.insertvalue %a, %b[0,0] : !llvm.struct<(i32)>
 }
 
+// -----
+
+func @insertvalue_invalid_type(%a : !llvm.array<1 x i32>) -> !llvm.array<1 x i32> {
+  // expected-error at +1 {{'llvm.insertvalue' op Type mismatch: cannot insert '!llvm.array<1 x i32>' into '!llvm.array<1 x i32>'}}
+  %b = "llvm.insertvalue"(%a, %a) {position = [0]} : (!llvm.array<1 x i32>, !llvm.array<1 x i32>) -> !llvm.array<1 x i32>
+  return %b : !llvm.array<1 x i32>
+}
+
+// -----
+
+func @extractvalue_invalid_type(%a : !llvm.array<4 x vector<8xf32>>) -> !llvm.array<4 x vector<8xf32>> {
+  // expected-error at +1 {{'llvm.extractvalue' op Type mismatch: extracting from '!llvm.array<4 x vector<8xf32>>' should produce 'vector<8xf32>' but this op returns '!llvm.array<4 x vector<8xf32>>'}}
+  %b = "llvm.extractvalue"(%a) {position = [1]}
+            : (!llvm.array<4 x vector<8xf32>>) -> !llvm.array<4 x vector<8xf32>>
+  return %b : !llvm.array<4 x vector<8xf32>>
+}
+
+
 // -----
 
 func @extractvalue_non_llvm_type(%a : i32, %b : tensor<*xi32>) {
@@ -422,6 +440,22 @@ func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
 
 // -----
 
+func @invalid_vector_type_4(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> {
+  // expected-error at +1 {{'llvm.extractelement' op Type mismatch: extracting from 'vector<4xf32>' should produce 'f32' but this op returns 'vector<4xf32>'}}
+  %b = "llvm.extractelement"(%a, %idx) : (vector<4xf32>, i32) -> vector<4xf32>
+  return %b : vector<4xf32>
+}
+
+// -----
+
+func @invalid_vector_type_5(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> {
+  // expected-error at +1 {{'llvm.insertelement' op Type mismatch: cannot insert 'vector<4xf32>' into 'vector<4xf32>'}}
+  %b = "llvm.insertelement"(%a, %a, %idx) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32>
+  return %b : vector<4xf32>
+}
+
+// -----
+
 func @null_non_llvm_type() {
   // expected-error at +1 {{must be LLVM pointer type, but got 'i32'}}
   llvm.mlir.null : i32


        


More information about the Mlir-commits mailing list