[Mlir-commits] [mlir] 916f3e1 - [mlir][vector][avx] add AVX dot product to X86Vector dialect with lowering

Aart Bik llvmlistbot at llvm.org
Thu Apr 15 15:01:47 PDT 2021


Author: Aart Bik
Date: 2021-04-15T15:01:39-07:00
New Revision: 916f3e16bd4d9a7c6aca94cc5d0cf5ee55e3a3cb

URL: https://github.com/llvm/llvm-project/commit/916f3e16bd4d9a7c6aca94cc5d0cf5ee55e3a3cb
DIFF: https://github.com/llvm/llvm-project/commit/916f3e16bd4d9a7c6aca94cc5d0cf5ee55e3a3cb.diff

LOG: [mlir][vector][avx] add AVX dot product to X86Vector dialect with lowering

In the long run, we want to unify the dot product codegen solutions between
all target architectures, but this intrinsic enables experimenting with AVX
specific implementations in the meantime.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D100593

Added: 
    mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-dot.mlir

Modified: 
    mlir/include/mlir/Dialect/X86Vector/X86Vector.td
    mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
    mlir/test/Dialect/X86Vector/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 9f5e7577d2d32..e70ad498cdf0f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -29,9 +29,11 @@ def X86Vector_Dialect : Dialect {
 // AVX512 op definitions
 //===----------------------------------------------------------------------===//
 
+// Operation that is part of the input dialect.
 class AVX512_Op<string mnemonic, list<OpTrait> traits = []> :
   Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
 
+// Intrinsic operation used during lowering to LLVM IR.
 class AVX512_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
   LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
                   "x86_avx512_" # !subst(".", "_", mnemonic),
@@ -46,6 +48,7 @@ class AVX512_IntrOverloadedOp<string mnemonic,
                   /*list<int> overloadedResults=*/[0],
                   /*list<int> overloadedOperands=*/[],
                   traits, /*numResults=*/1>;
+
 //----------------------------------------------------------------------------//
 // MaskCompressOp
 //----------------------------------------------------------------------------//
@@ -271,9 +274,17 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
 // AVX op definitions
 //===----------------------------------------------------------------------===//
 
+// Operation that is part of the input dialect.
 class AVX_Op<string mnemonic, list<OpTrait> traits = []> :
   Op<X86Vector_Dialect, "avx." # mnemonic, traits> {}
 
+// Operation that may be part of the input dialect, but whose
+// form is somewhere between the user view of the operation
+// and the actual lower level intrinsic in LLVM IR.
+class AVX_LowOp<string mnemonic, list<OpTrait> traits = []> :
+  Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
+
+// Intrinsic operation used during lowering to LLVM IR.
 class AVX_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
   LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
                   "x86_avx_" # !subst(".", "_", mnemonic),
@@ -295,4 +306,39 @@ def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [NoSideEffect,
   let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
 }
 
+//----------------------------------------------------------------------------//
+// AVX Dot
+//----------------------------------------------------------------------------//
+
+def DotOp : AVX_LowOp<"dot", [NoSideEffect, SameOperandsAndResultType]> {
+  let summary = "Dot";
+  let description = [{
+    Computes the 4-way dot products of the lower and higher parts of the source
+    vectors and broadcasts the two results to the lower and higher elements of
+    the destination vector, respectively. Adding one element of the lower part
+    to one element of the higher part in the destination vector yields the full
+    dot product of the two source vectors.
+
+    Example:
+
+    ```mlir
+    %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
+    %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32>
+    %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32>
+    %d = addf %1, %2 : f32
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
+                       VectorOfLengthAndType<[8], [F32]>:$b);
+  let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
+  let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
+}
+
+def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [NoSideEffect,
+    AllTypesMatch<["a", "b", "res"]>]> {
+  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
+                       VectorOfLengthAndType<[8], [F32]>:$b, I8:$c);
+  let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
+}
+
 #endif // X86VECTOR_OPS

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 9e2a743450ff0..ab9e6ff3177cd 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -104,6 +104,25 @@ struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
   }
 };
 
+struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
+  using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(DotOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    DotOp::Adaptor adaptor(operands);
+    auto opType = adaptor.a().getType();
+    Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
+    // Dot product of all elements, broadcasted to all elements.
+    auto attr = rewriter.getI8IntegerAttr(0xff);
+    Value scale =
+        rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
+    rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(),
+                                           scale);
+    return success();
+  }
+};
+
 /// An entry associating the "main" AVX512 op with its instantiations for
 /// vectors of 32-bit and 64-bit elements.
 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
@@ -145,7 +164,8 @@ using Registry = RegistryImpl<
 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   Registry::registerPatterns(converter, patterns);
-  patterns.add<MaskCompressOpConversion, RsqrtOpConversion>(converter);
+  patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
+      converter);
 }
 
 void mlir::configureX86VectorLegalizeForExportTarget(
@@ -155,4 +175,6 @@ void mlir::configureX86VectorLegalizeForExportTarget(
   target.addIllegalOp<MaskCompressOp>();
   target.addLegalOp<RsqrtIntrOp>();
   target.addIllegalOp<RsqrtOp>();
+  target.addLegalOp<DotIntrOp>();
+  target.addIllegalOp<DotOp>();
 }

diff  --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 6f23153d41db2..d25d48e4e7ef3 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -50,3 +50,11 @@ func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
   %0 = x86vector.avx.rsqrt %a : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: func @avx_dot
+func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
+{
+  // CHECK: x86vector.avx.intr.dp.ps.256
+  %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
+  return %0 : vector<8xf32>
+}

diff  --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 4dfd934c59385..fdf7ee1366e04 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -54,3 +54,11 @@ func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
   %0 = x86vector.avx.rsqrt %a : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: func @avx_dot
+func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
+{
+  // CHECK: x86vector.avx.intr.dot {{.*}} : vector<8xf32>
+  %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
+  return %0 : vector<8xf32>
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-dot.mlir
new file mode 100644
index 0000000000000..24fa01f609cdf
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-dot.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-x86vector" -convert-std-to-llvm | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @entry() -> i32 {
+  %i0 = constant 0 : i32
+  %i4 = constant 4 : i32
+
+  %a = std.constant dense<[1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0]> : vector<8xf32>
+  %b = std.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32>
+  %r = x86vector.avx.intr.dot %a, %b : vector<8xf32>
+
+  %1 = vector.extractelement %r[%i0 : i32]: vector<8xf32>
+  %2 = vector.extractelement %r[%i4 : i32]: vector<8xf32>
+  %d = addf %1, %2 : f32
+
+  // CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 )
+  // CHECK: 492
+  vector.print %r : vector<8xf32>
+  vector.print %d : f32
+
+  return %i0 : i32
+}


        


More information about the Mlir-commits mailing list