[Mlir-commits] [mlir] 47232be - [mlir][spirv] Fix extended umul expansion for WebGPU

Jakub Kuderski llvmlistbot at llvm.org
Thu Jan 5 15:41:57 PST 2023


Author: Jakub Kuderski
Date: 2023-01-05T18:41:26-05:00
New Revision: 47232bea9e5ecc5ad999f53ec2895d2c07427572

URL: https://github.com/llvm/llvm-project/commit/47232bea9e5ecc5ad999f53ec2895d2c07427572
DIFF: https://github.com/llvm/llvm-project/commit/47232bea9e5ecc5ad999f53ec2895d2c07427572.diff

LOG: [mlir][spirv] Fix extended umul expansion for WebGPU

Fix an off-by-one error in extended umul extension for WebGPU.
Revert to the long multiplication algorithm originally added to wide
integer emulation, which was deleted in D139776. It is much easier
to see why it is correct.

Add runtime tests based on the mlir-vulkan-runner. These run both with
and without umul extension.

Issue: https://github.com/llvm/llvm-project/issues/59563

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D141085

Added: 
    mlir/test/mlir-vulkan-runner/umul_extended.mlir

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
    mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
    mlir/tools/mlir-vulkan-runner/CMakeLists.txt
    mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 6cf127e17e7db..5f8426b8871ab 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -17,8 +17,13 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 
