[Mlir-commits] [mlir] acb69f3 - [mlir] Change ConvertOpToLLVMPattern::matchAndRewrite argument to concrete operand type.

Christian Sigg llvmlistbot at llvm.org
Sat Nov 28 04:09:36 PST 2020


Author: Christian Sigg
Date: 2020-11-28T13:09:25+01:00
New Revision: acb69f3b7c83f411c08b77d75f2e812faf3cb83f

URL: https://github.com/llvm/llvm-project/commit/acb69f3b7c83f411c08b77d75f2e812faf3cb83f
DIFF: https://github.com/llvm/llvm-project/commit/acb69f3b7c83f411c08b77d75f2e812faf3cb83f.diff

LOG: [mlir] Change ConvertOpToLLVMPattern::matchAndRewrite argument to concrete operand type.

Reviewed By: herhut, ftynse

Differential Revision: https://reviews.llvm.org/D92111

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
    mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/lib/Transforms/TestConvertCallOp.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 919a93ac84a2..70db4c1510bf 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -564,14 +564,47 @@ class ConvertToLLVMPattern : public ConversionPattern {
 
 /// Utility class for operation conversions targeting the LLVM dialect that
 /// match exactly one source operation.
-template <typename OpTy>
+template <typename SourceOp>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
   ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
                          PatternBenefit benefit = 1)
-      : ConvertToLLVMPattern(OpTy::getOperationName(),
+      : ConvertToLLVMPattern(SourceOp::getOperationName(),
                              &typeConverter.getContext(), typeConverter,
                              benefit) {}
+
+  /// Wrappers around the RewritePattern methods that pass the derived op type.
+  void rewrite(Operation *op, ArrayRef<Value> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    rewrite(cast<SourceOp>(op), operands, rewriter);
+  }
+  LogicalResult match(Operation *op) const final {
+    return match(cast<SourceOp>(op));
+  }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+  }
+
+  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// overridden by the derived pattern class.
+  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    llvm_unreachable("must override rewrite or matchAndRewrite");
+  }
+  virtual LogicalResult match(SourceOp op) const {
+    llvm_unreachable("must override match or matchAndRewrite");
+  }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    if (succeeded(match(op))) {
+      rewrite(op, operands, rewriter);
+      return success();
+    }
+    return failure();
+  }
 };
 
 namespace LLVM {
@@ -604,7 +637,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   /// Converts the type of the result to an LLVM type, pass operands as is,
   /// preserve attributes.
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
                                          operands, this->typeConverter,
@@ -621,7 +654,7 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     static_assert(
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,

diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index d625db95e976..cb7644cb7202 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -163,7 +163,7 @@ class ConvertHostRegisterOpToGpuRuntimeCallPattern
 
 private:
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -205,7 +205,7 @@ class ConvertWaitOpToGpuRuntimeCallPattern
 
 private:
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -219,7 +219,7 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern
 
 private:
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -251,7 +251,7 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
                                    Location loc, OpBuilder &builder) const;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 
   llvm::SmallString<32> gpuBinaryAnnotation;
@@ -321,14 +321,15 @@ isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
 }
 
 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
+    gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
+  auto *op = hostRegisterOp.getOperation();
   if (failed(areAllLLVMTypes(op, operands, rewriter)))
     return failure();
 
   Location loc = op->getLoc();
 
-  auto memRefType = cast<gpu::HostRegisterOp>(op).value().getType();
+  auto memRefType = hostRegisterOp.value().getType();
   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
 
@@ -412,19 +413,19 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
 // afterwards. In case this isn't correct, we will get a runtime error.
 // Eventually, we will have a pass that guarantees this property.
 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
+    gpu::WaitOp waitOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  if (cast<gpu::WaitOp>(op).asyncToken())
-    return rewriter.notifyMatchFailure(op, "Cannot convert async op.");
+  if (waitOp.asyncToken())
+    return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
 
-  Location loc = op->getLoc();
+  Location loc = waitOp.getLoc();
 
   for (auto asyncDependency : operands)
     streamSynchronizeCallBuilder.create(loc, rewriter, {asyncDependency});
   for (auto asyncDependency : operands)
     streamDestroyCallBuilder.create(loc, rewriter, {asyncDependency});
 
-  rewriter.eraseOp(op);
+  rewriter.eraseOp(waitOp);
   return success();
 }
 
@@ -435,23 +436,23 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
 // assumes that there is no other use between the definition and this op, and
 // the plan is to have a pass that guarantees this property.
 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
+    gpu::WaitOp waitOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  if (!cast<gpu::WaitOp>(op).asyncToken())
-    return rewriter.notifyMatchFailure(op, "Can only convert async op.");
+  if (!waitOp.asyncToken())
+    return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
 
