[Mlir-commits] [mlir] cdb6eb7 - Update syntax for amx.tile_muli to use two Unit attr to mark the zext case
Mehdi Amini
llvmlistbot at llvm.org
Fri Mar 19 21:12:34 PDT 2021
Author: Mehdi Amini
Date: 2021-03-20T04:12:24Z
New Revision: cdb6eb7e8372027e74d6b0fb1258fff37e2b3b5a
URL: https://github.com/llvm/llvm-project/commit/cdb6eb7e8372027e74d6b0fb1258fff37e2b3b5a
DIFF: https://github.com/llvm/llvm-project/commit/cdb6eb7e8372027e74d6b0fb1258fff37e2b3b5a.diff
LOG: Update syntax for amx.tile_muli to use two Unit attr to mark the zext case
This makes the annotation tied to the operand and the use of a keyword
more explicit/readable on what it means.
Differential Revision: https://reviews.llvm.org/D99001
Added:
Modified:
mlir/include/mlir/Dialect/AMX/AMX.td
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/AMX/invalid.mlir
mlir/test/Dialect/AMX/legalize-for-llvm.mlir
mlir/test/Dialect/AMX/roundtrip.mlir
mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir
mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 45c63a99e670..24052ed4f24d 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -196,14 +196,14 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
combinations (4 bytes packed into dwords in the columns of both the
source operand tiles; the zero or sign extension is specified with
- the attributes). The operation is eventually lowered into one of
- the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with
- the corresponding tile configuration.
+ the attributes and default to sign extended). The operation is eventually
+ lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
+ instructions with the corresponding tile configuration.
Example:
```mlir
- %0 = amx.tile_muli %a, %b, %c [true, true]
+ %0 = amx.tile_muli %a zext, %b zext, %c
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
```
}];
@@ -211,7 +211,9 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
VectorOfRankAndType<[2], [I32, I8]>:$acc,
- BoolArrayAttr:$zext);
+ UnitAttr:$isZextLhs,
+ UnitAttr:$isZextRhs
+ );
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
let extraClassDeclaration = [{
VectorType getLhsVectorType() {
@@ -224,7 +226,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
return res().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` "
+ let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
}
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 5ebef7efe213..ab98820b2ecb 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -85,8 +85,6 @@ static LogicalResult verify(amx::TileMulFOp op) {
}
static LogicalResult verify(amx::TileMulIOp op) {
- if (op.zext().size() != 2)
- return op.emitOpError("unexpected zext length");
VectorType aType = op.getLhsVectorType();
VectorType bType = op.getRhsVectorType();
VectorType cType = op.getVectorType();
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 6e082ce790fc..7db57d383ba3 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -191,8 +191,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
// Replace operation with intrinsic.
Type resType = typeConverter->convertType(cType);
- bool zexta = op.zext()[0].cast<BoolAttr>().getValue();
- bool zextb = op.zext()[1].cast<BoolAttr>().getValue();
+ bool zexta = op.isZextLhs();
+ bool zextb = op.isZextRhs();
if (zexta && zextb)
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index b3a7286b526a..6f147cf2851e 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -46,13 +46,3 @@ func @multsize() {
// expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
%3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
}
-
-// -----
-
-func @zextsize() {
- %0 = amx.tile_zero : vector<8x8xi8>
- %1 = amx.tile_zero : vector<8x8xi8>
- %2 = amx.tile_zero : vector<8x8xi32>
- // expected-error at +1 {{'amx.tile_muli' op unexpected zext length}}
- %3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32>
-}
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index f88d83d8f311..37382b34972d 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -17,13 +17,13 @@ func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
%1 = amx.tile_zero : vector<16x64xi8>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
- %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
- %5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
- %6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
- %7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, vector<16x16xi32>
return
}
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index 98b8024c194d..93f3ea4a2977 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -28,14 +28,22 @@ func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+// Verify the parsing/printing of the sign-extension annotation.
+// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
+// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
+// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
%0 = constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
- %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+ // Verify the various `zext` combinations.
+ %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir
index dee283c68212..45e9816fa9d6 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir
@@ -24,7 +24,7 @@ func @kernel1(%arg0: memref<16x16xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
- %4 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+ %4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
@@ -36,7 +36,7 @@ func @kernel2(%arg0: memref<16x16xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
- %4 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+ %4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
@@ -48,7 +48,7 @@ func @kernel3(%arg0: memref<16x16xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
- %4 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+ %4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
@@ -60,7 +60,7 @@ func @kernel4(%arg0: memref<16x16xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
- %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
index a52f66c640f8..df848a04eae7 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
@@ -13,7 +13,7 @@ func @kernel1(%arg0: memref<2x8xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%3 = amx.tile_zero : vector<2x2xi32>
- %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
return
}
@@ -26,7 +26,7 @@ func @kernel2(%arg0: memref<2x8xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
- %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
return
}
More information about the Mlir-commits
mailing list