+#include <array>
+#include <cstdint>
+
 namespace mlir {
 namespace spirv {
 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
@@ -61,41 +66,62 @@ struct ExpandUMulExtendedPattern final : OpRewritePattern<UMulExtendedOp> {
           loc,
           llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
 
-    // Calculate the 'low' and the 'high' result separately, using long
-    // multiplication:
-    //
-    // lhs = [0   0]  [a   b]
-    // rhs = [0   0]  [c   d]
-    // --lhs * rhs--
-    // =     [    a * c    ]   [    b * d    ] +
-    //       [ 0 ]    [a * d + b * c]    [ 0 ]
+    // Emulate 64-bit multiplication by splitting each input element of type i32
+    // into 2 16-bit digits of type i32. This is so that the intermediate
+    // multiplications and additions do not overflow. We extract these 16-bit
+    // digits from i32 vector elements by masking (low digit) and shifting right
+    // (high digit).
     //
-    // ==> high = (a * c) + (a * d + b * c) >> 16
-    Value low = rewriter.create<IMulOp>(loc, lhs, rhs);
-
+    // The multiplication algorithm used is the standard (long) multiplication.
+    // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
+    // digits. After constant-folding, we end up emitting only 4 multiplications
+    // and 4 additions.
     Value cstLowMask = rewriter.create<ConstantOp>(
         loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
-    auto getLowHalf = [&rewriter, loc, cstLowMask](Value val) {
+    auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
       return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
     };
 
     Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
                                               getScalarOrSplatAttr(argTy, 16));
-    auto getHighHalf = [&rewriter, loc, cst16](Value val) {
+    auto getHighDigit = [&rewriter, loc, cst16](Value val) {
       return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
     };
 
-    Value lhsLow = getLowHalf(lhs);
-    Value lhsHigh = getHighHalf(lhs);
-    Value rhsLow = getLowHalf(rhs);
-    Value rhsHigh = getHighHalf(rhs);
-
-    Value high0 = rewriter.create<IMulOp>(loc, lhsHigh, rhsHigh);
-    Value mid = rewriter.create<IAddOp>(
-        loc, rewriter.create<IMulOp>(loc, lhsHigh, rhsLow),
-        rewriter.create<IMulOp>(loc, lhsLow, rhsHigh));
-    Value high1 = getHighHalf(mid);
-    Value high = rewriter.create<IAddOp>(loc, high0, high1);
+    Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
+                                             getScalarOrSplatAttr(argTy, 0));
+
+    Value lhsLow = getLowDigit(lhs);
+    Value lhsHigh = getHighDigit(lhs);
+    Value rhsLow = getLowDigit(rhs);
+    Value rhsHigh = getHighDigit(rhs);
+
+    std::array<Value, 2> lhsDigits = {lhsLow, lhsHigh};
+    std::array<Value, 2> rhsDigits = {rhsLow, rhsHigh};
+    std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
+
+    for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
+      for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
+        Value &thisResDigit = resultDigits[i + j];
+        Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
+        Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
+        thisResDigit = getLowDigit(current);
+
+        if (i + j + 1 != resultDigits.size()) {
+          Value &nextResDigit = resultDigits[i + j + 1];
+          Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
+                                                      getHighDigit(current));
+          nextResDigit = carry;
+        }
+      }
+    }
+
+    auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
+      Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
+      return rewriter.create<BitwiseOrOp>(loc, low, highBits);
+    };
+    Value low = combineDigits(resultDigits[0], resultDigits[1]);
+    Value high = combineDigits(resultDigits[2], resultDigits[3]);
 
     rewriter.replaceOpWithNewOp<CompositeConstructOp>(
         op, op.getType(), llvm::makeArrayRef({low, high}));

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index e1899ee975626..d0720a3bbd4d3 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --split-input-file --verify-diagnostics --spirv-webgpu-prepare %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics \
+// RUN:   --spirv-webgpu-prepare --cse %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // spirv.UMulExtended
@@ -10,18 +11,23 @@ spirv.module Logical GLSL450 {
 // CHECK-SAME:       ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32)
 // CHECK-DAG:        [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
 // CHECK-DAG:        [[CST16:%.+]]   = spirv.Constant 16 : i32
-// CHECK-NEXT:       [[RESLOW:%.+]]  = spirv.IMul [[ARG0]], [[ARG1]] : i32
 // CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32
 // CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32
 // CHECK-NEXT:       [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32
 // CHECK-NEXT:       [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32
-// CHECK-DAG:        [[RESHI0:%.+]]  = spirv.IMul [[LHSHI]], [[RHSHI]] : i32
-// CHECK-DAG:        [[MID0:%.+]]    = spirv.IMul [[LHSHI]], [[RHSLOW]] : i32
-// CHECK-DAG:        [[MID1:%.+]]    = spirv.IMul [[LHSLOW]], [[RHSHI]] : i32
-// CHECK-NEXT:       [[MID:%.+]]     = spirv.IAdd [[MID0]], [[MID1]] : i32
-// CHECK-NEXT:       [[RESHI1:%.+]]  = spirv.ShiftRightLogical [[MID]], [[CST16]] : i32
-// CHECK-NEXT:       [[RESHI:%.+]]   = spirv.IAdd [[RESHI0]], [[RESHI1]] : i32
-// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLOW]], [[RESHI]] : (i32, i32) -> !spirv.struct<(i32, i32)>
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSHI]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSHI]]
+// CHECK-DAG:                          spirv.IAdd
+// CHECK-DAG:                          spirv.IAdd
+// CHECK-DAG:                          spirv.IAdd
+// CHECK-DAG:                          spirv.IAdd
+// CHECK:                              spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
+// CHECK:                              spirv.BitwiseOr
+// CHECK:                              spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
+// CHECK:                              spirv.BitwiseOr
+// CHECK:            [[RES:%.+]]     = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
 // CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
 spirv.func @umul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" {
   %0 = spirv.UMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)>
@@ -32,18 +38,23 @@ spirv.func @umul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i
 // CHECK-SAME:       ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>)
 // CHECK-DAG:        [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
 // CHECK-DAG:        [[CST16:%.+]]   = spirv.Constant dense<16> : vector<3xi32>
-// CHECK-NEXT:       [[RESLOW:%.+]]  = spirv.IMul [[ARG0]], [[ARG1]] : vector<3xi32>
 // CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32>
 // CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32>
 // CHECK-NEXT:       [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32>
 // CHECK-NEXT:       [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32>
