[Mlir-commits] [mlir] a48ff68 - [mlir][LLVMIR] "Modernize" LLVM insert/extract element operations

Jeff Niu llvmlistbot at llvm.org
Thu Aug 18 11:43:12 PDT 2022


Author: Jeff Niu
Date: 2022-08-18T14:43:06-04:00
New Revision: a48ff6888825002285060d46a7ed0f5142576991

URL: https://github.com/llvm/llvm-project/commit/a48ff6888825002285060d46a7ed0f5142576991
DIFF: https://github.com/llvm/llvm-project/commit/a48ff6888825002285060d46a7ed0f5142576991.diff

LOG: [mlir][LLVMIR] "Modernize" LLVM insert/extract element operations

This patch "modernizes" the implementation of these operations by
switching them to assembly formats and type inference traits.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D131971

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 08feec68b6bc..07900973a302 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -701,17 +701,26 @@ def LLVM_CallOp : LLVM_Op<"call",
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
 
-def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
+def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect,
+    TypesMatchWith<"result type matches vector element type", "vector", "res",
+                   "LLVM::getVectorElementType($_self)">]> {
+  let summary = "Extract an element from an LLVM vector.";
+
   let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
   let results = (outs LLVM_Type:$res);
+
+  let builders = [
+    OpBuilder<(ins "Value":$vector, "Value":$position,
+                   CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+  ];
+
+  let assemblyFormat = [{
+    $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
+  }];
+
   string llvmBuilder = [{
     $res = builder.CreateExtractElement($vector, $position);
   }];
-  let builders = [
-    OpBuilder<(ins "Value":$vector, "Value":$position,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
-  let hasCustomAssemblyFormat = 1;
-  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -746,16 +755,26 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
 // InsertElementOp
 //===----------------------------------------------------------------------===//
 
-def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
+def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect,
+    TypesMatchWith<"argument type matches vector element type", "vector",
+                   "value", "LLVM::getVectorElementType($_self)">,
+    AllTypesMatch<["res", "vector"]>]> {
+  let summary = "Insert an element into an LLVM vector.";
+
   let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value,
                        AnyInteger:$position);
   let results = (outs LLVM_AnyVector:$res);
+
+  let builders = [LLVM_OneResultOpBuilder];
+
+  let assemblyFormat = [{
+    $value `,` $vector `[` $position `:` type($position) `]` attr-dict `:`
+    type($vector)
+  }];
+
   string llvmBuilder = [{
     $res = builder.CreateInsertElement($vector, $value, $position);
   }];
-  let builders = [LLVM_OneResultOpBuilder];
-  let hasCustomAssemblyFormat = 1;
-  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index cf3139ec3cc7..35c215c07a9f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1365,10 +1365,10 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 //===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::ExtractElementOp.
+// ExtractElementOp
 //===----------------------------------------------------------------------===//
-// Expects vector to be of wrapped LLVM vector type and position to be of
-// wrapped LLVM i32 type.
+
+/// Expects vector to be an LLVM vector type and position to be an integer type.
 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
                                    Value vector, Value position,
                                    ArrayRef<NamedAttribute> attrs) {
@@ -1378,49 +1378,6 @@ void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
   result.addAttributes(attrs);
 }
 
-void ExtractElementOp::print(OpAsmPrinter &p) {
-  p << ' ' << getVector() << "[" << getPosition() << " : "
-    << getPosition().getType() << "]";
-  p.printOptionalAttrDict((*this)->getAttrs());
-  p << " : " << getVector().getType();
-}
-
-// <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
-//                 attribute-dict? `:` type
-ParseResult ExtractElementOp::parse(OpAsmParser &parser,
-                                    OperationState &result) {
-  SMLoc loc;
-  OpAsmParser::UnresolvedOperand vector, position;
-  Type type, positionType;
-  if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
-      parser.parseLSquare() || parser.parseOperand(position) ||
-      parser.parseColonType(positionType) || parser.parseRSquare() ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(vector, type, result.operands) ||
-      parser.resolveOperand(position, positionType, result.operands))
-    return failure();
-  if (!LLVM::isCompatibleVectorType(type))
-    return parser.emitError(
-        loc, "expected LLVM dialect-compatible vector type for operand #1");
-  result.addTypes(LLVM::getVectorElementType(type));
-  return success();
-}
-
-LogicalResult ExtractElementOp::verify() {
-  Type vectorType = getVector().getType();
-  if (!LLVM::isCompatibleVectorType(vectorType))
-    return emitOpError("expected LLVM dialect-compatible vector type for "
-                       "operand #1, got")
-           << vectorType;
-  Type valueType = LLVM::getVectorElementType(vectorType);
-  if (valueType != getRes().getType())
-    return emitOpError() << "Type mismatch: extracting from " << vectorType
-                         << " should produce " << valueType
-                         << " but this op returns " << getRes().getType();
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // ExtractValueOp
 //===----------------------------------------------------------------------===//
@@ -1530,57 +1487,6 @@ void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
         container, builder.getAttr<DenseI64ArrayAttr>(position));
 }
 