-  Location loc = op->getLoc();
+  Location loc = waitOp.getLoc();
 
   auto insertionPoint = rewriter.saveInsertionPoint();
   SmallVector<Value, 1> events;
-  for (auto pair : llvm::zip(op->getOperands(), operands)) {
+  for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) {
     auto token = std::get<0>(pair);
     if (auto *defOp = token.getDefiningOp()) {
       rewriter.setInsertionPointAfter(defOp);
     } else {
       // If we can't find the defining op, we record the event at block start,
       // which is late and therefore misses parallelism, but still valid.
-      rewriter.setInsertionPointToStart(op->getBlock());
+      rewriter.setInsertionPointToStart(waitOp.getOperation()->getBlock());
     }
     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
     auto stream = std::get<1>(pair);
@@ -464,7 +465,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
   for (auto event : events)
     eventDestroyCallBuilder.create(loc, rewriter, {event});
-  rewriter.replaceOp(op, {stream});
+  rewriter.replaceOp(waitOp, {stream});
 
   return success();
 }
@@ -564,23 +565,21 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
 // If the op is async, the stream corresponds to the (single) async dependency
 // as well as the async token the op produces.
 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
+    gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  if (failed(areAllLLVMTypes(op, operands, rewriter)))
+  if (failed(areAllLLVMTypes(launchOp, operands, rewriter)))
     return failure();
 
-  auto launchOp = cast<gpu::LaunchFuncOp>(op);
-
   if (launchOp.asyncDependencies().size() > 1)
     return rewriter.notifyMatchFailure(
-        op, "Cannot convert with more than one async dependency.");
+        launchOp, "Cannot convert with more than one async dependency.");
 
   // Fail when the synchronous version of the op has async dependencies. The
   // lowering destroys the stream, and we do not want to check that there is no
   // use of the stream after this op.
   if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
     return rewriter.notifyMatchFailure(
-        op, "Cannot convert non-async op with async dependencies.");
+        launchOp, "Cannot convert non-async op with async dependencies.");
 
   Location loc = launchOp.getLoc();
 
@@ -612,7 +611,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
       loc, rewriter, {module.getResult(0), kernelName});
   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                                 rewriter.getI32IntegerAttr(0));
-  auto adaptor = gpu::LaunchFuncOpAdaptor(operands, op->getAttrDictionary());
+  auto adaptor = gpu::LaunchFuncOpAdaptor(
+      operands, launchOp.getOperation()->getAttrDictionary());
   Value stream =
       adaptor.asyncDependencies().empty()
           ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
@@ -620,23 +620,24 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
   // Create array of pointers to kernel arguments.
   auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
-  launchKernelCallBuilder.create(
-      loc, rewriter,
-      {function.getResult(0), launchOp.gridSizeX(), launchOp.gridSizeY(),
-       launchOp.gridSizeZ(), launchOp.blockSizeX(), launchOp.blockSizeY(),
-       launchOp.blockSizeZ(), /*sharedMemBytes=*/zero, stream, kernelParams,
-       /*extra=*/nullpointer});
+  launchKernelCallBuilder.create(loc, rewriter,
+                                 {function.getResult(0), launchOp.gridSizeX(),
+                                  launchOp.gridSizeY(), launchOp.gridSizeZ(),
+                                  launchOp.blockSizeX(), launchOp.blockSizeY(),
+                                  launchOp.blockSizeZ(),
+                                  /*sharedMemBytes=*/zero, stream, kernelParams,
+                                  /*extra=*/nullpointer});
 
   if (launchOp.asyncToken()) {
     // Async launch: make dependent ops use the same stream.
-    rewriter.replaceOp(op, {stream});
+    rewriter.replaceOp(launchOp, {stream});
   } else {
     // Synchronize with host and destroy stream. This must be the stream created
     // above (with no other uses) because we check that the synchronous version
     // does not have any async dependencies.
     streamSynchronizeCallBuilder.create(loc, rewriter, stream);
     streamDestroyCallBuilder.create(loc, rewriter, stream);
-    rewriter.eraseOp(op);
+    rewriter.eraseOp(launchOp);
   }
   moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
 

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index c34198e48d6f..525a5be24485 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -151,9 +151,9 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
   using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    gpu::LaunchFuncOp launchOp = cast<gpu::LaunchFuncOp>(op);