-// CHECK-DAG:        [[RESHI0:%.+]]  = spirv.IMul [[LHSHI]], [[RHSHI]] : vector<3xi32>
-// CHECK-DAG:        [[MID0:%.+]]    = spirv.IMul [[LHSHI]], [[RHSLOW]] : vector<3xi32>
-// CHECK-DAG:        [[MID1:%.+]]    = spirv.IMul [[LHSLOW]], [[RHSHI]] : vector<3xi32>
-// CHECK-NEXT:       [[MID:%.+]]     = spirv.IAdd [[MID0]], [[MID1]] : vector<3xi32>
-// CHECK-NEXT:       [[RESHI1:%.+]]  = spirv.ShiftRightLogical [[MID]], [[CST16]] : vector<3xi32>
-// CHECK-NEXT:       [[RESHI:%.+]]   = spirv.IAdd [[RESHI0]], [[RESHI1]] : vector<3xi32>
-// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLOW]], [[RESHI]]
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSHI]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSHI]]
+// CHECK-DAG:                          spirv.IAdd
+// CHECK-DAG:                          spirv.IAdd
+// CHECK-DAG:                          spirv.IAdd
+// CHECK-DAG:                          spirv.IAdd
+// CHECK:                              spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
+// CHECK:                              spirv.BitwiseOr
+// CHECK:                              spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
+// CHECK:                              spirv.BitwiseOr
+// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLOW:%.+]], [[RESHI:%.+]]
 // CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 spirv.func @umul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>)
   -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {

diff  --git a/mlir/test/mlir-vulkan-runner/umul_extended.mlir b/mlir/test/mlir-vulkan-runner/umul_extended.mlir
new file mode 100644
index 0000000000000..5e7b020d13709
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/umul_extended.mlir
@@ -0,0 +1,66 @@
+// Make sure that unsigned extended multiplication produces expected results
+// with and without expansion to primitive mul/add ops for WebGPU.
+
+// RUN: mlir-vulkan-runner %s \
+// RUN:  --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
+// RUN:  --entry-point-result=void | FileCheck %s
+
+// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
+// RUN:  --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
+// RUN:  --entry-point-result=void | FileCheck %s
+
+// CHECK: [0, 1, -2,  1, 1048560, -87620295, -131071, -49]
+// CHECK: [0, 0,  1, -2,       0,     65534, -131070,   6]
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+  gpu.module @kernels {
+    gpu.func @kernel_add(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %arg2 : memref<8xi32>, %arg3 : memref<8xi32>)
+      kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+      %0 = gpu.block_id x
+      %lhs = memref.load %arg0[%0] : memref<8xi32>
+      %rhs = memref.load %arg1[%0] : memref<8xi32>
+      %low, %hi = arith.mului_extended %lhs, %rhs : i32
+      memref.store %low, %arg2[%0] : memref<8xi32>
+      memref.store %hi, %arg3[%0] : memref<8xi32>
+      gpu.return
+    }
+  }
+
+  func.func @main() {
+    %buf0 = memref.alloc() : memref<8xi32>
+    %buf1 = memref.alloc() : memref<8xi32>
+    %buf2 = memref.alloc() : memref<8xi32>
+    %buf3 = memref.alloc() : memref<8xi32>
+    %i32_0 = arith.constant 0 : i32
+
+    // Initialize output buffers.
+    %buf4 = memref.cast %buf2 : memref<8xi32> to memref<?xi32>
+    %buf5 = memref.cast %buf3 : memref<8xi32> to memref<?xi32>
+    call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
+    call @fillResource1DInt(%buf5, %i32_0) : (memref<?xi32>, i32) -> ()
+
+    %idx_0 = arith.constant 0 : index
+    %idx_1 = arith.constant 1 : index
+    %idx_8 = arith.constant 8 : index
+
+    // Initialize input buffers.
+    %lhs_vals = arith.constant dense<[0, 1, -1,  -1,  65535,  65535, -65535,  7]> : vector<8xi32>
+    %rhs_vals = arith.constant dense<[0, 1,  2,  -1,     16,  -1337, -65535, -7]> : vector<8xi32>
+    vector.store %lhs_vals, %buf0[%idx_0] : memref<8xi32>, vector<8xi32>
+    vector.store %rhs_vals, %buf1[%idx_0] : memref<8xi32>, vector<8xi32>
+
+    gpu.launch_func @kernels::@kernel_add
+        blocks in (%idx_8, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
+        args(%buf0 : memref<8xi32>, %buf1 : memref<8xi32>, %buf2 : memref<8xi32>, %buf3 : memref<8xi32>)
+    %buf_low = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
+    %buf_hi = memref.cast %buf5 : memref<?xi32> to memref<*xi32>
+    call @printMemrefI32(%buf_low) : (memref<*xi32>) -> ()
+    call @printMemrefI32(%buf_hi) : (memref<*xi32>) -> ()
+    return
+  }
+  func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+  func.func private @printMemrefI32(%ptr : memref<*xi32>)
+}

