[Mlir-commits] [mlir] aec9e20 - [mlir] introduce type constraints for operands of LLVM dialect operations

Alex Zinenko llvmlistbot at llvm.org
Fri Sep 4 01:02:08 PDT 2020


Author: Alex Zinenko
Date: 2020-09-04T10:01:59+02:00
New Revision: aec9e20a3e9a4f25a5b1e07816c95f970300d918

URL: https://github.com/llvm/llvm-project/commit/aec9e20a3e9a4f25a5b1e07816c95f970300d918
DIFF: https://github.com/llvm/llvm-project/commit/aec9e20a3e9a4f25a5b1e07816c95f970300d918.diff

LOG: [mlir] introduce type constraints for operands of LLVM dialect operations

Historically, the operations in the MLIR's LLVM dialect only checked that the
operand are of LLVM dialect type without more detailed constraints. This was
due to LLVM dialect types wrapping LLVM IR types and having clunky verification
methods. With the new first-class modeling, it is possible to define type
constraints similarly to other dialects and use them to enforce some
correctness rules in verifiers instead of having LLVM assert during translation
to LLVM IR. This hardening discovered several issues where MLIR was producing
LLVM dialect operations that cannot exist in LLVM IR.

Depends On D85900

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D85901

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 288031c598ff..0ae6267cb67c 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -21,7 +21,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 // Type constraint accepting standard integers, indices and wrapped LLVM integer
 // types.
 def IntLikeOrLLVMInt : TypeConstraint<
-  Or<[AnySignlessInteger.predicate, Index.predicate, LLVMInt.predicate]>,
+  Or<[AnySignlessInteger.predicate, Index.predicate,
+      LLVM_AnyInteger.predicate]>,
   "integer, index or LLVM dialect equivalent">;
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 1f0eb6aab58a..10755a436115 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -17,6 +17,10 @@
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
+//===----------------------------------------------------------------------===//
+// LLVM Dialect.
+//===----------------------------------------------------------------------===//
+
 def LLVM_Dialect : Dialect {
   let name = "llvm";
   let cppNamespace = "LLVM";
@@ -38,34 +42,108 @@ def LLVM_Dialect : Dialect {
   }];
 }
 
-// LLVM IR type wrapped in MLIR.
+//===----------------------------------------------------------------------===//
+// LLVM dialect type constraints.
+//===----------------------------------------------------------------------===//
+
+// LLVM dialect type.
 def LLVM_Type : DialectType<LLVM_Dialect,
                             CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
                             "LLVM dialect type">;
 
