[Mlir-commits] [mlir] cd73af9 - [MLIR] Remove LLVM_AnyInteger type constraint

Kiran Chandramohan llvmlistbot at llvm.org
Tue Jun 8 09:21:23 PDT 2021


Author: Kiran Chandramohan
Date: 2021-06-08T17:21:00+01:00
New Revision: cd73af92315ecf25ed47f4991806a054ddfca5ea

URL: https://github.com/llvm/llvm-project/commit/cd73af92315ecf25ed47f4991806a054ddfca5ea
DIFF: https://github.com/llvm/llvm-project/commit/cd73af92315ecf25ed47f4991806a054ddfca5ea.diff

LOG: [MLIR] Remove LLVM_AnyInteger type constraint

LLVM Dialect uses builtin-integer types. The existing LLVM_AnyInteger
type constraint is a dupe of AnyInteger. This patch removes LLVM_AnyInteger
and replaces all usage with AnyInteger.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D103839

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AMX/AMX.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/test/Dialect/LLVMIR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 24052ed4f24d0..85611affa80c1 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -239,7 +239,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
 //
 
 def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
-  Arguments<(ins LLVM_AnyInteger, LLVM_AnyInteger)>;
+  Arguments<(ins AnyInteger, AnyInteger)>;
 
 //
 // Tile memory operations. Parameters define the tile size,
@@ -248,12 +248,12 @@ def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
 //
 
 def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
-  Arguments<(ins LLVM_AnyInteger,
-                 LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger)>;
+  Arguments<(ins AnyInteger,
+                 AnyInteger, LLVM_AnyPointer, AnyInteger)>;
 
 def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