diff  --git a/mlir/tools/mlir-vulkan-runner/CMakeLists.txt b/mlir/tools/mlir-vulkan-runner/CMakeLists.txt
index cfb443a44fdbd..4f64117c296f9 100644
--- a/mlir/tools/mlir-vulkan-runner/CMakeLists.txt
+++ b/mlir/tools/mlir-vulkan-runner/CMakeLists.txt
@@ -74,6 +74,8 @@ if (MLIR_ENABLE_VULKAN_RUNNER)
     MLIRTargetLLVMIRExport
     MLIRTransforms
     MLIRTranslateLib
+    MLIRVectorDialect
+    MLIRVectorToLLVM
     ${Vulkan_LIBRARY}
   )
 

diff  --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 9201f500c4317..a1a0bcb885504 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -13,12 +13,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
-#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -30,18 +30,28 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/ExecutionEngine/JitRunner.h"
-#include "mlir/ExecutionEngine/OptUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Export.h"
 #include "llvm/Support/InitLLVM.h"
 #include "llvm/Support/TargetSelect.h"
 
 using namespace mlir;
 
-static LogicalResult runMLIRPasses(Operation *op, JitRunnerOptions &options) {
+namespace {
+struct VulkanRunnerOptions {
+  llvm::cl::OptionCategory category{"mlir-vulkan-runner options"};
+  llvm::cl::opt<bool> spirvWebGPUPrepare{
+      "vulkan-runner-spirv-webgpu-prepare",
+      llvm::cl::desc("Run MLIR transforms used when targetting WebGPU"),
+      llvm::cl::cat(category)};
+};
+} // namespace
+
+static LogicalResult runMLIRPasses(Operation *op,
+                                   VulkanRunnerOptions &options) {
   auto module = dyn_cast<ModuleOp>(op);
   if (!module)
     return op->emitOpError("expected a 'builtin.module' op");
@@ -55,10 +65,13 @@ static LogicalResult runMLIRPasses(Operation *op, JitRunnerOptions &options) {
   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
   modulePM.addPass(spirv::createLowerABIAttributesPass());
   modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
+  if (options.spirvWebGPUPrepare)
+    modulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
 
   passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
   LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module));
   passManager.addPass(createMemRefToLLVMConversionPass());
+  passManager.addPass(createConvertVectorToLLVMPass());
   passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
   passManager.addPass(createConvertFuncToLLVMPass(llvmOptions));
   passManager.addPass(createReconcileUnrealizedCastsPass());
@@ -75,13 +88,21 @@ int main(int argc, char **argv) {
   llvm::InitializeNativeTarget();
   llvm::InitializeNativeTargetAsmPrinter();
 
+  // Initialize runner-specific CLI options. These will be parsed and
+  // initialzied in `JitRunnerMain`.
+  VulkanRunnerOptions options;
+  auto runPassesWithOptions = [&options](Operation *op, JitRunnerOptions &) {
+    return runMLIRPasses(op, options);
+  };
+
   mlir::JitRunnerConfig jitRunnerConfig;
-  jitRunnerConfig.mlirTransformer = runMLIRPasses;
+  jitRunnerConfig.mlirTransformer = runPassesWithOptions;
 
   mlir::DialectRegistry registry;
   registry.insert<mlir::arith::ArithDialect, mlir::LLVM::LLVMDialect,
                   mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect,
-                  mlir::func::FuncDialect, mlir::memref::MemRefDialect>();
+                  mlir::func::FuncDialect, mlir::memref::MemRefDialect,
+                  mlir::vector::VectorDialect>();
   mlir::registerLLVMDialectTranslation(registry);
 
   return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e64cbb1e6dcdd..b4a1b7691eb4b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7303,6 +7303,8 @@ cc_binary(
         ":SPIRVDialect",
         ":SPIRVTransforms",
         ":ToLLVMIRTranslation",
+        ":VectorDialect",
+        ":VectorToLLVM",
         "//llvm:Support",
     ],
 )


        


More information about the Mlir-commits mailing list