+    auto *op = launchOp.getOperation();
     MLIRContext *context = rewriter.getContext();
     auto module = launchOp.getParentOfType<ModuleOp>();
 

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 49942995fc78..c19f53c4e999 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1396,10 +1396,8 @@ struct FuncOpConversion : public FuncOpConversionBase {
       : FuncOpConversionBase(converter) {}
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto funcOp = cast<FuncOp>(op);
-
     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
     if (!newFuncOp)
       return failure();
@@ -1407,14 +1405,14 @@ struct FuncOpConversion : public FuncOpConversionBase {
     if (typeConverter.getOptions().emitCWrappers ||
         funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
       if (newFuncOp.isExternal())
-        wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp,
+        wrapExternalFunction(rewriter, funcOp.getLoc(), typeConverter, funcOp,
                              newFuncOp);
       else
-        wrapForExternalCallers(rewriter, op->getLoc(), typeConverter, funcOp,
+        wrapForExternalCallers(rewriter, funcOp.getLoc(), typeConverter, funcOp,
                                newFuncOp);
     }
 
-    rewriter.eraseOp(op);
+    rewriter.eraseOp(funcOp);
     return success();
   }
 };
@@ -1425,10 +1423,8 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
   using FuncOpConversionBase::FuncOpConversionBase;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto funcOp = cast<FuncOp>(op);
-
     // Store the type of memref-typed arguments before the conversion so that we
     // can promote them to MemRef descriptor at the beginning of the function.
     SmallVector<Type, 8> oldArgTypes =
@@ -1438,7 +1434,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
     if (!newFuncOp)
       return failure();
     if (newFuncOp.getBody().empty()) {
-      rewriter.eraseOp(op);
+      rewriter.eraseOp(funcOp);
       return success();
     }
 
@@ -1471,7 +1467,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
       // TODO: The placeholder is needed to avoid replacing barePtr uses in the
       // MemRef descriptor instructions. We may want to have a utility in the
       // rewriter to properly handle this use case.
-      Location loc = op->getLoc();
+      Location loc = funcOp.getLoc();
       auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
       rewriter.replaceUsesOfBlockArgument(arg, placeholder);
 
@@ -1480,7 +1476,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
       rewriter.replaceOp(placeholder, {desc});
     }
 
-    rewriter.eraseOp(op);
+    rewriter.eraseOp(funcOp);
     return success();
   }
 };
@@ -1711,13 +1707,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
   using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op->getLoc();
+    auto loc = op.getLoc();
     AssertOp::Adaptor transformed(operands);
 
     // Insert the `abort` declaration if necessary.
-    auto module = op->getParentOfType<ModuleOp>();
+    auto module = op.getParentOfType<ModuleOp>();
     auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
     if (!abortFunc) {
       OpBuilder::InsertionGuard guard(rewriter);
@@ -1754,13 +1750,13 @@ struct CreateComplexOpLowering
   using ConvertOpToLLVMPattern<CreateComplexOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(CreateComplexOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto complexOp = cast<CreateComplexOp>(op);
     CreateComplexOp::Adaptor transformed(operands);
 
     // Pack real and imaginary part in a complex number struct.
-    auto loc = op->getLoc();
+    auto loc = op.getLoc();
     auto structType = typeConverter.convertType(complexOp.getType());
     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
     complexStruct.setReal(rewriter, loc, transformed.real());
@@ -1775,13 +1771,13 @@ struct ReOpLowering : public ConvertOpToLLVMPattern<ReOp> {
   using ConvertOpToLLVMPattern<ReOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(ReOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     ReOp::Adaptor transformed(operands);
 
     // Extract real part from the complex number struct.
     ComplexStructBuilder complexStruct(transformed.complex());
-    Value real = complexStruct.real(rewriter, op->getLoc());
+    Value real = complexStruct.real(rewriter, op.getLoc());
     rewriter.replaceOp(op, real);
 
     return success();
@@ -1792,13 +1788,13 @@ struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
   using ConvertOpToLLVMPattern<ImOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(ImOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     ImOp::Adaptor transformed(operands);
 
     // Extract imaginary part from the complex number struct.
     ComplexStructBuilder complexStruct(transformed.complex());
-    Value imaginary = complexStruct.imaginary(rewriter, op->getLoc());
+    Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
     rewriter.replaceOp(op, imaginary);
 
     return success();
@@ -1833,9 +1829,8 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
   using ConvertOpToLLVMPattern<AddCFOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
+  matchAndRewrite(AddCFOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto op = cast<AddCFOp>(operation);
     auto loc = op.getLoc();
     BinaryComplexOperands arg =
         unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
@@ -1861,9 +1856,8 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
   using ConvertOpToLLVMPattern<SubCFOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
+  matchAndRewrite(SubCFOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto op = cast<SubCFOp>(operation);
     auto loc = op.getLoc();
     BinaryComplexOperands arg =
         unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
@@ -1889,9 +1883,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
   using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
+  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto op = cast<ConstantOp>(operation);
     // If constant refers to a function, convert it to "addressof".
     if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
       auto type = typeConverter.convertType(op.getResult().getType())
@@ -2284,10 +2277,9 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
   using Base = ConvertOpToLLVMPattern<CallOpType>;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(CallOpType callOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     typename CallOpType::Adaptor transformed(operands);
-    auto callOp = cast<CallOpType>(op);
 
     // Pack the result types into a struct.
     Type packedResult = nullptr;
@@ -2301,10 +2293,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
     }
 
     auto promoted = this->typeConverter.promoteOperands(
-        op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter);
+        callOp.getLoc(), /*opOperands=*/callOp.getOperation()->getOperands(),
+        operands, rewriter);
     auto newOp = rewriter.create<LLVM::CallOp>(
-        op->getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
-        promoted, op->getAttrs());
+        callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
+        promoted, callOp.getAttrs());
 
     SmallVector<Value, 4> results;
     if (numResults < 2) {
@@ -2315,9 +2308,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
       // Extract individual results from the structure and return them as list.
       results.reserve(numResults);
       for (unsigned i = 0; i < numResults; ++i) {
-        auto type = this->typeConverter.convertType(op->getResult(i).getType());
+        auto type =
+            this->typeConverter.convertType(callOp.getResult(i).getType());
         results.push_back(rewriter.create<LLVM::ExtractValueOp>(
-            op->getLoc(), type, newOp.getOperation()->getResult(0),
+            callOp.getLoc(), type, newOp.getOperation()->getResult(0),
             rewriter.getI64ArrayAttr(i)));
       }
     }
@@ -2327,16 +2321,16 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
       // descriptors.
       assert(results.size() == resultTypes.size() &&
              "The number of arguments and types doesn't match");
-      this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(),
-                                                       resultTypes, results);
-    } else if (failed(copyUnrankedDescriptors(rewriter, op->getLoc(),
+      this->typeConverter.promoteBarePtrsToDescriptors(
+          rewriter, callOp.getLoc(), resultTypes, results);
+    } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(),
                                               this->typeConverter, resultTypes,
                                               results,
                                               /*toDynamic=*/false))) {
       return failure();
     }
 
-    rewriter.replaceOp(op, results);
+    rewriter.replaceOp(callOp, results);
     return success();
   }
 };