-  Arguments<(ins LLVM_AnyInteger,
-                 LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger, LLVM_Type)>;
+  Arguments<(ins AnyInteger,
+                 AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
 
 //
 // Tile multiplication operations (series of dot products). Parameters
@@ -263,32 +263,32 @@ def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
 
 // Dot product of bf16 tiles into f32 tile.
 def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
-  Arguments<(ins LLVM_AnyInteger,
-                 LLVM_AnyInteger,
-		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+  Arguments<(ins AnyInteger,
+                 AnyInteger,
+		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
 // Dot product of i8 tiles into i32 tile (with sign/sign extension).
 def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
-  Arguments<(ins LLVM_AnyInteger,
-                 LLVM_AnyInteger,
-		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+  Arguments<(ins AnyInteger,
+                 AnyInteger,
+		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
 // Dot product of i8 tiles into i32 tile (with sign/zero extension).
 def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
-  Arguments<(ins LLVM_AnyInteger,
-                 LLVM_AnyInteger,
-		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+  Arguments<(ins AnyInteger,
+                 AnyInteger,
+		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
 // Dot product of i8 tiles into i32 tile (with zero/sign extension).
 def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
-  Arguments<(ins LLVM_AnyInteger,
-                 LLVM_AnyInteger,
-		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+  Arguments<(ins AnyInteger,
+                 AnyInteger,
+		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
 // Dot product of i8 tiles into i32 tile (with zero/zero extension).
 def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
-  Arguments<(ins LLVM_AnyInteger,
-                 LLVM_AnyInteger,
-		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+  Arguments<(ins AnyInteger,
+                 AnyInteger,
+		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
 #endif // AMX

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 8c83dbc0c9d19..716260f3819d1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -62,11 +62,6 @@ def LLVM_TokenType : Type<
   "LLVM token type">,
   BuildableType<"::mlir::LLVM::LLVMTokenType::get($_builder.getContext())">;
 
-// Type constraint accepting LLVM integer types.
-def LLVM_AnyInteger : Type<
-  CPred<"$_self.isa<::mlir::IntegerType>()">,
-  "LLVM integer type">;
-
 // Type constraint accepting LLVM primitive types, i.e. all types except void
 // and function.
 def LLVM_PrimitiveType : Type<

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 64271d625e175..e1a32e6d6f153 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -129,7 +129,7 @@ class LLVM_ArithmeticOpBase<Type type, string mnemonic,
 }
 class LLVM_IntArithmeticOp<string mnemonic, string builderFunc,
                            list<OpTrait> traits = []> :
-    LLVM_ArithmeticOpBase<LLVM_AnyInteger, mnemonic, builderFunc, traits> {
+    LLVM_ArithmeticOpBase<AnyInteger, mnemonic, builderFunc, traits> {
   let arguments = commonArgs;
 }
 class LLVM_FloatArithmeticOp<string mnemonic, string builderFunc,
@@ -190,8 +190,8 @@ def ICmpPredicate : I64EnumAttr<
 // Other integer operations.
 def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> {
   let arguments = (ins ICmpPredicate:$predicate,
-                   AnyTypeOf<[LLVM_ScalarOrVectorOf<LLVM_AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$lhs,
-                   AnyTypeOf<[LLVM_ScalarOrVectorOf<LLVM_AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$rhs);
+                   AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$lhs,
+                   AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$rhs);
   let results = (outs LLVM_ScalarOrVectorOf<I1>:$res);
   let llvmBuilder = [{
     $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
@@ -290,7 +290,7 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
 
 // Memory-related operations.
 def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase {
-  let arguments = (ins LLVM_AnyInteger:$arraySize,
+  let arguments = (ins AnyInteger:$arraySize,
                    OptionalAttr<I64Attr>:$alignment);
   let results = (outs LLVM_AnyPointer:$res);
   string llvmBuilder = [{
@@ -318,7 +318,7 @@ def LLVM_GEPOp
           "$res = builder.CreateGEP("
           " $base->getType()->getPointerElementType(), $base, $indices);"> {
   let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
-                   Variadic<LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>:$indices);
+                   Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$indices);
   let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
   let builders = [LLVM_OneResultOpBuilder];
   let assemblyFormat = [{
@@ -389,32 +389,32 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast",
                                        LLVM_ScalarOrVectorOf<LLVM_AnyPointer>,
                                        LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
 def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr",
-                                  LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
+                                  LLVM_ScalarOrVectorOf<AnyInteger>,
                                   LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
 def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt",
                                   LLVM_ScalarOrVectorOf<LLVM_AnyPointer>,
-                                  LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+                                  LLVM_ScalarOrVectorOf<AnyInteger>>;
 def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt",
-                              LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
-                              LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+                              LLVM_ScalarOrVectorOf<AnyInteger>,
+                              LLVM_ScalarOrVectorOf<AnyInteger>>;
 def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt",
-                              LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
-                              LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+                              LLVM_ScalarOrVectorOf<AnyInteger>,
+                              LLVM_ScalarOrVectorOf<AnyInteger>>;
 def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc",
-                               LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
-                               LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+                               LLVM_ScalarOrVectorOf<AnyInteger>,
+                               LLVM_ScalarOrVectorOf<AnyInteger>>;
 def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP",
-                                LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
+                                LLVM_ScalarOrVectorOf<AnyInteger>,
                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
 def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP",
-                                LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
+                                LLVM_ScalarOrVectorOf<AnyInteger>,
                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
 def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI",
                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
-                                LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+                                LLVM_ScalarOrVectorOf<AnyInteger>>;
 def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI",
                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
-                                LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
+                                LLVM_ScalarOrVectorOf<AnyInteger>>;
 def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt",
                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
@@ -514,7 +514,7 @@ def LLVM_CallOp : LLVM_Op<"call",
   let printer = [{ printCallOp(p, *this); }];
 }
 def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
-  let arguments = (ins LLVM_AnyVector:$vector, LLVM_AnyInteger:$position);
+  let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
   let results = (outs LLVM_Type:$res);
   string llvmBuilder = [{
     $res = builder.CreateExtractElement($vector, $position);
@@ -537,7 +537,7 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
 }
 def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
   let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value,
-                   LLVM_AnyInteger:$position);
+                   AnyInteger:$position);
   let results = (outs LLVM_AnyVector:$res);
   string llvmBuilder = [{
     $res = builder.CreateInsertElement($vector, $value, $position);
@@ -1616,7 +1616,7 @@ def AtomicOrdering : I64EnumAttr<
   let cppNamespace = "::mlir::LLVM";
 }
 
-def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyInteger]>;
+def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, AnyInteger]>;
 
 // FIXME: Need to add alignment attribute to MLIR atomicrmw operation.
 def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> {
@@ -1634,7 +1634,7 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> {
   let verifier = "return ::verify(*this);";
 }
 
-def LLVM_AtomicCmpXchgType : AnyTypeOf<[LLVM_AnyInteger, LLVM_AnyPointer]>;
+def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>;
 def LLVM_AtomicCmpXchgResultType : Type<And<[
   LLVM_AnyStruct.predicate,
   CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().getBody().size() == 2">,

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 6c1f5c0e7f10b..087d10d143980 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -28,9 +28,8 @@ def OpenMP_Dialect : Dialect {
 class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
       Op<OpenMP_Dialect, mnemonic, traits>;
 
-// Type which can be constraint accepting standard integers, indices and
-// LLVM integer types.
-def IntLikeType : AnyTypeOf<[AnyInteger, Index, LLVM_AnyInteger]>;
+// Type which can be constraint accepting standard integers and indices.
+def IntLikeType : AnyTypeOf<[AnyInteger, Index]>;
 
 //===----------------------------------------------------------------------===//
 // 2.6 parallel Construct

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index d01a195b8b12d..a28218bdeddb2 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -539,7 +539,7 @@ func @nvvm_invalid_mma_7(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
 // -----
 
 func @atomicrmw_expected_ptr(%f32 : f32) {
-  // expected-error at +1 {{operand #0 must be LLVM pointer to floating point LLVM type or LLVM integer type}}
+  // expected-error at +1 {{operand #0 must be LLVM pointer to floating point LLVM type or integer}}
   %0 = "llvm.atomicrmw"(%f32, %f32) {bin_op=11, ordering=1} : (f32, f32) -> f32
   llvm.return
 }
@@ -587,7 +587,7 @@ func @atomicrmw_expected_int(%f32_ptr : !llvm.ptr<f32>, %f32 : f32) {
 // -----
 
 func @cmpxchg_expected_ptr(%f32_ptr : !llvm.ptr<f32>, %f32 : f32) {
-  // expected-error at +1 {{op operand #0 must be LLVM pointer to LLVM integer type or LLVM pointer type}}
+  // expected-error at +1 {{op operand #0 must be LLVM pointer to integer or LLVM pointer type}}
   %0 = "llvm.cmpxchg"(%f32, %f32, %f32) {success_ordering=2,failure_ordering=2} : (f32, f32, f32) -> !llvm.struct<(f32, i1)>
   llvm.return
 }


        


More information about the Mlir-commits mailing list