-// Type constraint accepting only wrapped LLVM integer types.
-def LLVMInt : TypeConstraint<
-  And<[LLVM_Type.predicate,
-       CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
-  "LLVM dialect integer">;
+// Type constraint accepting LLVM integer types.
+def LLVM_AnyInteger : Type<
+  CPred<"$_self.isa<::mlir::LLVM::LLVMIntegerType>()">,
+  "LLVM integer type">;
+
+// Type constraints accepting LLVM integer type of a specific width.
+class LLVM_IntBase<int width> :
+    Type<And<[
+        LLVM_AnyInteger.predicate,
+        CPred<"$_self.cast<::mlir::LLVM::LLVMIntegerType>().getBitWidth() == "
+              # width>]>,
+        "LLVM " # width # "-bit integer type">,
+    BuildableType<
+        "::mlir::LLVM::LLVMIntegerType::get($_builder.getContext(), "
+        # width # ")">;
+
+def LLVM_i1 : LLVM_IntBase<1>;
+def LLVM_i8 : LLVM_IntBase<8>;
+def LLVM_i32 : LLVM_IntBase<32>;
 
-def LLVMIntBase : TypeConstraint<
+// Type constraint accepting LLVM primitive types, i.e. all types except void
+// and function.
+def LLVM_PrimitiveType : Type<
   And<[LLVM_Type.predicate,
-       CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
-  "LLVM dialect integer">;
-
-// Integer type of a specific width.
-class LLVMI<int width>
-    : Type<And<[
-        LLVM_Type.predicate,
-        CPred<
-         "$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>,
-         "LLVM dialect " # width # "-bit integer">,
-      BuildableType<
-        "::mlir::LLVM::LLVMType::getIntNTy($_builder.getContext(),"
-         # width # ")">;
-
-def LLVMI1 : LLVMI<1>;
+       CPred<"!$_self.isa<::mlir::LLVM::LLVMVoidType, "
+                         "::mlir::LLVM::LLVMFunctionType>()">]>,
+  "primitive LLVM type">;
+
+// Type constraint accepting any LLVM floating point type.
+def LLVM_AnyFloat : Type<
+  CPred<"$_self.isa<::mlir::LLVM::LLVMBFloatType, "
+                   "::mlir::LLVM::LLVMHalfType, "
+                   "::mlir::LLVM::LLVMFloatType, "
+                   "::mlir::LLVM::LLVMDoubleType>()">,
+  "floating point LLVM type">;
+
+// Type constraint accepting any LLVM pointer type.
+def LLVM_AnyPointer : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMPointerType>()">,
+                          "LLVM pointer type">;
+
+// Type constraint accepting LLVM pointer type with an additional constraint
+// on the element type.
+class LLVM_PointerTo<Type pointee> : Type<
+  And<[LLVM_AnyPointer.predicate,
+       SubstLeaves<
+         "$_self",
+         "$_self.cast<::mlir::LLVM::LLVMPointerType>().getElementType()",
+         pointee.predicate>]>,
+  "LLVM pointer to " # pointee.description>;
+
+// Type constraint accepting any LLVM structure type.
+def LLVM_AnyStruct : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMStructType>()">,
+                         "LLVM structure type">;
+
+// Type constraint accepting opaque LLVM structure type.
+def LLVM_OpaqueStruct : Type<
+  And<[LLVM_AnyStruct.predicate,
+       CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().isOpaque()">]>>;
+
+// Type constraint accepting any LLVM type that can be loaded or stored, i.e. a
+// type that has size (not void, function or opaque struct type).
+def LLVM_LoadableType : Type<
+  And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>]>,
+  "LLVM type with size">;
+
+// Type constraint accepting any LLVM aggregate type, i.e. structure or array.
+def LLVM_AnyAggregate : Type<
+  CPred<"$_self.isa<::mlir::LLVM::LLVMStructType, "
+                   "::mlir::LLVM::LLVMArrayType>()">,
+  "LLVM aggregate type">;
+
+// Type constraint accepting any LLVM non-aggregate type, i.e. not structure or
+// array.
+def LLVM_AnyNonAggregate : Type<Neg<LLVM_AnyAggregate.predicate>,
+                               "LLVM non-aggregate type">;
+
+// Type constraint accepting any LLVM vector type.
+def LLVM_AnyVector : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMVectorType>()">,
+                         "LLVM vector type">;
+
+// Type constraint accepting an LLVM vector type with an additional constraint
+// on the vector element type.
+class LLVM_VectorOf<Type element> : Type<
+  And<[LLVM_AnyVector.predicate,
+       SubstLeaves<
+         "$_self",
+         "$_self.cast<::mlir::LLVM::LLVMVectorType>().getElementType()",
+         element.predicate>]>,
+  "LLVM vector of " # element.description>;
+
+// Type constraint accepting a constrained type, or a vector of such types.
+class LLVM_ScalarOrVectorOf<Type element> :
+    AnyTypeOf<[element, LLVM_VectorOf<element>]>;
 
 // Base class for LLVM operations. Defines the interface to the llvm::IRBuilder
 // used to translate to LLVM IR proper.
@@ -85,6 +163,10 @@ class LLVM_OpBase<Dialect dialect, string mnemonic, list<OpTrait> traits = []> :
   string llvmBuilder = "";
 }
 
+//===----------------------------------------------------------------------===//
+// Base classes for LLVM dialect operations.
+//===----------------------------------------------------------------------===//
+
 // Base class for LLVM operations. All operations get an "llvm." prefix in
 // their name automatically. LLVM operations have either zero or one result,
 // this class is specialized below for both cases and should not be used

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b1dd7b1af030..b5bf4ac77972 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -87,39 +87,50 @@ class LLVM_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
     LLVM_Op<mnemonic, !listconcat(traits, [Terminator])>;
 
 // Class for arithmetic binary operations.
-class LLVM_ArithmeticOp<string mnemonic, string builderFunc,
-                        list<OpTrait> traits = []> :
+class LLVM_ArithmeticOpBase<Type type, string mnemonic,
+                            string builderFunc, list<OpTrait> traits = []> :
     LLVM_OneResultOp<mnemonic,
            !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
-    Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>,
+    Arguments<(ins LLVM_ScalarOrVectorOf<type>:$lhs,
+                   LLVM_ScalarOrVectorOf<type>:$rhs)>,
     LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> {
-  let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
+  let parser =
+      [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
   let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }];
 }
-class LLVM_UnaryArithmeticOp<string mnemonic, string builderFunc,
-                        list<OpTrait> traits = []> :
+class LLVM_IntArithmeticOp<string mnemonic, string builderFunc,
+                           list<OpTrait> traits = []> :
+    LLVM_ArithmeticOpBase<LLVM_AnyInteger, mnemonic, builderFunc, traits>;
+class LLVM_FloatArithmeticOp<string mnemonic, string builderFunc,
+                             list<OpTrait> traits = []> :
+    LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, builderFunc, traits>;
+
+// Class for arithmetic unary operations.
+class LLVM_UnaryArithmeticOp<Type type, string mnemonic,
+                             string builderFunc, list<OpTrait> traits = []> :
     LLVM_OneResultOp<mnemonic,
            !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
-    Arguments<(ins LLVM_Type:$operand)>,
+    Arguments<(ins type:$operand)>,
     LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> {
-  let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
+  let parser =
+      [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
   let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }];
 }
 
 // Integer binary operations.
