[Mlir-commits] [mlir] cdb6eb7 - Update syntax for amx.tile_muli to use two Unit attr to mark the zext case

Mehdi Amini llvmlistbot at llvm.org
Fri Mar 19 21:12:34 PDT 2021


Author: Mehdi Amini
Date: 2021-03-20T04:12:24Z
New Revision: cdb6eb7e8372027e74d6b0fb1258fff37e2b3b5a

URL: https://github.com/llvm/llvm-project/commit/cdb6eb7e8372027e74d6b0fb1258fff37e2b3b5a
DIFF: https://github.com/llvm/llvm-project/commit/cdb6eb7e8372027e74d6b0fb1258fff37e2b3b5a.diff

LOG: Update syntax for amx.tile_muli to use two Unit attr to mark the zext case

This makes the annotation tied to the operand and the use of a keyword
more explicit/readable on what it means.

Differential Revision: https://reviews.llvm.org/D99001

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AMX/AMX.td
    mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
    mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/AMX/invalid.mlir
    mlir/test/Dialect/AMX/legalize-for-llvm.mlir
    mlir/test/Dialect/AMX/roundtrip.mlir
    mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir
    mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 45c63a99e670..24052ed4f24d 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -196,14 +196,14 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
     into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
     combinations (4 bytes packed into dwords in the columns of both the
     source operand tiles; the zero or sign extension is specified with
-    the attributes). The operation is eventually lowered into one of
-    the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with
-    the corresponding tile configuration.
+    the attributes and default to sign extended). The operation is eventually
+    lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
+    instructions with the corresponding tile configuration.
 
     Example:
 
     ```mlir
-      %0 = amx.tile_muli %a, %b, %c [true, true]
+      %0 = amx.tile_muli %a zext, %b zext, %c 
         : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
     ```
   }];
@@ -211,7 +211,9 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
   let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
                        VectorOfRankAndType<[2], [I32, I8]>:$rhs,
                        VectorOfRankAndType<[2], [I32, I8]>:$acc,
-                       BoolArrayAttr:$zext);
+                       UnitAttr:$isZextLhs,
+                       UnitAttr:$isZextRhs
+                       );
   let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
   let extraClassDeclaration = [{
     VectorType getLhsVectorType() {
@@ -224,7 +226,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
       return res().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` "
+  let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
                        "type($lhs) `,` type($rhs) `,` type($acc) ";
 }
 

diff  --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 5ebef7efe213..ab98820b2ecb 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -85,8 +85,6 @@ static LogicalResult verify(amx::TileMulFOp op) {
 }
 
 static LogicalResult verify(amx::TileMulIOp op) {
-  if (op.zext().size() != 2)
-    return op.emitOpError("unexpected zext length");
   VectorType aType = op.getLhsVectorType();
   VectorType bType = op.getRhsVectorType();
   VectorType cType = op.getVectorType();

diff  --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 6e082ce790fc..7db57d383ba3 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -191,8 +191,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
     // Replace operation with intrinsic.
     Type resType = typeConverter->convertType(cType);
-    bool zexta = op.zext()[0].cast<BoolAttr>().getValue();
-    bool zextb = op.zext()[1].cast<BoolAttr>().getValue();
+    bool zexta = op.isZextLhs();
+    bool zextb = op.isZextRhs();
     if (zexta && zextb)
       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
           op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),

diff  --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index b3a7286b526a..6f147cf2851e 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -46,13 +46,3 @@ func @multsize() {
   // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
   %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
 }
-
-// -----
-
-func @zextsize() {
-  %0 = amx.tile_zero : vector<8x8xi8>
-  %1 = amx.tile_zero : vector<8x8xi8>
-  %2 = amx.tile_zero : vector<8x8xi32>
-  // expected-error at +1 {{'amx.tile_muli' op unexpected zext length}}
-  %3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32>
-}

diff  --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index f88d83d8f311..37382b34972d 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -17,13 +17,13 @@ func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %1 = amx.tile_zero : vector<16x64xi8>
   %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
   %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
-  %5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
-  %6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, vector<16x16xi32>
   return
 }

diff  --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index 98b8024c194d..93f3ea4a2977 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -28,14 +28,22 @@ func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
 // CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
 // CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
 // CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
 // CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+// Verify the parsing/printing of the sign-extension annotation.
+// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
+// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
+// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
 func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
   %0 = constant 0 : index
   %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
   %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+  // Verify the various `zext` combinations.
+  %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   return
 }

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir
index dee283c68212..45e9816fa9d6 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir
@@ -24,7 +24,7 @@ func @kernel1(%arg0: memref<16x16xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
   %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  %4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
@@ -36,7 +36,7 @@ func @kernel2(%arg0: memref<16x16xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
   %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  %4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
@@ -48,7 +48,7 @@ func @kernel3(%arg0: memref<16x16xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
   %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  %4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
@@ -60,7 +60,7 @@ func @kernel4(%arg0: memref<16x16xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
   %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
index a52f66c640f8..df848a04eae7 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
@@ -13,7 +13,7 @@ func @kernel1(%arg0: memref<2x8xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
   %3 = amx.tile_zero : vector<2x2xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
   return
 }
@@ -26,7 +26,7 @@ func @kernel2(%arg0: memref<2x8xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
   %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
   return
 }


        


More information about the Mlir-commits mailing list