@@ -2359,18 +2353,18 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
       : ConvertOpToLLVMPattern<DeallocOp>(converter) {}
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(DeallocOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     assert(operands.size() == 1 && "dealloc takes one operand");
     DeallocOp::Adaptor transformed(operands);
 
     // Insert the `free` declaration if it is not already present.
     auto freeFunc =
-        op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
+        op.getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
     if (!freeFunc) {
       OpBuilder::InsertionGuard guard(rewriter);
       rewriter.setInsertionPointToStart(
-          op->getParentOfType<ModuleOp>().getBody());
+          op.getParentOfType<ModuleOp>().getBody());
       freeFunc = rewriter.create<LLVM::LLVMFuncOp>(
           rewriter.getUnknownLoc(), "free",
           LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
@@ -2379,8 +2373,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
 
     MemRefDescriptor memref(transformed.memref());
     Value casted = rewriter.create<LLVM::BitcastOp>(
-        op->getLoc(), getVoidPtrType(),
-        memref.allocatedPtr(rewriter, op->getLoc()));
+        op.getLoc(), getVoidPtrType(),
+        memref.allocatedPtr(rewriter, op.getLoc()));
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
         op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
     return success();
@@ -2410,9 +2404,8 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
   using ConvertOpToLLVMPattern<GlobalMemrefOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(GlobalMemrefOp global, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto global = cast<GlobalMemrefOp>(op);
     MemRefType type = global.type().cast<MemRefType>();
     if (!isSupportedMemRefType(type))
       return failure();
@@ -2434,7 +2427,7 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
     }
 
     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
-        op, arrayTy, global.constant(), linkage, global.sym_name(),
+        global, arrayTy, global.constant(), linkage, global.sym_name(),
         initialValue, type.getMemorySpace());
     return success();
   }
@@ -2491,7 +2484,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     RsqrtOp::Adaptor transformed(operands);
     auto operandType =
@@ -2500,8 +2493,8 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
     if (!operandType)
       return failure();
 
-    auto loc = op->getLoc();
-    auto resultType = *op->result_type_begin();
+    auto loc = op.getLoc();
+    auto resultType = op.getResult().getType();
     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
 