-def LLVM_AddOp : LLVM_ArithmeticOp<"add", "CreateAdd", [Commutative]>;
-def LLVM_SubOp : LLVM_ArithmeticOp<"sub", "CreateSub">;
-def LLVM_MulOp : LLVM_ArithmeticOp<"mul", "CreateMul", [Commutative]>;
-def LLVM_UDivOp : LLVM_ArithmeticOp<"udiv", "CreateUDiv">;
-def LLVM_SDivOp : LLVM_ArithmeticOp<"sdiv", "CreateSDiv">;
-def LLVM_URemOp : LLVM_ArithmeticOp<"urem", "CreateURem">;
-def LLVM_SRemOp : LLVM_ArithmeticOp<"srem", "CreateSRem">;
-def LLVM_AndOp : LLVM_ArithmeticOp<"and", "CreateAnd">;
-def LLVM_OrOp : LLVM_ArithmeticOp<"or", "CreateOr">;
-def LLVM_XOrOp : LLVM_ArithmeticOp<"xor", "CreateXor">;
-def LLVM_ShlOp : LLVM_ArithmeticOp<"shl", "CreateShl">;
-def LLVM_LShrOp : LLVM_ArithmeticOp<"lshr", "CreateLShr">;
-def LLVM_AShrOp : LLVM_ArithmeticOp<"ashr", "CreateAShr">;
+def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "CreateAdd", [Commutative]>;
+def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "CreateSub">;
+def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "CreateMul", [Commutative]>;
+def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "CreateUDiv">;
+def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "CreateSDiv">;
+def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "CreateURem">;
+def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "CreateSRem">;
+def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "CreateAnd">;
+def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "CreateOr">;
+def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "CreateXor">;
+def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "CreateShl">;
+def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "CreateLShr">;
+def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "CreateAShr">;
 
 // Predicate for integer comparisons.
 def ICmpPredicateEQ  : I64EnumAttrCase<"eq", 0>;
@@ -143,8 +154,9 @@ def ICmpPredicate : I64EnumAttr<
 
 // Other integer operations.
 def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