-//===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::InsertElementOp.
-//===----------------------------------------------------------------------===//
-
-void InsertElementOp::print(OpAsmPrinter &p) {
-  p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : "
-    << getPosition().getType() << "]";
-  p.printOptionalAttrDict((*this)->getAttrs());
-  p << " : " << getVector().getType();
-}
-
-// <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
-//                 attribute-dict? `:` type
-ParseResult InsertElementOp::parse(OpAsmParser &parser,
-                                   OperationState &result) {
-  SMLoc loc;
-  OpAsmParser::UnresolvedOperand vector, value, position;
-  Type vectorType, positionType;
-  if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
-      parser.parseComma() || parser.parseOperand(vector) ||
-      parser.parseLSquare() || parser.parseOperand(position) ||
-      parser.parseColonType(positionType) || parser.parseRSquare() ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(vectorType))
-    return failure();
-
-  if (!LLVM::isCompatibleVectorType(vectorType))
-    return parser.emitError(
-        loc, "expected LLVM dialect-compatible vector type for operand #1");
-  Type valueType = LLVM::getVectorElementType(vectorType);
-  if (!valueType)
-    return failure();
-
-  if (parser.resolveOperand(vector, vectorType, result.operands) ||
-      parser.resolveOperand(value, valueType, result.operands) ||
-      parser.resolveOperand(position, positionType, result.operands))
-    return failure();
-
-  result.addTypes(vectorType);
-  return success();
-}
-
-LogicalResult InsertElementOp::verify() {
-  Type valueType = LLVM::getVectorElementType(getVector().getType());
-  if (valueType != getValue().getType())
-    return emitOpError() << "Type mismatch: cannot insert "
-                         << getValue().getType() << " into "
-                         << getVector().getType();
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // InsertValueOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 9ec25d17aa0f..bf56f4fece53 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -458,14 +458,14 @@ func.func @extractvalue_wrong_nesting() {
 // -----
 
 func.func @invalid_vector_type_1(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
-  // expected-error at +1 {{expected LLVM dialect-compatible vector type for operand #1}}
+  // expected-error at +1 {{'vector' must be LLVM dialect-compatible vector}}
   %0 = llvm.extractelement %arg2[%arg1 : i32] : f32
 }
 
 // -----
 
 func.func @invalid_vector_type_2(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
-  // expected-error at +1 {{expected LLVM dialect-compatible vector type for operand #1}}
+  // expected-error at +1 {{'vector' must be LLVM dialect-compatible vector}}
   %0 = llvm.insertelement %arg2, %arg2[%arg1 : i32] : f32
 }
 
@@ -479,7 +479,7 @@ func.func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
 // -----
 
 func.func @invalid_vector_type_4(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> {
-  // expected-error at +1 {{'llvm.extractelement' op Type mismatch: extracting from 'vector<4xf32>' should produce 'f32' but this op returns 'vector<4xf32>'}}
+  // expected-error at +1 {{failed to verify that result type matches vector element type}}
   %b = "llvm.extractelement"(%a, %idx) : (vector<4xf32>, i32) -> vector<4xf32>
   return %b : vector<4xf32>
 }
@@ -487,7 +487,7 @@ func.func @invalid_vector_type_4(%a : vector<4xf32>, %idx : i32) -> vector<4xf32
 // -----
 
 func.func @invalid_vector_type_5(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> {
-  // expected-error at +1 {{'llvm.insertelement' op Type mismatch: cannot insert 'vector<4xf32>' into 'vector<4xf32>'}}
+  // expected-error at +1 {{failed to verify that argument type matches vector element type}}
   %b = "llvm.insertelement"(%a, %a, %idx) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32>
   return %b : vector<4xf32>
 }


        


More information about the Mlir-commits mailing list