@@ -2524,7 +2517,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
       return failure();
 
     return handleMultidimensionalVectors(
-        op, operands, typeConverter,
+        op.getOperation(), operands, typeConverter,
         [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
           auto splatAttr = SplatElementsAttr::get(
               mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
@@ -2543,8 +2536,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
   using ConvertOpToLLVMPattern<MemRefCastOp>::ConvertOpToLLVMPattern;
 
-  LogicalResult match(Operation *op) const override {
-    auto memRefCastOp = cast<MemRefCastOp>(op);
+  LogicalResult match(MemRefCastOp memRefCastOp) const override {
     Type srcType = memRefCastOp.getOperand().getType();
     Type dstType = memRefCastOp.getType();
 
@@ -2568,19 +2560,18 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
                : failure();
   }
 
-  void rewrite(Operation *op, ArrayRef<Value> operands,
+  void rewrite(MemRefCastOp memRefCastOp, ArrayRef<Value> operands,
                ConversionPatternRewriter &rewriter) const override {
-    auto memRefCastOp = cast<MemRefCastOp>(op);
     MemRefCastOp::Adaptor transformed(operands);
 
     auto srcType = memRefCastOp.getOperand().getType();
     auto dstType = memRefCastOp.getType();
     auto targetStructType = typeConverter.convertType(memRefCastOp.getType());
-    auto loc = op->getLoc();
+    auto loc = memRefCastOp.getLoc();
 
     // For ranked/ranked case, just keep the original descriptor.
     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
-      return rewriter.replaceOp(op, {transformed.source()});
+      return rewriter.replaceOp(memRefCastOp, {transformed.source()});
 
     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
       // Casting ranked to unranked memref type
@@ -2607,7 +2598,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
       memRefDesc.setRank(rewriter, loc, rankVal);
       // d2 = InsertValueOp d1, voidptr, 1
       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
-      rewriter.replaceOp(op, (Value)memRefDesc);
+      rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
 
     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
       // Casting from unranked type to ranked.
@@ -2625,7 +2616,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
               .getResult();
       // struct = LoadOp castPtr
       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
-      rewriter.replaceOp(op, loadOp.getResult());
+      rewriter.replaceOp(memRefCastOp, loadOp.getResult());
     } else {
       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
     }
@@ -2680,17 +2671,17 @@ struct MemRefReinterpretCastOpLowering
   using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(MemRefReinterpretCastOp castOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto castOp = cast<MemRefReinterpretCastOp>(op);
-    MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary());
+    MemRefReinterpretCastOp::Adaptor adaptor(
+        operands, castOp.getOperation()->getAttrDictionary());
     Type srcType = castOp.source().getType();
 
     Value descriptor;
     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
                                                adaptor, &descriptor)))
       return failure();
-    rewriter.replaceOp(op, {descriptor});
+    rewriter.replaceOp(castOp, {descriptor});
     return success();
   }
 
@@ -2748,10 +2739,9 @@ struct MemRefReshapeOpLowering
   using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(MemRefReshapeOp reshapeOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto reshapeOp = cast<MemRefReshapeOp>(op);
-
+    auto *op = reshapeOp.getOperation();
     MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
     Type srcType = reshapeOp.source().getType();
 
@@ -2898,15 +2888,14 @@ struct DialectCastOpLowering
   using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(LLVM::DialectCastOp castOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto castOp = cast<LLVM::DialectCastOp>(op);
     LLVM::DialectCastOp::Adaptor transformed(operands);
     if (transformed.in().getType() !=
         typeConverter.convertType(castOp.getType())) {
       return failure();
     }
-    rewriter.replaceOp(op, transformed.in());
+    rewriter.replaceOp(castOp, transformed.in());
     return success();
   }
 };
@@ -2917,19 +2906,18 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
   using ConvertOpToLLVMPattern<DimOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(DimOp dimOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto dimOp = cast<DimOp>(op);
     Type operandType = dimOp.memrefOrTensor().getType();
     if (operandType.isa<UnrankedMemRefType>()) {
-      rewriter.replaceOp(op, {extractSizeOfUnrankedMemRef(operandType, dimOp,
-                                                          operands, rewriter)});
+      rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
+                                    operandType, dimOp, operands, rewriter)});
 
       return success();
     }
     if (operandType.isa<MemRefType>()) {
-      rewriter.replaceOp(op, {extractSizeOfRankedMemRef(operandType, dimOp,
-                                                        operands, rewriter)});
+      rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
+                                    operandType, dimOp, operands, rewriter)});
       return success();
     }
     return failure();
