[Mlir-commits] [mlir] ba79bb4 - [mlir][nvvm] Change MMAShapeAttr to AttrDef

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 9 15:14:49 PDT 2022


Author: Mogball
Date: 2022-06-09T22:14:45Z
New Revision: ba79bb4973f963e9bd6a007e6508cdc6ec990051

URL: https://github.com/llvm/llvm-project/commit/ba79bb4973f963e9bd6a007e6508cdc6ec990051
DIFF: https://github.com/llvm/llvm-project/commit/ba79bb4973f963e9bd6a007e6508cdc6ec990051.diff

LOG: [mlir][nvvm] Change MMAShapeAttr to AttrDef

MMAShapeAttr was a StructAttr

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127348

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/include/mlir/IR/Builders.h
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 20cb2e47343c0..b3eca74db7329 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -47,6 +47,15 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
   LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM attribute definitions
+//===----------------------------------------------------------------------===//
+
+class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<NVVM_Dialect, attrName, traits> {
+  let mnemonic = attrMnemonic;
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM intrinsic operations
 //===----------------------------------------------------------------------===//
@@ -460,12 +469,10 @@ def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflo
 }
 
 /// Attribute to hold the MMA shape
-def NVVM_MMAShapeAttr : StructAttr<"MMAShapeAttr", NVVM_Dialect, [
-    StructFieldAttr<"m", I32Attr>,
-    StructFieldAttr<"n", I32Attr>,
-    StructFieldAttr<"k", I32Attr>
-  ]> {
+def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> {
   let summary = "Attribute for MMA operation shape.";
+  let parameters = (ins "int":$m, "int":$n, "int":$k);
+  let assemblyFormat = "`<` struct(params) `>`";
 }
 
 // Returns true if this combination of layout/satf for MMA ops is supported;
@@ -983,7 +990,7 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
   string llvmBuilder = [{
     auto operands = moduleTranslation.lookupValues(opInst.getOperands());
     auto intId = mlir::NVVM::MmaOp::getIntrinsicID(
-        $shape.m().getInt(), $shape.n().getInt(), $shape.k().getInt(), 
+        $shape.getM(), $shape.getN(), $shape.getK(), 
         $b1Op, $intOverflowBehavior,
         $layoutA, $layoutB,
         $multiplicandAPtxType.getValue(), 

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2a175bf63e6d8..205a0629ae88d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -78,10 +78,17 @@ class Builder {
   TupleType getTupleType(TypeRange elementTypes);
   NoneType getNoneType();
 
-  /// Get or construct an instance of the type 'ty' with provided arguments.
+  /// Get or construct an instance of the type `Ty` with provided arguments.
   template <typename Ty, typename... Args>
-  Ty getType(Args... args) {
-    return Ty::get(context, args...);
+  Ty getType(Args &&...args) {
+    return Ty::get(context, std::forward<Args>(args)...);
+  }
+
+  /// Get or construct an instance of the attribute `Attr` with provided
+  /// arguments.
+  template <typename Attr, typename... Args>
+  Attr getAttr(Args &&...args) {
+    return Attr::get(context, std::forward<Args>(args)...);
   }
 
   // Attributes.
@@ -510,8 +517,7 @@ class OpBuilder : public Builder {
   Operation *cloneWithoutRegions(Operation &op) {
     return insert(op.cloneWithoutRegions());
   }
-  template <typename OpT>
-  OpT cloneWithoutRegions(OpT op) {
+  template <typename OpT> OpT cloneWithoutRegions(OpT op) {
     return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
   }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f667d54d88433..d217f6bfe02c5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -196,11 +196,8 @@ void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
 
   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
   MLIRContext *ctx = builder.getContext();
-  Type i32 = builder.getIntegerType(32);
   result.addAttribute(
-      "shape", MMAShapeAttr::get(builder.getIntegerAttr(i32, shape[0]),
-                                 builder.getIntegerAttr(i32, shape[1]),
-                                 builder.getIntegerAttr(i32, shape[2]), ctx));
+      "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
 
   result.addOperands(operandA);
   result.addOperands(operandB);
@@ -358,9 +355,8 @@ LogicalResult MmaOp::verify() {
   auto s32x2StructTy =
       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
 
-  std::array<int64_t, 3> mmaShape{shapeAttr().m().getInt(),
-                                  shapeAttr().n().getInt(),
-                                  shapeAttr().k().getInt()};
+  std::array<int64_t, 3> mmaShape{shapeAttr().getM(), shapeAttr().getN(),
+                                  shapeAttr().getK()};
 
   // These variables define the set of allowed data types for matrices A, B, C,
   // and result.

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 8a8d6d5bca06d..0d0f7845e24a9 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -12,7 +12,7 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
   // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
   // CHECK-NOT llvm.extractvalue
   // CHECK: [[d:%.+]] = nvvm.mma.sync
-  // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n  = 8 : i32}
+  // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>    
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
@@ -30,7 +30,7 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
 func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
   // We just need to check the mma instruction and the manipulatin of the result.
   // CHECK: [[d:%.+]] = nvvm.mma.sync
-  // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n  = 8 : i32}
+  // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
   // CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>    
   // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
@@ -61,7 +61,7 @@ func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: v
   // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
   // CHECK-NOT llvm.extractvalue
   // CHECK: [[d:%.+]] = nvvm.mma.sync
-  // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n  = 8 : i32}
+  // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 8>
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>    
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
@@ -95,7 +95,7 @@ func.func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: ve
   // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
-  // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
+  // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
@@ -116,7 +116,7 @@ func.func @m16n8k32_i4(%arg0: vector<2x8xi4>, %arg1: vector<1x8xi4>, %arg2: vect
   // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
-  // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
+  // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
@@ -143,7 +143,7 @@ func.func @m16n8k64_i4(%arg0: vector<4x8xi4>, %arg1: vector<2x8xi4>, %arg2: vect
   // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
-  // CHECK-SAME: shape = {k = 64 : i32, m = 16 : i32, n = 8 : i32}
+  // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 64>
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
@@ -156,7 +156,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
   // CHECK: llvm.extractvalue
   // CHECK: llvm.extractvalue
   // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
-  // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}
+  // CHECK-SAME: shape = #nvvm.shape<m = 8, n = 8, k = 4>
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
   // CHECK: llvm.mlir.undef : vector<2xf64>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>    
@@ -217,7 +217,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
   // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}, {{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}, {{%.+}}, {{%.+}}]
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<tf32>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
-  // CHECK-SAME: shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}
+  // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
   // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>  
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32>  
   // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0]

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 876668de9d819..449ffb1bc9ab6 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -541,7 +541,7 @@ func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
                          %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
   // expected-error at +1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] 
-    {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+    {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
@@ -553,7 +553,7 @@ func.func @nvvm_invalid_mma_1(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                          %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
   // expected-error at +1 {{Could not match allowed types for the result; expected one of !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> but got !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>}}
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
-    {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
+    {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
 }
 
@@ -565,7 +565,7 @@ func.func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                          %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
   // expected-error at +1 {{op requires attribute 'layoutA'}}
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] 
-    {shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+    {shape = #nvvm.shape<m = 8, n = 8, k = 4>}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
@@ -575,7 +575,7 @@ func.func @nvvm_invalid_mma_3(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                          %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                          %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
   // expected-error at +1 {{unimplemented variant for MMA shape <8, 8, 16>}}
-  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
@@ -588,7 +588,7 @@ func.func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,     
-     shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c978d773d9591..283127382df1b 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -85,7 +85,7 @@ func.func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf
   // CHECK: nvvm.mma.sync
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+     shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
@@ -96,19 +96,19 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}]
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+     shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 }
 
 // CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
                              %c0 : i32, %c1 : i32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 8, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
   %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
      intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
-     shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
+     shape = #nvvm.shape<m = 8, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
   llvm.return %0 : !llvm.struct<(i32, i32)>
 }
 
@@ -119,7 +119,7 @@ func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   // CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
@@ -128,10 +128,10 @@ func.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                                 %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
@@ -140,10 +140,10 @@ func.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                                 %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
   %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
@@ -152,10 +152,10 @@ func.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                                 %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
@@ -164,10 +164,10 @@ func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                                 %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
   %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
@@ -175,23 +175,23 @@ func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
 func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
                                      %b0 : i32,
                                      %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 4>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
-     shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 4>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
 // CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8
 func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
      intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
-     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
@@ -199,12 +199,12 @@ func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
 func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
                                 %b0 : i32,
                                 %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
      intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
-     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
@@ -212,11 +212,11 @@ func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
 func.func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
                                %b0 : i32, %b1 : i32,
                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = #nvvm.shape<m = 16, n = 8, k = 256>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
   %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
-     b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     b1Op = #nvvm.mma_b1op<xor_popc>, shape = #nvvm.shape<m = 16, n = 8, k = 256>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
@@ -224,12 +224,12 @@ func.func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
 func.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
                                %b0 : i32,
                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
      b1Op = #nvvm.mma_b1op<xor_popc>,
-     shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
@@ -237,11 +237,11 @@ func.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
 func.func @nvvm_mma_m8n8k128_b1_b1(%a0 : i32,
                               %b0 : i32,
                               %c0 : i32, %c1 : i32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = #nvvm.shape<m = 8, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
   %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
-     b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32)>
+     b1Op = #nvvm.mma_b1op<xor_popc>, shape = #nvvm.shape<m = 8, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32)>
 }
 
@@ -249,12 +249,12 @@ func.func @nvvm_mma_m8n8k128_b1_b1(%a0 : i32,
 func.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
                                %b0 : i32,
                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
-  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
      intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
-     shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index a66560d0e0da8..a90fabe3461d9 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -107,7 +107,7 @@ llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2x
                     %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> {
   // CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
-  {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
@@ -118,7 +118,7 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
   // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f16
   %0 = nvvm.mma.sync A[ %a0, %a1, %a2, %a3 ] B[ %b0, %b1 ] C[ %c0, %c1 ]
-    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}}
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>}
      : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
@@ -132,7 +132,7 @@ llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16
   %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
@@ -145,7 +145,7 @@ llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32
   %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
@@ -158,7 +158,7 @@ llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f32
   %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
@@ -171,7 +171,7 @@ llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32,
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
      intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
-     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
@@ -184,7 +184,7 @@ llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
      intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
-     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
@@ -196,7 +196,7 @@ llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
-     b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     b1Op = #nvvm.mma_b1op<xor_popc>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
@@ -209,7 +209,7 @@ llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
      intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
-     shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
@@ -220,7 +220,7 @@ llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
   // CHECK: call { double, double } @llvm.nvvm.mma.m8n8k4.row.col.f64
   %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (f64, f64, f64) -> !llvm.struct<(f64, f64)>
+     shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (f64, f64, f64) -> !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
 
@@ -232,7 +232,7 @@ llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
      multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
-     shape = {m = 16 : i32, n = 8 : i32, k = 4 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+     shape = #nvvm.shape<m = 16, n = 8, k = 4>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 


        


More information about the Mlir-commits mailing list