-                  Arguments<(ins ICmpPredicate:$predicate, LLVM_Type:$lhs,
-                             LLVM_Type:$rhs)> {
+                  Arguments<(ins ICmpPredicate:$predicate,
+                                 LLVM_ScalarOrVectorOf<LLVM_AnyInteger>:$lhs,
+                                 LLVM_ScalarOrVectorOf<LLVM_AnyInteger>:$rhs)> {
   let llvmBuilder = [{
     $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
   }];
@@ -189,8 +201,9 @@ def FCmpPredicate : I64EnumAttr<
 
 // Other integer operations.
 def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>,
-                  Arguments<(ins FCmpPredicate:$predicate, LLVM_Type:$lhs,
-                             LLVM_Type:$rhs)> {
+                  Arguments<(ins FCmpPredicate:$predicate,
+                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$lhs,
+                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs)> {
   let llvmBuilder = [{
     $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
   }];
@@ -205,12 +218,13 @@ def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>,
 }
 
 // Floating point binary operations.
-def LLVM_FAddOp : LLVM_ArithmeticOp<"fadd", "CreateFAdd">;
-def LLVM_FSubOp : LLVM_ArithmeticOp<"fsub", "CreateFSub">;
-def LLVM_FMulOp : LLVM_ArithmeticOp<"fmul", "CreateFMul">;
-def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">;
-def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">;
-def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">;
+def LLVM_FAddOp : LLVM_FloatArithmeticOp<"fadd", "CreateFAdd">;
+def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub">;
+def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">;
+def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">;
+def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">;
+def LLVM_FNegOp : LLVM_UnaryArithmeticOp<LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
+                                         "fneg", "CreateFNeg">;
 
 // Common code definition that is used to verify and set the alignment attribute
 // of LLVM ops that accept such an attribute.
@@ -241,7 +255,8 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
 def LLVM_AllocaOp :
     MemoryOpWithAlignmentBase,
     LLVM_OneResultOp<"alloca">,
-    Arguments<(ins LLVM_Type:$arraySize, OptionalAttr<I64Attr>:$alignment)> {
+    Arguments<(ins LLVM_AnyInteger:$arraySize,
+               OptionalAttr<I64Attr>:$alignment)> {
   string llvmBuilder = [{
     auto *inst = builder.CreateAlloca(
       $_resultType->getPointerElementType(), $arraySize);
@@ -259,8 +274,11 @@ def LLVM_AllocaOp :
   let parser = [{ return parseAllocaOp(parser, result); }];
   let printer = [{ printAllocaOp(p, *this); }];
 }
+
 def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
-                 Arguments<(ins LLVM_Type:$base, Variadic<LLVM_Type>:$indices)>,
+                 Arguments<(ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
+                                Variadic<LLVM_ScalarOrVectorOf<
+                                    LLVM_AnyInteger>>:$indices)>,
                  LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> {
   let assemblyFormat = [{
     $base `[` $indices `]` attr-dict `:` functional-type(operands, results)
@@ -269,7 +287,7 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
 def LLVM_LoadOp :
     MemoryOpWithAlignmentAndAttributes,
     LLVM_OneResultOp<"load">,
-    Arguments<(ins LLVM_Type:$addr,
+    Arguments<(ins LLVM_PointerTo<LLVM_LoadableType>:$addr,
                    OptionalAttr<I64Attr>:$alignment,
                    UnitAttr:$volatile_,
                    UnitAttr:$nontemporal)> {
@@ -296,8 +314,8 @@ def LLVM_LoadOp :
 def LLVM_StoreOp :
     MemoryOpWithAlignmentAndAttributes,
     LLVM_ZeroResultOp<"store">,
-    Arguments<(ins LLVM_Type:$value,
-                   LLVM_Type:$addr,
+    Arguments<(ins LLVM_LoadableType:$value,
+                   LLVM_PointerTo<LLVM_LoadableType>:$addr,
                    OptionalAttr<I64Attr>:$alignment,
                    UnitAttr:$volatile_,
                    UnitAttr:$nontemporal)> {
@@ -314,28 +332,41 @@ def LLVM_StoreOp :
 }
 
 // Casts.
-class LLVM_CastOp<string mnemonic, string builderFunc,
+class LLVM_CastOp<string mnemonic, string builderFunc, Type type,
                   list<OpTrait> traits = []> :
     LLVM_OneResultOp<mnemonic,
            !listconcat([NoSideEffect], traits)>,
-    Arguments<(ins LLVM_Type:$arg)>,
+    Arguments<(ins type:$arg)>,
     LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> {
   let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
   let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
 }
-def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast">;
-def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast">;
-def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr">;
-def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt">;
-def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt">;
-def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt">;
-def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc">;
-def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP">;
-def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP">;
-def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI">;
-def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI">;
-def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">;
-def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">;
+def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast",
+                                 LLVM_AnyNonAggregate>;
+def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast",
+                                       LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
+def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr",
+                                  LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt",
+                                  LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
+def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt",
+                              LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt",
+                              LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc",
+                               LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP",
+                                LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP",
+                                LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI",
+                                LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
+def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI",
+                                LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
+def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt",
+                                LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
+def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc",
+                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
 
 // Call-related operations.
 def LLVM_InvokeOp : LLVM_Op<"invoke", [
@@ -404,8 +435,8 @@ def LLVM_CallOp : LLVM_Op<"call">,
   let printer = [{ printCallOp(p, *this); }];
 }
 def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>,
-                          Arguments<(ins LLVM_Type:$vector,
-                                     LLVM_Type:$position)> {
+                          Arguments<(ins LLVM_AnyVector:$vector,
+                                         LLVM_AnyInteger:$position)> {
   string llvmBuilder = [{
     $res = builder.CreateExtractElement($vector, $position);
   }];
@@ -416,8 +447,8 @@ def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>,
   let printer = [{ printExtractElementOp(p, *this); }];
 }
 def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>,
-                          Arguments<(ins LLVM_Type:$container,
-                                     ArrayAttr:$position)> {
+                          Arguments<(ins LLVM_AnyAggregate:$container,
+                                         ArrayAttr:$position)> {
   string llvmBuilder = [{
     $res = builder.CreateExtractValue($container, extractPosition($position));
   }];
@@ -425,8 +456,9 @@ def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>,
   let printer = [{ printExtractValueOp(p, *this); }];
 }
 def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>,
-                         Arguments<(ins LLVM_Type:$vector, LLVM_Type:$value,
-                                    LLVM_Type:$position)> {
+                         Arguments<(ins LLVM_AnyVector:$vector,
+                                        LLVM_PrimitiveType:$value,
+                                        LLVM_AnyInteger:$position)> {
   string llvmBuilder = [{
     $res = builder.CreateInsertElement($vector, $value, $position);
   }];
@@ -434,8 +466,9 @@ def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>,
   let printer = [{ printInsertElementOp(p, *this); }];
 }
 def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>,
-                         Arguments<(ins LLVM_Type:$container, LLVM_Type:$value,
-                                    ArrayAttr:$position)> {
+                         Arguments<(ins LLVM_AnyAggregate:$container,
+                                        LLVM_PrimitiveType:$value,
+                                        ArrayAttr:$position)> {
   string llvmBuilder = [{
     $res = builder.CreateInsertValue($container, $value,
                                      extractPosition($position));
@@ -451,7 +484,7 @@ def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>,
 }
 def LLVM_ShuffleVectorOp
     : LLVM_OneResultOp<"shufflevector", [NoSideEffect]>,
-      Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, ArrayAttr:$mask)> {
+      Arguments<(ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask)> {
   string llvmBuilder = [{
       SmallVector<unsigned, 4> position = extractPosition($mask);
       SmallVector<int, 4> mask(position.begin(), position.end());
@@ -478,8 +511,9 @@ def LLVM_ShuffleVectorOp
 def LLVM_SelectOp
     : LLVM_OneResultOp<"select",
           [NoSideEffect, AllTypesMatch<["trueValue", "falseValue", "res"]>]>,
-      Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue,
-                 LLVM_Type:$falseValue)>,
+      Arguments<(ins LLVM_ScalarOrVectorOf<LLVM_i1>:$condition,
+                     LLVM_Type:$trueValue,
+                     LLVM_Type:$falseValue)>,
       LLVM_Builder<
           "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
   let builders = [OpBuilder<
@@ -508,7 +542,7 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
 def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
     [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
      NoSideEffect]> {
-  let arguments = (ins LLVMI1:$condition,
+  let arguments = (ins LLVM_i1:$condition,
                    Variadic<LLVM_Type>:$trueDestOperands,
                    Variadic<LLVM_Type>:$falseDestOperands,
                    OptionalAttr<ElementsAttr>:$branch_weights);
@@ -1090,9 +1124,11 @@ def AtomicOrdering : I64EnumAttr<
   let cppNamespace = "::mlir::LLVM";
 }
 
+def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyInteger]>;
+
 def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">,
-    Arguments<(ins AtomicBinOp:$bin_op, LLVM_Type:$ptr, LLVM_Type:$val,
-                   AtomicOrdering:$ordering)>,
+    Arguments<(ins AtomicBinOp:$bin_op, LLVM_PointerTo<LLVM_AtomicRMWType>:$ptr,
+               LLVM_AtomicRMWType:$val, AtomicOrdering:$ordering)>,
     Results<(outs LLVM_Type:$res)> {
   let llvmBuilder = [{
     $res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($bin_op), $ptr, $val,
@@ -1103,8 +1139,11 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">,
   let verifier = "return ::verify(*this);";
 }
 
+def LLVM_AtomicCmpXchgType : AnyTypeOf<[LLVM_AnyInteger, LLVM_AnyPointer]>;
+
 def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg">,
-    Arguments<(ins LLVM_Type:$ptr, LLVM_Type:$cmp, LLVM_Type:$val,
+    Arguments<(ins LLVM_PointerTo<LLVM_AtomicCmpXchgType>:$ptr,
+                   LLVM_AtomicCmpXchgType:$cmp, LLVM_AtomicCmpXchgType:$val,
                    AtomicOrdering:$success_ordering,
                    AtomicOrdering:$failure_ordering)>,
     Results<(outs LLVM_Type:$res)> {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 96d8459e5b3d..63bd10c2e6f1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1533,8 +1533,6 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
 
 static LogicalResult verify(AtomicRMWOp op) {
   auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
-  if (!ptrType.isPointerTy())
-    return op.emitOpError("expected LLVM IR pointer type for operand #0");
   auto valType = op.val().getType().cast<LLVM::LLVMType>();
   if (valType != ptrType.getPointerElementTy())
     return op.emitOpError("expected LLVM IR element type for operand #0 to "

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index e27650b3297d..a89287b764e5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -440,7 +440,8 @@ LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
 bool LLVMStructType::isPacked() { return getImpl()->isPacked(); }
 bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); }
 bool LLVMStructType::isOpaque() {
-  return getImpl()->isOpaque() || !getImpl()->isInitialized();
+  return getImpl()->isIdentified() &&
+         (getImpl()->isOpaque() || !getImpl()->isInitialized());
 }
 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
 StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 1f8b1600873c..c19795e98b68 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -394,7 +394,7 @@ func @nvvm_invalid_mma_7(%a0 : !llvm.vec<2 x half>, %a1 : !llvm.vec<2 x half>,
 
 // CHECK-LABEL: @atomicrmw_expected_ptr
 func @atomicrmw_expected_ptr(%f32 : !llvm.float) {
-  // expected-error at +1 {{expected LLVM IR pointer type for operand #0}}
+  // expected-error at +1 {{operand #0 must be LLVM pointer to floating point LLVM type or LLVM integer type}}
   %0 = "llvm.atomicrmw"(%f32, %f32) {bin_op=11, ordering=1} : (!llvm.float, !llvm.float) -> !llvm.float
   llvm.return
 }
@@ -448,7 +448,7 @@ func @atomicrmw_expected_int(%f32_ptr : !llvm.ptr<float>, %f32 : !llvm.float) {
 
 // CHECK-LABEL: @cmpxchg_expected_ptr
 func @cmpxchg_expected_ptr(%f32_ptr : !llvm.ptr<float>, %f32 : !llvm.float) {
-  // expected-error at +1 {{expected LLVM IR pointer type for operand #0}}
+  // expected-error at +1 {{op operand #0 must be LLVM pointer to LLVM integer type or LLVM pointer type}}
   %0 = "llvm.cmpxchg"(%f32, %f32, %f32) {success_ordering=2,failure_ordering=2} : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.struct<(float, i1)>
   llvm.return
 }


        


More information about the Mlir-commits mailing list