@@ -3006,10 +2994,10 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
   using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(RankOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
-    Type operandType = cast<RankOp>(op).memrefOrTensor().getType();
+    Location loc = op.getLoc();
+    Type operandType = op.memrefOrTensor().getType();
     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
       UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
@@ -3033,8 +3021,8 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
   using ConvertOpToLLVMPattern<Derived>::isSupportedMemRefType;
   using Base = LoadStoreOpLowering<Derived>;
 
-  LogicalResult match(Operation *op) const override {
-    MemRefType type = cast<Derived>(op).getMemRefType();
+  LogicalResult match(Derived op) const override {
+    MemRefType type = op.getMemRefType();
     return isSupportedMemRefType(type) ? success() : failure();
   }
 };
@@ -3045,16 +3033,15 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loadOp = cast<LoadOp>(op);
     LoadOp::Adaptor transformed(operands);
     auto type = loadOp.getMemRefType();
 
     Value dataPtr =
-        getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+        getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
                              transformed.indices(), rewriter);
-    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
+    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
     return success();
   }
 };
@@ -3065,13 +3052,13 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(StoreOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto type = cast<StoreOp>(op).getMemRefType();
+    auto type = op.getMemRefType();
     StoreOp::Adaptor transformed(operands);
 
     Value dataPtr =
-        getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+        getStridedElementPtr(op.getLoc(), type, transformed.memref(),
                              transformed.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
                                                dataPtr);
@@ -3085,29 +3072,26 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(PrefetchOp prefetchOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto prefetchOp = cast<PrefetchOp>(op);
     PrefetchOp::Adaptor transformed(operands);
     auto type = prefetchOp.getMemRefType();
+    auto loc = prefetchOp.getLoc();
 
-    Value dataPtr =
-        getStridedElementPtr(op->getLoc(), type, transformed.memref(),
-                             transformed.indices(), rewriter);
+    Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
+                                         transformed.indices(), rewriter);
 
     // Replace with llvm.prefetch.
     auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
     auto isWrite = rewriter.create<LLVM::ConstantOp>(
-        op->getLoc(), llvmI32Type,
-        rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
+        loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
     auto localityHint = rewriter.create<LLVM::ConstantOp>(
-        op->getLoc(), llvmI32Type,
+        loc, llvmI32Type,
         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
     auto isData = rewriter.create<LLVM::ConstantOp>(
-        op->getLoc(), llvmI32Type,
-        rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
+        loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
 
-    rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite,
+    rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
                                                 localityHint, isData);
     return success();
   }
@@ -3121,10 +3105,9 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
   using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(IndexCastOp indexCastOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     IndexCastOpAdaptor transformed(operands);
-    auto indexCastOp = cast<IndexCastOp>(op);
 
     auto targetType =
         this->typeConverter.convertType(indexCastOp.getResult().getType())
@@ -3134,12 +3117,12 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
     unsigned sourceBits = sourceType.getIntegerBitWidth();
 
     if (targetBits == sourceBits)
-      rewriter.replaceOp(op, transformed.in());
+      rewriter.replaceOp(indexCastOp, transformed.in());
     else if (targetBits < sourceBits)
-      rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
+      rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
                                                  transformed.in());
     else
-      rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
+      rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
                                                 transformed.in());
     return success();
   }
@@ -3156,13 +3139,12 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
   using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto cmpiOp = cast<CmpIOp>(op);
     CmpIOpAdaptor transformed(operands);
 
     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
-        op, typeConverter.convertType(cmpiOp.getResult().getType()),
+        cmpiOp, typeConverter.convertType(cmpiOp.getResult().getType()),
         rewriter.getI64IntegerAttr(static_cast<int64_t>(
             convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
         transformed.lhs(), transformed.rhs());
@@ -3175,13 +3157,12 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
   using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto cmpfOp = cast<CmpFOp>(op);
     CmpFOpAdaptor transformed(operands);
 
     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
-        op, typeConverter.convertType(cmpfOp.getResult().getType()),
+        cmpfOp, typeConverter.convertType(cmpfOp.getResult().getType()),
         rewriter.getI64IntegerAttr(static_cast<int64_t>(
             convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
         transformed.lhs(), transformed.rhs());
@@ -3243,10 +3224,10 @@ struct OneToOneLLVMTerminatorLowering
   using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
-                                          op->getAttrs());
+    rewriter.replaceOpWithNewOp<TargetOp>(
+        op, operands, op.getOperation()->getSuccessors(), op.getAttrs());
     return success();
   }
 };
@@ -3261,16 +3242,16 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
   using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
-    unsigned numArguments = op->getNumOperands();
+    Location loc = op.getLoc();
+    unsigned numArguments = op.getNumOperands();
     SmallVector<Value, 4> updatedOperands;
 
     if (typeConverter.getOptions().useBarePtrCallConv) {
       // For the bare-ptr calling convention, extract the aligned pointer to
       // be returned from the memref descriptor.
-      for (auto it : llvm::zip(op->getOperands(), operands)) {
+      for (auto it : llvm::zip(op.getOperation()->getOperands(), operands)) {
         Type oldTy = std::get<0>(it).getType();
         Value newOperand = std::get<1>(it);
         if (oldTy.isa<MemRefType>()) {
@@ -3286,26 +3267,26 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
     } else {
       updatedOperands = llvm::to_vector<4>(operands);
       copyUnrankedDescriptors(rewriter, loc, typeConverter,
-                              op->getOperands().getTypes(), updatedOperands,
+                              op.getOperands().getTypes(), updatedOperands,
                               /*toDynamic=*/true);
     }
 
     // If ReturnOp has 0 or 1 operand, create it and return immediately.
     if (numArguments == 0) {
       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
-                                                  op->getAttrs());
+                                                  op.getAttrs());
       return success();
     }
     if (numArguments == 1) {
       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
-          op, TypeRange(), updatedOperands, op->getAttrs());
+          op, TypeRange(), updatedOperands, op.getAttrs());
       return success();
     }
 
     // Otherwise, we need to pack the arguments into an LLVM struct type before
     // returning.
     auto packedType = typeConverter.packFunctionResults(
-        llvm::to_vector<4>(op->getOperandTypes()));
+        llvm::to_vector<4>(op.getOperandTypes()));
 
     Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
     for (unsigned i = 0; i < numArguments; ++i) {
@@ -3314,7 +3295,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
           rewriter.getI64ArrayAttr(i));
     }
     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
-                                                op->getAttrs());
+                                                op.getAttrs());
     return success();
   }
 };
@@ -3335,29 +3316,30 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto splatOp = cast<SplatOp>(op);
     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
     if (!resultType || resultType.getRank() != 1)
       return failure();
 
     // First insert it into an undef vector so we can shuffle it.
     auto vectorType = typeConverter.convertType(splatOp.getType());
-    Value undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
+    Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
     auto zero = rewriter.create<LLVM::ConstantOp>(
-        op->getLoc(), typeConverter.convertType(rewriter.getIntegerType(32)),
+        splatOp.getLoc(),
+        typeConverter.convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
 
     auto v = rewriter.create<LLVM::InsertElementOp>(
-        op->getLoc(), vectorType, undef, splatOp.getOperand(), zero);
+        splatOp.getLoc(), vectorType, undef, splatOp.getOperand(), zero);
 
     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
     SmallVector<int32_t, 4> zeroValues(width, 0);
 
     // Shuffle the value across the desired number of elements.
     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
-    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
+    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
+                                                       zeroAttrs);
     return success();
   }
 };
@@ -3369,16 +3351,15 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto splatOp = cast<SplatOp>(op);
     SplatOp::Adaptor adaptor(operands);
     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
     if (!resultType || resultType.getRank() == 1)
       return failure();
 
     // First insert it into an undef vector so we can shuffle it.
-    auto loc = op->getLoc();
+    auto loc = splatOp.getLoc();
     auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter);
     auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
     auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
@@ -3409,7 +3390,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
                                                   position);
     });
-    rewriter.replaceOp(op, desc);
+    rewriter.replaceOp(splatOp, desc);
     return success();
   }
 };
@@ -3431,10 +3412,9 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
   using ConvertOpToLLVMPattern<SubViewOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(SubViewOp subViewOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op->getLoc();
-    auto subViewOp = cast<SubViewOp>(op);
+    auto loc = subViewOp.getLoc();
 
     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
     auto sourceElementTy =
@@ -3545,7 +3525,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
       j--;
     }
 
-    rewriter.replaceOp(op, {targetMemRef});
+    rewriter.replaceOp(subViewOp, {targetMemRef});
     return success();
   }
 };
@@ -3562,16 +3542,15 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
   using ConvertOpToLLVMPattern<TransposeOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(TransposeOp transposeOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op->getLoc();
+    auto loc = transposeOp.getLoc();
     TransposeOpAdaptor adaptor(operands);
     MemRefDescriptor viewMemRef(adaptor.in());
 
-    auto transposeOp = cast<TransposeOp>(op);
     // No permutation, early exit.
     if (transposeOp.permutation().isIdentity())
-      return rewriter.replaceOp(op, {viewMemRef}), success();
+      return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
 
     auto targetMemRef = MemRefDescriptor::undef(
         rewriter, loc, typeConverter.convertType(transposeOp.getShapedType()));
@@ -3596,7 +3575,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
                              viewMemRef.stride(rewriter, loc, sourcePos));
     }
 
-    rewriter.replaceOp(op, {targetMemRef});
+    rewriter.replaceOp(transposeOp, {targetMemRef});
     return success();
   }
 };
@@ -3643,10 +3622,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
   }
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(ViewOp viewOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op->getLoc();
-    auto viewOp = cast<ViewOp>(op);
+    auto loc = viewOp.getLoc();
     ViewOpAdaptor adaptor(operands);
 
     auto viewMemRefType = viewOp.getType();
@@ -3656,14 +3634,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
     auto targetDescTy =
         typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
     if (!targetDescTy)
-      return op->emitWarning("Target descriptor type not converted to LLVM"),
+      return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
              failure();
 
     int64_t offset;
     SmallVector<int64_t, 4> strides;
     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
     if (failed(successStrides))
-      return op->emitWarning("cannot cast to non-strided shape"), failure();
+      return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
     assert(offset == 0 && "expected offset to be 0");
 
     // Create the descriptor.
@@ -3695,11 +3673,12 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
 
     // Early exit for 0-D corner case.
     if (viewMemRefType.getRank() == 0)
-      return rewriter.replaceOp(op, {targetMemRef}), success();
+      return rewriter.replaceOp(viewOp, {targetMemRef}), success();
 
     // Fields 4 and 5: Update sizes and strides.
     if (strides.back() != 1)
-      return op->emitWarning("cannot cast to non-contiguous shape"), failure();
+      return viewOp.emitWarning("cannot cast to non-contiguous shape"),
+             failure();
     Value stride = nullptr, nextSize = nullptr;
     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
       // Update size.
@@ -3712,7 +3691,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
       nextSize = size;
     }
 
-    rewriter.replaceOp(op, {targetMemRef});
+    rewriter.replaceOp(viewOp, {targetMemRef});
     return success();
   }
 };
@@ -3722,11 +3701,12 @@ struct AssumeAlignmentOpLowering
   using ConvertOpToLLVMPattern<AssumeAlignmentOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(AssumeAlignmentOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     AssumeAlignmentOp::Adaptor transformed(operands);
     Value memref = transformed.memref();
-    unsigned alignment = cast<AssumeAlignmentOp>(op).alignment();
+    unsigned alignment = op.alignment();
+    auto loc = op.getLoc();
 
     MemRefDescriptor memRefDescriptor(memref);
     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
@@ -3741,16 +3721,14 @@ struct AssumeAlignmentOpLowering
     // pointer SSA value.
     auto intPtrType =
         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
-    Value zero = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0);
-    Value mask = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType,
-                                         alignment - 1);
-    Value ptrValue =
-        rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), intPtrType, ptr);
+    Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
+    Value mask =
+        createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
+    Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
     rewriter.create<LLVM::AssumeOp>(
-        op->getLoc(),
-        rewriter.create<LLVM::ICmpOp>(
-            op->getLoc(), LLVM::ICmpPredicate::eq,
-            rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
+        loc, rewriter.create<LLVM::ICmpOp>(
+                 loc, LLVM::ICmpPredicate::eq,
+                 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
 
     rewriter.eraseOp(op);
     return success();
@@ -3789,9 +3767,10 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto atomicOp = cast<AtomicRMWOp>(op);
+    if (failed(match(atomicOp)))
+      return failure();
     auto maybeKind = matchSimpleAtomicOp(atomicOp);
     if (!maybeKind)
       return failure();
@@ -3799,10 +3778,10 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
     auto resultType = adaptor.value().getType();
     auto memRefType = atomicOp.getMemRefType();
     auto dataPtr =
-        getStridedElementPtr(op->getLoc(), memRefType, adaptor.memref(),
+        getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
                              adaptor.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
-        op, resultType, *maybeKind, dataPtr, adaptor.value(),
+        atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
         LLVM::AtomicOrdering::acq_rel);
     return success();
   }
@@ -3840,11 +3819,10 @@ struct GenericAtomicRMWOpLowering
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto atomicOp = cast<GenericAtomicRMWOp>(op);
 
-    auto loc = op->getLoc();
+    auto loc = atomicOp.getLoc();
     GenericAtomicRMWOp::Adaptor adaptor(operands);
     LLVM::LLVMType valueType =
         typeConverter.convertType(atomicOp.getResult().getType())
@@ -3908,7 +3886,7 @@ struct GenericAtomicRMWOpLowering
                  std::next(opsToMoveEnd), rewriter);
 
     // The 'result' of the atomic_rmw op is the newly loaded value.
-    rewriter.replaceOp(op, {newLoaded});
+    rewriter.replaceOp(atomicOp, {newLoaded});
 
     return success();
   }

diff  --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
index a612738c5dcc..61062c7938fe 100644
--- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
@@ -25,7 +25,7 @@ class TestTypeProducerOpConverter
       test::TestTypeProducerOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(test::TestTypeProducerOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
     return success();


        


More information about the Mlir-commits mailing list