[Mlir-commits] [mlir] 3145427 - [mlir][NFC] Replace all usages of PatternMatchResult with LogicalResult
River Riddle
llvmlistbot at llvm.org
Tue Mar 17 20:23:18 PDT 2020
Author: River Riddle
Date: 2020-03-17T20:21:32-07:00
New Revision: 3145427dd73f0ee16dac4044890e2e2d2cae5040
URL: https://github.com/llvm/llvm-project/commit/3145427dd73f0ee16dac4044890e2e2d2cae5040
DIFF: https://github.com/llvm/llvm-project/commit/3145427dd73f0ee16dac4044890e2e2d2cae5040.diff
LOG: [mlir][NFC] Replace all usages of PatternMatchResult with LogicalResult
This also replaces usages of matchSuccess/matchFailure with success/failure respectively.
Differential Revision: https://reviews.llvm.org/D76313
Added:
Modified:
mlir/docs/DialectConversion.md
mlir/docs/QuickstartRewrites.md
mlir/docs/Tutorials/Toy/Ch-3.md
mlir/docs/Tutorials/Toy/Ch-5.md
mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
mlir/lib/Dialect/AffineOps/AffineOps.cpp
mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 8af0e4fb0b25..ec02b7274a46 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -247,7 +247,7 @@ struct MyConversionPattern : public ConversionPattern {
/// The `matchAndRewrite` hooks on ConversionPatterns take an additional
/// `operands` parameter, containing the remapped operands of the original
/// operation.
- virtual PatternMatchResult
+ virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const;
};
diff --git a/mlir/docs/QuickstartRewrites.md b/mlir/docs/QuickstartRewrites.md
index e0370e487181..71e8f0164fc9 100644
--- a/mlir/docs/QuickstartRewrites.md
+++ b/mlir/docs/QuickstartRewrites.md
@@ -171,8 +171,8 @@ struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}
- PatternMatchResult match(Operation *op) const override {
- return matchSuccess();
+ LogicalResult match(Operation *op) const override {
+ return success();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
@@ -188,12 +188,12 @@ struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op,
+ LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
op, op->getResult(0).getType(), op->getOperand(0),
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
- return matchSuccess();
+ return success();
}
};
```
diff --git a/mlir/docs/Tutorials/Toy/Ch-3.md b/mlir/docs/Tutorials/Toy/Ch-3.md
index 9d5911761263..f8c8357442c8 100644
--- a/mlir/docs/Tutorials/Toy/Ch-3.md
+++ b/mlir/docs/Tutorials/Toy/Ch-3.md
@@ -86,7 +86,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
- mlir::PatternMatchResult
+ mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@@ -96,11 +96,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
- return matchFailure();
+ return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
- return matchSuccess();
+ return success();
}
};
```
diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md
index 57a7baf48c97..dbc545c49206 100644
--- a/mlir/docs/Tutorials/Toy/Ch-5.md
+++ b/mlir/docs/Tutorials/Toy/Ch-5.md
@@ -106,7 +106,7 @@ struct TransposeOpLowering : public mlir::ConversionPattern {
/// Match and rewrite the given `toy.transpose` operation, with the given
/// operands that have been remapped from `tensor<...>` to `memref<...>`.
- mlir::PatternMatchResult
+ mlir::LogicalResult
matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@@ -132,7 +132,7 @@ struct TransposeOpLowering : public mlir::ConversionPattern {
SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs);
});
- return matchSuccess();
+ return success();
}
};
```
diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
index da1c36311765..8529ea0f24ee 100644
--- a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
@@ -35,7 +35,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
- mlir::PatternMatchResult
+ mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@@ -45,11 +45,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
- return matchFailure();
+ return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
index 82a1ee0b7e40..0dd38b2c31a4 100644
--- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
@@ -40,7 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
- mlir::PatternMatchResult
+ mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@@ -50,11 +50,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
- return matchFailure();
+ return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 28531e600ffc..9559402708c8 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -103,7 +103,7 @@ struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@@ -126,7 +126,7 @@ struct BinaryOpLowering : public ConversionPattern {
// Create the binary operation performed on the loaded values.
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
- return matchSuccess();
+ return success();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
@@ -139,8 +139,8 @@ using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(toy::ConstantOp op,
+ PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
@@ -189,7 +189,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
- return matchSuccess();
+ return success();
}
};
@@ -200,16 +200,16 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(toy::ReturnOp op,
+ PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
- return matchFailure();
+ return failure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
- return matchSuccess();
+ return success();
}
};
@@ -221,7 +221,7 @@ struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@@ -240,7 +240,7 @@ struct TransposeOpLowering : public ConversionPattern {
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
index 82a1ee0b7e40..0dd38b2c31a4 100644
--- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
@@ -40,7 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
- mlir::PatternMatchResult
+ mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@@ -50,11 +50,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
- return matchFailure();
+ return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index 28531e600ffc..9559402708c8 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -103,7 +103,7 @@ struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@@ -126,7 +126,7 @@ struct BinaryOpLowering : public ConversionPattern {
// Create the binary operation performed on the loaded values.
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
- return matchSuccess();
+ return success();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
@@ -139,8 +139,8 @@ using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(toy::ConstantOp op,
+ PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
@@ -189,7 +189,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
- return matchSuccess();
+ return success();
}
};
@@ -200,16 +200,16 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(toy::ReturnOp op,
+ PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
- return matchFailure();
+ return failure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
- return matchSuccess();
+ return success();
}
};
@@ -221,7 +221,7 @@ struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@@ -240,7 +240,7 @@ struct TransposeOpLowering : public ConversionPattern {
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 439009acd886..5455738dff2a 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -41,7 +41,7 @@ class PrintOpLowering : public ConversionPattern {
explicit PrintOpLowering(MLIRContext *context)
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
@@ -91,7 +91,7 @@ class PrintOpLowering : public ConversionPattern {
// Notify the rewriter that this operation has been removed.
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
private:
diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
index 82a1ee0b7e40..0dd38b2c31a4 100644
--- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
@@ -40,7 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
- mlir::PatternMatchResult
+ mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@@ -50,11 +50,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
- return matchFailure();
+ return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 28531e600ffc..9559402708c8 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -103,7 +103,7 @@ struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@@ -126,7 +126,7 @@ struct BinaryOpLowering : public ConversionPattern {
// Create the binary operation performed on the loaded values.
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
- return matchSuccess();
+ return success();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
@@ -139,8 +139,8 @@ using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(toy::ConstantOp op,
+ PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
@@ -189,7 +189,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
- return matchSuccess();
+ return success();
}
};
@@ -200,16 +200,16 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(toy::ReturnOp op,
+ PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
- return matchFailure();
+ return failure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
- return matchSuccess();
+ return success();
}
};
@@ -221,7 +221,7 @@ struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@@ -240,7 +240,7 @@ struct TransposeOpLowering : public ConversionPattern {
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 439009acd886..5455738dff2a 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -41,7 +41,7 @@ class PrintOpLowering : public ConversionPattern {
explicit PrintOpLowering(MLIRContext *context)
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
@@ -91,7 +91,7 @@ class PrintOpLowering : public ConversionPattern {
// Notify the rewriter that this operation has been removed.
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
private:
diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
index 9f0795f936cd..fafc3876db27 100644
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -58,7 +58,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
- mlir::PatternMatchResult
+ mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@@ -68,11 +68,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
- return matchFailure();
+ return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
index 6470f21a5c8e..c080ff2066d0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
@@ -53,7 +53,7 @@ class TileAndFuseLinalgOp<
"if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," #
" \"" # value # "\")))" #
- " return matchFailure();">;
+ " return failure();">;
//===----------------------------------------------------------------------===//
// Linalg tiling patterns.
@@ -70,22 +70,22 @@ class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
"if (failed(tileLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
StrJoinInt<permutation>.result # "})))" #
- " return matchFailure();">;
+ " return failure();">;
//===----------------------------------------------------------------------===//
// Linalg to loop patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " #
- " return matchFailure();">;
+ " return failure();">;
class LinalgOpToParallelLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToParallelLoops<" # OpType # ">($_builder, op))) " #
- " return matchFailure();">;
+ " return failure();">;
class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " #
- " return matchFailure();">;
+ " return failure();">;
//===----------------------------------------------------------------------===//
// Linalg to vector patterns precondition and DRR.
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index c36e8ab5aed1..9882ce933834 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -54,10 +54,6 @@ class PatternBenefit {
unsigned short representation;
};
-/// This is the type returned by a pattern match.
-/// TODO: Replace usages with LogicalResult directly.
-using PatternMatchResult = LogicalResult;
-
//===----------------------------------------------------------------------===//
// Pattern class
//===----------------------------------------------------------------------===//
@@ -85,20 +81,10 @@ class Pattern {
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
- virtual PatternMatchResult match(Operation *op) const = 0;
+ virtual LogicalResult match(Operation *op) const = 0;
virtual ~Pattern() {}
- //===--------------------------------------------------------------------===//
- // Helper methods to simplify pattern implementations
- //===--------------------------------------------------------------------===//
-
- /// Return a result, indicating that no match was found.
- PatternMatchResult matchFailure() const { return failure(); }
-
- /// This method indicates that a match was found.
- PatternMatchResult matchSuccess() const { return success(); }
-
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
@@ -130,22 +116,19 @@ class RewritePattern : public Pattern {
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
/// Attempt to match against code rooted at the specified operation,
- /// which is the same operation code as getRootKind(). On failure, this
- /// returns a None value. On success, it returns a (possibly null)
- /// pattern-specific state wrapped in an Optional. This state is passed back
- /// into the rewrite function if this match is selected.
- PatternMatchResult match(Operation *op) const override;
+ /// which is the same operation code as getRootKind().
+ LogicalResult match(Operation *op) const override;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
/// function will automatically perform the rewrite.
- virtual PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
+ virtual LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
if (succeeded(match(op))) {
rewrite(op, rewriter);
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
/// Return a list of operations that may be generated when rewriting an
@@ -182,11 +165,11 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), rewriter);
}
- PatternMatchResult match(Operation *op) const final {
+ LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), rewriter);
}
@@ -195,16 +178,16 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
- virtual PatternMatchResult match(SourceOp op) const {
+ virtual LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
- virtual PatternMatchResult matchAndRewrite(SourceOp op,
- PatternRewriter &rewriter) const {
+ virtual LogicalResult matchAndRewrite(SourceOp op,
+ PatternRewriter &rewriter) const {
if (succeeded(match(op))) {
rewrite(op, rewriter);
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
};
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 7db8355e4177..776007347c5e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -235,18 +235,18 @@ class ConversionPattern : public RewritePattern {
}
/// Hook for derived classes to implement combined matching and rewriting.
- virtual PatternMatchResult
+ virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (failed(match(op)))
- return matchFailure();
+ return failure();
rewrite(op, operands, rewriter);
- return matchSuccess();
+ return success();
}
/// Attempt to match and rewrite the IR root at the specified operation.
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const final;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final;
private:
using RewritePattern::rewrite;
@@ -266,7 +266,7 @@ struct OpConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
@@ -282,13 +282,13 @@ struct OpConversionPattern : public ConversionPattern {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
- virtual PatternMatchResult
+ virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (failed(match(op)))
- return matchFailure();
+ return failure();
rewrite(op, operands, rewriter);
- return matchSuccess();
+ return success();
}
private:
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 6c84268df8d7..9c100a280a64 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -297,15 +297,15 @@ class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
public:
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineMinOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineMinOp op,
+ PatternRewriter &rewriter) const override {
Value reduced =
lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
if (!reduced)
- return matchFailure();
+ return failure();
rewriter.replaceOp(op, reduced);
- return matchSuccess();
+ return success();
}
};
@@ -313,15 +313,15 @@ class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
public:
using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineMaxOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineMaxOp op,
+ PatternRewriter &rewriter) const override {
Value reduced =
lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
if (!reduced)
- return matchFailure();
+ return failure();
rewriter.replaceOp(op, reduced);
- return matchSuccess();
+ return success();
}
};
@@ -330,10 +330,10 @@ class AffineTerminatorLowering : public OpRewritePattern<AffineTerminatorOp> {
public:
using OpRewritePattern<AffineTerminatorOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineTerminatorOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineTerminatorOp op,
+ PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<loop::YieldOp>(op);
- return matchSuccess();
+ return success();
}
};
@@ -341,8 +341,8 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
public:
using OpRewritePattern<AffineForOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineForOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineForOp op,
+ PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lowerBound = lowerAffineLowerBound(op, rewriter);
Value upperBound = lowerAffineUpperBound(op, rewriter);
@@ -351,7 +351,7 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
f.region().getBlocks().clear();
rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end());
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
};
@@ -359,8 +359,8 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
public:
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineIfOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineIfOp op,
+ PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
// Now we just have to handle the condition logic.
@@ -381,7 +381,7 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
operandsRef.take_front(numDims),
operandsRef.drop_front(numDims));
if (!affResult)
- return matchFailure();
+ return failure();
auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
Value cmpVal =
rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
@@ -402,7 +402,7 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
// Ok, we're done!
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
};
@@ -412,15 +412,15 @@ class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
public:
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineApplyOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineApplyOp op,
+ PatternRewriter &rewriter) const override {
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
llvm::to_vector<8>(op.getOperands()));
if (!maybeExpandedMap)
- return matchFailure();
+ return failure();
rewriter.replaceOp(op, *maybeExpandedMap);
- return matchSuccess();
+ return success();
}
};
@@ -431,18 +431,18 @@ class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
public:
using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineLoadOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineLoadOp op,
+ PatternRewriter &rewriter) const override {
// Expand affine map from 'affineLoadOp'.
SmallVector<Value, 8> indices(op.getMapOperands());
auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!resultOperands)
- return matchFailure();
+ return failure();
// Build std.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
- return matchSuccess();
+ return success();
}
};
@@ -453,20 +453,20 @@ class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
public:
using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffinePrefetchOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffinePrefetchOp op,
+ PatternRewriter &rewriter) const override {
// Expand affine map from 'affinePrefetchOp'.
SmallVector<Value, 8> indices(op.getMapOperands());
auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!resultOperands)
- return matchFailure();
+ return failure();
// Build std.prefetch memref[expandedMap.results].
rewriter.replaceOpWithNewOp<PrefetchOp>(
op, op.memref(), *resultOperands, op.isWrite(),
op.localityHint().getZExtValue(), op.isDataCache());
- return matchSuccess();
+ return success();
}
};
@@ -477,19 +477,19 @@ class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
public:
using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineStoreOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineStoreOp op,
+ PatternRewriter &rewriter) const override {
// Expand affine map from 'affineStoreOp'.
SmallVector<Value, 8> indices(op.getMapOperands());
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!maybeExpandedMap)
- return matchFailure();
+ return failure();
// Build std.store valueToStore, memref[expandedMap.results].
rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
op.getMemRef(), *maybeExpandedMap);
- return matchSuccess();
+ return success();
}
};
@@ -500,8 +500,8 @@ class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
public:
using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineDmaStartOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineDmaStartOp op,
+ PatternRewriter &rewriter) const override {
SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::makeArrayRef(operands);
@@ -510,26 +510,26 @@ class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
rewriter, op.getLoc(), op.getSrcMap(),
operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
if (!maybeExpandedSrcMap)
- return matchFailure();
+ return failure();
// Expand affine map for DMA destination memref.
auto maybeExpandedDstMap = expandAffineMap(
rewriter, op.getLoc(), op.getDstMap(),
operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
if (!maybeExpandedDstMap)
- return matchFailure();
+ return failure();
// Expand affine map for DMA tag memref.
auto maybeExpandedTagMap = expandAffineMap(
rewriter, op.getLoc(), op.getTagMap(),
operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
if (!maybeExpandedTagMap)
- return matchFailure();
+ return failure();
// Build std.dma_start operation with affine map results.
rewriter.replaceOpWithNewOp<DmaStartOp>(
op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
*maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
*maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
- return matchSuccess();
+ return success();
}
};
@@ -540,19 +540,19 @@ class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
public:
using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineDmaWaitOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineDmaWaitOp op,
+ PatternRewriter &rewriter) const override {
// Expand affine map for DMA tag memref.
SmallVector<Value, 8> indices(op.getTagIndices());
auto maybeExpandedTagMap =
expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
if (!maybeExpandedTagMap)
- return matchFailure();
+ return failure();
// Build std.dma_wait operation with affine map results.
rewriter.replaceOpWithNewOp<DmaWaitOp>(
op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index f925481ae7a0..3ab1fc48cf72 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -46,7 +46,7 @@ struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
indexBitwidth(getIndexBitWidth(lowering_)) {}
// Convert the kernel arguments to an LLVM type, preserve the rest.
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -63,7 +63,7 @@ struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
break;
default:
- return matchFailure();
+ return failure();
}
if (indexBitwidth > 32) {
@@ -75,7 +75,7 @@ struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
}
rewriter.replaceOp(op, {newOp});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index fafdb9d1b90e..eb5628d18e46 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -34,7 +34,7 @@ struct OpToFuncCallLowering : public ConvertToLLVMPattern {
lowering_.getDialect()->getContext(), lowering_),
f32Func(f32Func), f64Func(f64Func) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
using LLVM::LLVMFuncOp;
@@ -49,13 +49,13 @@ struct OpToFuncCallLowering : public ConvertToLLVMPattern {
LLVMType funcType = getFunctionType(resultType, operands);
StringRef funcName = getFunctionName(resultType);
if (funcName.empty())
- return matchFailure();
+ return failure();
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp = rewriter.create<LLVM::CallOp>(
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
rewriter.replaceOp(op, {callOp.getResult(0)});
- return matchSuccess();
+ return success();
}
private:
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 67c335e629fe..e929caac6133 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -51,7 +51,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
/// !llvm<"{ float, i1 }">
/// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
/// !llvm<"{ float, i1 }">
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
@@ -84,7 +84,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
- return matchSuccess();
+ return success();
}
};
@@ -94,7 +94,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
typeConverter.getDialect()->getContext(),
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.empty() && "func op is not expected to have operands");
@@ -219,7 +219,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
signatureConversion);
rewriter.eraseOp(gpuFuncOp);
- return matchSuccess();
+ return success();
}
};
@@ -229,11 +229,11 @@ struct GPUReturnOpLowering : public ConvertToLLVMPattern {
typeConverter.getDialect()->getContext(),
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
index 3c07097db542..533ef7f53b92 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
@@ -26,7 +26,7 @@ class ForOpConversion final : public SPIRVOpLowering<loop::ForOp> {
public:
using SPIRVOpLowering<loop::ForOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -37,7 +37,7 @@ class IfOpConversion final : public SPIRVOpLowering<loop::IfOp> {
public:
using SPIRVOpLowering<loop::IfOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(loop::IfOp IfOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -47,11 +47,11 @@ class TerminatorOpConversion final : public SPIRVOpLowering<loop::YieldOp> {
public:
using SPIRVOpLowering<loop::YieldOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(loop::YieldOp terminatorOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(terminatorOp);
- return matchSuccess();
+ return success();
}
};
@@ -62,7 +62,7 @@ class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
public:
using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -75,7 +75,7 @@ class WorkGroupSizeConversion : public SPIRVOpLowering<gpu::BlockDimOp> {
public:
using SPIRVOpLowering<gpu::BlockDimOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(gpu::BlockDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -85,7 +85,7 @@ class GPUFuncOpConversion final : public SPIRVOpLowering<gpu::GPUFuncOp> {
public:
using SPIRVOpLowering<gpu::GPUFuncOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
@@ -98,7 +98,7 @@ class GPUModuleConversion final : public SPIRVOpLowering<gpu::GPUModuleOp> {
public:
using SPIRVOpLowering<gpu::GPUModuleOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -109,7 +109,7 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
public:
using SPIRVOpLowering<gpu::ReturnOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -120,7 +120,7 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
// loop::ForOp.
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// loop::ForOp can be lowered to the structured control flow represented by
@@ -186,14 +186,14 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
rewriter.eraseOp(forOp);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
// loop::IfOp.
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
IfOpConversion::matchAndRewrite(loop::IfOp ifOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// When lowering `loop::IfOp` we explicitly create a selection header block
@@ -238,7 +238,7 @@ IfOpConversion::matchAndRewrite(loop::IfOp ifOp, ArrayRef<Value> operands,
elseBlock, ArrayRef<Value>());
rewriter.eraseOp(ifOp);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
@@ -261,36 +261,36 @@ static Optional<int32_t> getLaunchConfigIndex(Operation *op) {
}
template <typename SourceOp, spirv::BuiltIn builtin>
-PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
+LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto index = getLaunchConfigIndex(op);
if (!index)
- return this->matchFailure();
+ return failure();
// SPIR-V invocation builtin variables are a vector of type <3xi32>
auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
op, rewriter.getIntegerType(32), spirvBuiltin,
rewriter.getI32ArrayAttr({index.getValue()}));
- return this->matchSuccess();
+ return success();
}
-PatternMatchResult WorkGroupSizeConversion::matchAndRewrite(
+LogicalResult WorkGroupSizeConversion::matchAndRewrite(
gpu::BlockDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto index = getLaunchConfigIndex(op);
if (!index)
- return matchFailure();
+ return failure();
auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
auto val = workGroupSizeAttr.getValue<int32_t>(index.getValue());
auto convertedType = typeConverter.convertType(op.getResult().getType());
if (!convertedType)
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
op, convertedType, IntegerAttr::get(convertedType, val));
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
@@ -343,11 +343,11 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter,
return newFuncOp;
}
-PatternMatchResult GPUFuncOpConversion::matchAndRewrite(
+LogicalResult GPUFuncOpConversion::matchAndRewrite(
gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!gpu::GPUDialect::isKernel(funcOp))
- return matchFailure();
+ return failure();
SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
for (auto argNum : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
@@ -358,22 +358,22 @@ PatternMatchResult GPUFuncOpConversion::matchAndRewrite(
auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
if (!entryPointAttr) {
funcOp.emitRemark("match failure: missing 'spv.entry_point_abi' attribute");
- return matchFailure();
+ return failure();
}
spirv::FuncOp newFuncOp = lowerAsEntryFunction(
funcOp, typeConverter, rewriter, entryPointAttr, argABI);
if (!newFuncOp)
- return matchFailure();
+ return failure();
newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(),
rewriter.getContext()));
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
// ModuleOp with gpu.module.
//===----------------------------------------------------------------------===//
-PatternMatchResult GPUModuleConversion::matchAndRewrite(
+LogicalResult GPUModuleConversion::matchAndRewrite(
gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto spvModule = rewriter.create<spirv::ModuleOp>(
@@ -389,21 +389,21 @@ PatternMatchResult GPUModuleConversion::matchAndRewrite(
// legalized later.
spvModuleRegion.back().erase();
rewriter.eraseOp(moduleOp);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
// GPU return inside kernel functions to SPIR-V return.
//===----------------------------------------------------------------------===//
-PatternMatchResult GPUReturnOpConversion::matchAndRewrite(
+LogicalResult GPUReturnOpConversion::matchAndRewrite(
gpu::ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!operands.empty())
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 05e593ba300c..577b134fa5ed 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -130,7 +130,7 @@ class RangeOpConversion : public ConvertToLLVMPattern {
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
@@ -146,7 +146,7 @@ class RangeOpConversion : public ConvertToLLVMPattern {
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
rewriter.replaceOp(op, desc);
- return matchSuccess();
+ return success();
}
};
@@ -160,14 +160,14 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
lowering_) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto reshapeOp = cast<ReshapeOp>(op);
MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
if (!dstType.hasStaticShape())
- return matchFailure();
+ return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -175,7 +175,7 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
if (failed(res) || llvm::any_of(strides, [](int64_t val) {
return ShapedType::isDynamicStrideOrOffset(val);
}))
- return matchFailure();
+ return failure();
edsc::ScopedContext context(rewriter, op->getLoc());
ReshapeOpOperandAdaptor adaptor(operands);
@@ -189,7 +189,7 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
for (auto en : llvm::enumerate(strides))
desc.setConstantStride(en.index(), en.value());
rewriter.replaceOp(op, {desc});
- return matchSuccess();
+ return success();
}
};
@@ -204,7 +204,7 @@ class SliceOpConversion : public ConvertToLLVMPattern {
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
@@ -247,7 +247,7 @@ class SliceOpConversion : public ConvertToLLVMPattern {
// Corner case, no sizes or strides: early return the descriptor.
if (sliceOp.getShapedType().getRank() == 0)
- return rewriter.replaceOp(op, {desc}), matchSuccess();
+ return rewriter.replaceOp(op, {desc}), success();
Value zero = llvm_constant(
int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
@@ -279,7 +279,7 @@ class SliceOpConversion : public ConvertToLLVMPattern {
}
rewriter.replaceOp(op, {desc});
- return matchSuccess();
+ return success();
}
};
@@ -297,7 +297,7 @@ class TransposeOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(TransposeOp::getOperationName(), context,
lowering_) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
@@ -308,7 +308,7 @@ class TransposeOpConversion : public ConvertToLLVMPattern {
auto transposeOp = cast<TransposeOp>(op);
// No permutation, early exit.
if (transposeOp.permutation().isIdentity())
- return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
+ return rewriter.replaceOp(op, {baseDesc}), success();
BaseViewConversionHelper desc(
typeConverter.convertType(transposeOp.getShapedType()));
@@ -330,7 +330,7 @@ class TransposeOpConversion : public ConvertToLLVMPattern {
}
rewriter.replaceOp(op, {desc});
- return matchSuccess();
+ return success();
}
};
@@ -340,11 +340,11 @@ class YieldOpConversion : public ConvertToLLVMPattern {
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
- return matchSuccess();
+ return success();
}
};
} // namespace
@@ -416,15 +416,15 @@ class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
public:
using OpRewritePattern<LinalgOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(LinalgOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
if (!libraryCallName)
- return this->matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
- return this->matchSuccess();
+ return success();
}
};
@@ -434,22 +434,22 @@ template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
public:
using OpRewritePattern<CopyOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(CopyOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(CopyOp op,
+ PatternRewriter &rewriter) const override {
auto inputPerm = op.inputPermutation();
if (inputPerm.hasValue() && !inputPerm->isIdentity())
- return matchFailure();
+ return failure();
auto outputPerm = op.outputPermutation();
if (outputPerm.hasValue() && !outputPerm->isIdentity())
- return matchFailure();
+ return failure();
auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
if (!libraryCallName)
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
- return matchSuccess();
+ return success();
}
};
@@ -460,12 +460,12 @@ class LinalgOpConversion<IndexedGenericOp>
public:
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(IndexedGenericOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(IndexedGenericOp op,
+ PatternRewriter &rewriter) const override {
auto libraryCallName =
getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
if (!libraryCallName)
- return this->matchFailure();
+ return failure();
// TODO(pifon, ntv): Use induction variables values instead of zeros, when
// IndexedGenericOp is tiled.
@@ -483,7 +483,7 @@ class LinalgOpConversion<IndexedGenericOp>
}
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
ArrayRef<Type>{}, operands);
- return this->matchSuccess();
+ return success();
}
};
@@ -495,8 +495,8 @@ class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
public:
using OpRewritePattern<CopyOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(CopyOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(CopyOp op,
+ PatternRewriter &rewriter) const override {
Value in = op.input(), out = op.output();
// If either inputPerm or outputPerm are non-identities, insert transposes.
@@ -511,10 +511,10 @@ class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
// If nothing was transposed, fail and let the conversion kick in.
if (in == op.input() && out == op.output())
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index 759812633136..90ec624a8cc1 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -54,7 +54,7 @@ class SingleWorkgroupReduction final
static Optional<linalg::RegionMatcher::BinaryOpKind>
matchAsPerformingReduction(linalg::GenericOp genericOp);
- PatternMatchResult
+ LogicalResult
matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -109,7 +109,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction(
return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
}
-PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
+LogicalResult SingleWorkgroupReduction::matchAndRewrite(
linalg::GenericOp genericOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Operation *op = genericOp.getOperation();
@@ -118,19 +118,19 @@ PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
auto binaryOpKind = matchAsPerformingReduction(genericOp);
if (!binaryOpKind)
- return matchFailure();
+ return failure();
// Query the shader interface for local workgroup size to make sure the
// invocation configuration fits with the input memref's shape.
DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
if (!localSize)
- return matchFailure();
+ return failure();
if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
- return matchFailure();
+ return failure();
if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
[](const APInt &size) { return !size.isOneValue(); }))
- return matchFailure();
+ return failure();
// TODO(antiagainst): Query the target environment to make sure the current
// workload fits in a local workgroup.
@@ -195,7 +195,7 @@ PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, &rewriter);
rewriter.eraseOp(genericOp);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
index a16c4a0c5cfb..8f7c76c921e1 100644
--- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
+++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
@@ -98,8 +98,8 @@ struct LoopToStandardPass : public OperationPass<LoopToStandardPass> {
struct ForLowering : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ForOp forOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override;
};
// Create a CFG subgraph for the loop.if operation (including its "then" and
@@ -147,20 +147,20 @@ struct ForLowering : public OpRewritePattern<ForOp> {
struct IfLowering : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(IfOp ifOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(IfOp ifOp,
+ PatternRewriter &rewriter) const override;
};
struct ParallelLowering : public OpRewritePattern<mlir::loop::ParallelOp> {
using OpRewritePattern<mlir::loop::ParallelOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
+ PatternRewriter &rewriter) const override;
};
} // namespace
-PatternMatchResult
-ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
+LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
// Start by splitting the block containing the 'loop.for' into two parts.
@@ -189,7 +189,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
auto step = forOp.step();
auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
if (!stepped)
- return matchFailure();
+ return failure();
SmallVector<Value, 8> loopCarried;
loopCarried.push_back(stepped);
@@ -202,7 +202,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
Value lowerBound = forOp.lowerBound();
Value upperBound = forOp.upperBound();
if (!lowerBound || !upperBound)
- return matchFailure();
+ return failure();
// The initial values of loop-carried values is obtained from the operands
// of the loop operation.
@@ -222,11 +222,11 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
// The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration.
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
- return matchSuccess();
+ return success();
}
-PatternMatchResult
-IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
+LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
+ PatternRewriter &rewriter) const {
auto loc = ifOp.getLoc();
// Start by splitting the block containing the 'loop.if' into two parts.
@@ -265,10 +265,10 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
// Ok, we're done!
rewriter.eraseOp(ifOp);
- return matchSuccess();
+ return success();
}
-PatternMatchResult
+LogicalResult
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const {
Location loc = parallelOp.getLoc();
@@ -344,7 +344,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
rewriter.replaceOp(parallelOp, loopResults);
- return matchSuccess();
+ return success();
}
void mlir::populateLoopToStdConversionPatterns(
diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
index 9ba3f40d24e9..8023226bc300 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
@@ -497,8 +497,8 @@ namespace {
struct ParallelToGpuLaunchLowering : public OpRewritePattern<ParallelOp> {
using OpRewritePattern<ParallelOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ParallelOp parallelOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(ParallelOp parallelOp,
+ PatternRewriter &rewriter) const override;
};
struct MappingAnnotation {
@@ -742,7 +742,7 @@ static LogicalResult processParallelLoop(ParallelOp parallelOp,
/// the actual loop bound. This only works if an static upper bound for the
/// dynamic loop bound can be defived, currently via analyzing `affine.min`
/// operations.
-PatternMatchResult
+LogicalResult
ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const {
// Create a launch operation. We start with bound one for all grid/block
@@ -761,7 +761,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
SmallVector<Operation *, 16> worklist;
if (failed(processParallelLoop(parallelOp, launchOp, cloningMap, worklist,
launchBounds, rewriter)))
- return matchFailure();
+ return failure();
// Whether we have seen any side-effects. Reset when leaving an inner scope.
bool seenSideeffects = false;
@@ -778,13 +778,13 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
// Before entering a nested scope, make sure there have been no
// sideeffects until now.
if (seenSideeffects)
- return matchFailure();
+ return failure();
// A nested loop.parallel needs insertion of code to compute indices.
// Insert that now. This will also update the worklist with the loops
// body.
if (failed(processParallelLoop(nestedParallel, launchOp, cloningMap,
worklist, launchBounds, rewriter)))
- return matchFailure();
+ return failure();
} else if (op == launchOp.getOperation()) {
// Found our sentinel value. We have finished the operations from one
// nesting level, pop one level back up.
@@ -802,7 +802,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
clone->getNumRegions() != 0;
// If we are no longer in the innermost scope, sideeffects are disallowed.
if (seenSideeffects && leftNestingScope)
- return matchFailure();
+ return failure();
}
}
@@ -812,7 +812,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
launchOp.setOperand(std::get<0>(bound), std::get<1>(bound));
rewriter.eraseOp(parallelOp);
- return matchSuccess();
+ return success();
}
void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 3e99eb598dfc..250ff3682653 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -946,7 +946,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
bool emitCWrappers)
: FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
@@ -962,7 +962,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
}
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
private:
@@ -976,7 +976,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
struct BarePtrFuncOpConversion : public FuncOpConversionBase {
using FuncOpConversionBase::FuncOpConversionBase;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
@@ -990,7 +990,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (newFuncOp.getBody().empty()) {
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
// Promote bare pointers from MemRef arguments to a MemRef descriptor struct
@@ -1017,7 +1017,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
}
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
};
@@ -1109,7 +1109,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numResults = op->getNumResults();
@@ -1119,7 +1119,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
packedType =
this->typeConverter.packFunctionResults(op->getResultTypes());
if (!packedType)
- return this->matchFailure();
+ return failure();
}
auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
@@ -1127,10 +1127,10 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
- return rewriter.eraseOp(op), this->matchSuccess();
+ return rewriter.eraseOp(op), success();
if (numResults == 1)
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
- this->matchSuccess();
+ success();
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
@@ -1143,7 +1143,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
rewriter.getI64ArrayAttr(i)));
}
rewriter.replaceOp(op, results);
- return this->matchSuccess();
+ return success();
}
};
@@ -1207,7 +1207,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ValidateOpCount<SourceOp, OpCount>();
@@ -1221,7 +1221,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// Cannot convert ops if their operands are not of LLVM type.
for (Value operand : operands) {
if (!operand || !operand.getType().isa<LLVM::LLVMType>())
- return this->matchFailure();
+ return failure();
}
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
@@ -1230,7 +1230,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
auto newOp = rewriter.create<TargetOp>(
op->getLoc(), operands[0].getType(), operands, op->getAttrs());
rewriter.replaceOp(op, newOp.getResult());
- return this->matchSuccess();
+ return success();
}
if (succeeded(HandleMultidimensionalVectors(
@@ -1240,8 +1240,8 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
operands, op->getAttrs());
},
rewriter)))
- return this->matchSuccess();
- return this->matchFailure();
+ return success();
+ return failure();
}
};
@@ -1381,24 +1381,24 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
: LLVMLegalizationPattern<AllocOp>(dialect_, converter),
useAlloca(useAlloca) {}
- PatternMatchResult match(Operation *op) const override {
+ LogicalResult match(Operation *op) const override {
MemRefType type = cast<AllocOp>(op).getType();
if (isSupportedMemRefType(type))
- return matchSuccess();
+ return success();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);
if (failed(successStrides))
- return matchFailure();
+ return failure();
// Dynamic strides are ok if they can be deduced from dynamic sizes (which
// is guaranteed when succeeded(successStrides)). Dynamic offset however can
// never be alloc'ed.
if (offset == MemRefType::getDynamicStrideOrOffset())
- return matchFailure();
+ return failure();
- return matchSuccess();
+ return success();
}
void rewrite(Operation *op, ArrayRef<Value> operands,
@@ -1574,7 +1574,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = LLVMLegalizationPattern<CallOpType>;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<CallOpType> transformed(operands);
@@ -1595,7 +1595,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
if (numResults != 0) {
if (!(packedResult =
this->typeConverter.packFunctionResults(resultTypes)))
- return this->matchFailure();
+ return failure();
}
auto promoted = this->typeConverter.promoteMemRefDescriptors(
@@ -1606,7 +1606,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
// If < 2 results, packing did not do anything and we can just return.
if (numResults < 2) {
rewriter.replaceOp(op, newOp.getResults());
- return this->matchSuccess();
+ return success();
}
// Otherwise, it had been converted to an operation producing a structure.
@@ -1624,7 +1624,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
}
rewriter.replaceOp(op, results);
- return this->matchSuccess();
+ return success();
}
};
@@ -1647,11 +1647,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
: LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
useAlloca(useAlloca) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (useAlloca)
- return rewriter.eraseOp(op), matchSuccess();
+ return rewriter.eraseOp(op), success();
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
@@ -1673,7 +1673,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
memref.allocatedPtr(rewriter, op->getLoc()));
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
- return matchSuccess();
+ return success();
}
bool useAlloca;
@@ -1683,7 +1683,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
using LLVMLegalizationPattern<RsqrtOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<RsqrtOp> transformed(operands);
@@ -1691,7 +1691,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
if (!operandType)
- return matchFailure();
+ return failure();
auto loc = op->getLoc();
auto resultType = *op->result_type_begin();
@@ -1709,12 +1709,12 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
}
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
- return this->matchSuccess();
+ return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
- return this->matchFailure();
+ return failure();
if (succeeded(HandleMultidimensionalVectors(
op, operands, typeConverter,
@@ -1732,8 +1732,8 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
sqrt);
},
rewriter)))
- return this->matchSuccess();
- return this->matchFailure();
+ return success();
+ return failure();
}
};
@@ -1741,7 +1741,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@@ -1753,7 +1753,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
if (!operandType)
- return matchFailure();
+ return failure();
std::string functionName;
if (operandType.isFloatTy())
@@ -1761,7 +1761,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
else if (operandType.isDoubleTy())
functionName = "tanh";
else
- return matchFailure();
+ return failure();
// Get a reference to the tanh function, inserting it if necessary.
Operation *tanhFunc =
@@ -1783,14 +1783,14 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc),
transformed.operand());
- return matchSuccess();
+ return success();
}
};
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
- PatternMatchResult match(Operation *op) const override {
+ LogicalResult match(Operation *op) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
Type srcType = memRefCastOp.getOperand().getType();
Type dstType = memRefCastOp.getType();
@@ -1801,8 +1801,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
return (isSupportedMemRefType(targetType) &&
isSupportedMemRefType(sourceType))
- ? matchSuccess()
- : matchFailure();
+ ? success()
+ : failure();
}
// At least one of the operands is unranked type
@@ -1812,8 +1812,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
// Unranked to unranked cast is disallowed
return !(srcType.isa<UnrankedMemRefType>() &&
dstType.isa<UnrankedMemRefType>())
- ? matchSuccess()
- : matchFailure();
+ ? success()
+ : failure();
}
void rewrite(Operation *op, ArrayRef<Value> operands,
@@ -1886,17 +1886,17 @@ struct DialectCastOpLowering
: public LLVMLegalizationPattern<LLVM::DialectCastOp> {
using LLVMLegalizationPattern<LLVM::DialectCastOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto castOp = cast<LLVM::DialectCastOp>(op);
OperandAdaptor<LLVM::DialectCastOp> transformed(operands);
if (transformed.in().getType() !=
typeConverter.convertType(castOp.getType())) {
- return matchFailure();
+ return failure();
}
rewriter.replaceOp(op, transformed.in());
- return matchSuccess();
+ return success();
}
};
@@ -1905,7 +1905,7 @@ struct DialectCastOpLowering
struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<DimOp>(op);
@@ -1922,7 +1922,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
// Use constant for static size.
rewriter.replaceOp(
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
- return matchSuccess();
+ return success();
}
};
@@ -1934,10 +1934,9 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
using Base = LoadStoreOpLowering<Derived>;
- PatternMatchResult match(Operation *op) const override {
+ LogicalResult match(Operation *op) const override {
MemRefType type = cast<Derived>(op).getMemRefType();
- return isSupportedMemRefType(type) ? this->matchSuccess()
- : this->matchFailure();
+ return isSupportedMemRefType(type) ? success() : failure();
}
// Given subscript indices and array sizes in row-major order,
@@ -2010,7 +2009,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
using Base::Base;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
@@ -2020,7 +2019,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
- return matchSuccess();
+ return success();
}
};
@@ -2029,7 +2028,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
using Base::Base;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
@@ -2039,7 +2038,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
- return matchSuccess();
+ return success();
}
};
@@ -2048,7 +2047,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
using Base::Base;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto prefetchOp = cast<PrefetchOp>(op);
@@ -2072,7 +2071,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite,
localityHint, isData);
- return matchSuccess();
+ return success();
}
};
@@ -2083,7 +2082,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
IndexCastOpOperandAdaptor transformed(operands);
@@ -2104,7 +2103,7 @@ struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
else
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
transformed.in());
- return matchSuccess();
+ return success();
}
};
@@ -2118,7 +2117,7 @@ static LLVMPredType convertCmpPredicate(StdPredType pred) {
struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpiOp = cast<CmpIOp>(op);
@@ -2130,14 +2129,14 @@ struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
- return matchSuccess();
+ return success();
}
};
struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpfOp = cast<CmpFOp>(op);
@@ -2149,7 +2148,7 @@ struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
- return matchSuccess();
+ return success();
}
};
@@ -2189,12 +2188,12 @@ struct OneToOneLLVMTerminatorLowering
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
op->getAttrs());
- return this->matchSuccess();
+ return success();
}
};
@@ -2207,7 +2206,7 @@ struct OneToOneLLVMTerminatorLowering
struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op->getNumOperands();
@@ -2216,12 +2215,12 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
if (numArguments == 0) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, ArrayRef<Type>(), ArrayRef<Value>(), op->getAttrs());
- return matchSuccess();
+ return success();
}
if (numArguments == 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, ArrayRef<Type>(), operands.front(), op->getAttrs());
- return matchSuccess();
+ return success();
}
// Otherwise, we need to pack the arguments into an LLVM struct type before
@@ -2237,7 +2236,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, ArrayRef<Type>(), packed,
op->getAttrs());
- return matchSuccess();
+ return success();
}
};
@@ -2256,13 +2255,13 @@ struct CondBranchOpLowering
struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() != 1)
- return matchFailure();
+ return failure();
// First insert it into an undef vector so we can shuffle it.
auto vectorType = typeConverter.convertType(splatOp.getType());
@@ -2280,7 +2279,7 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
// Shuffle the value across the desired number of elements.
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
- return matchSuccess();
+ return success();
}
};
@@ -2290,14 +2289,14 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
OperandAdaptor<SplatOp> adaptor(operands);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() == 1)
- return matchFailure();
+ return failure();
// First insert it into an undef vector so we can shuffle it.
auto loc = op->getLoc();
@@ -2305,7 +2304,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
if (!llvmArrayTy || !llvmVectorTy)
- return matchFailure();
+ return failure();
// Construct returned value.
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
@@ -2332,7 +2331,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
position);
});
rewriter.replaceOp(op, desc);
- return matchSuccess();
+ return success();
}
};
@@ -2344,7 +2343,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -2376,7 +2375,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
auto targetDescTy = typeConverter.convertType(viewMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!sourceElementTy || !targetDescTy)
- return matchFailure();
+ return failure();
// Currently, only rank > 0 and full or no operands are supported. Fail to
// convert otherwise.
@@ -2385,22 +2384,22 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
(!dynamicOffsets.empty() && rank != dynamicOffsets.size()) ||
(!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
(!dynamicStrides.empty() && rank != dynamicStrides.size()))
- return matchFailure();
+ return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
if (failed(successStrides))
- return matchFailure();
+ return failure();
// Fail to convert if neither a dynamic nor static offset is available.
if (dynamicOffsets.empty() &&
offset == MemRefType::getDynamicStrideOrOffset())
- return matchFailure();
+ return failure();
// Create the descriptor.
if (!operands.front().getType().isa<LLVM::LLVMType>())
- return matchFailure();
+ return failure();
MemRefDescriptor sourceMemRef(operands.front());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
@@ -2460,7 +2459,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
}
rewriter.replaceOp(op, {targetMemRef});
- return matchSuccess();
+ return success();
}
};
@@ -2505,7 +2504,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
return createIndexConstant(rewriter, loc, 1);
}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -2520,14 +2519,13 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
if (!targetDescTy)
return op->emitWarning("Target descriptor type not converted to LLVM"),
- matchFailure();
+ failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
if (failed(successStrides))
- return op->emitWarning("cannot cast to non-strided shape"),
- matchFailure();
+ return op->emitWarning("cannot cast to non-strided shape"), failure();
// Create the descriptor.
MemRefDescriptor sourceMemRef(adaptor.source());
@@ -2560,12 +2558,11 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
// Early exit for 0-D corner case.
if (viewMemRefType.getRank() == 0)
- return rewriter.replaceOp(op, {targetMemRef}), matchSuccess();
+ return rewriter.replaceOp(op, {targetMemRef}), success();
// Fields 4 and 5: Update sizes and strides.
if (strides.back() != 1)
- return op->emitWarning("cannot cast to non-contiguous shape"),
- matchFailure();
+ return op->emitWarning("cannot cast to non-contiguous shape"), failure();
Value stride = nullptr, nextSize = nullptr;
// Drop the dynamic stride from the operand list, if present.
ArrayRef<Value> sizeOperands(sizeAndOffsetOperands);
@@ -2583,7 +2580,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
}
rewriter.replaceOp(op, {targetMemRef});
- return matchSuccess();
+ return success();
}
};
@@ -2591,7 +2588,7 @@ struct AssumeAlignmentOpLowering
: public LLVMLegalizationPattern<AssumeAlignmentOp> {
using LLVMLegalizationPattern<AssumeAlignmentOp>::LLVMLegalizationPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<AssumeAlignmentOp> transformed(operands);
@@ -2622,7 +2619,7 @@ struct AssumeAlignmentOpLowering
rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
};
@@ -2657,13 +2654,13 @@ namespace {
struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto atomicOp = cast<AtomicRMWOp>(op);
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
- return matchFailure();
+ return failure();
OperandAdaptor<AtomicRMWOp> adaptor(operands);
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
@@ -2672,7 +2669,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
op, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel);
- return matchSuccess();
+ return success();
}
};
@@ -2706,13 +2703,13 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto atomicOp = cast<AtomicRMWOp>(op);
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (maybeKind)
- return matchFailure();
+ return failure();
LLVM::FCmpPredicate predicate;
switch (atomicOp.kind()) {
@@ -2723,7 +2720,7 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
predicate = LLVM::FCmpPredicate::olt;
break;
default:
- return matchFailure();
+ return failure();
}
OperandAdaptor<AtomicRMWOp> adaptor(operands);
@@ -2779,7 +2776,7 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(op, {newLoaded});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index a1a599b0ab08..310dcd8a86bd 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -31,7 +31,7 @@ class ConstantCompositeOpConversion final : public SPIRVOpLowering<ConstantOp> {
public:
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(ConstantOp constCompositeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -45,7 +45,7 @@ class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
public:
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -55,7 +55,7 @@ class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> {
public:
using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -65,7 +65,7 @@ class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
public:
using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -81,14 +81,14 @@ class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
public:
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType =
this->typeConverter.convertType(operation.getResult().getType());
rewriter.template replaceOpWithNewOp<SPIRVOp>(
operation, resultType, operands, ArrayRef<NamedAttribute>());
- return this->matchSuccess();
+ return success();
}
};
@@ -100,7 +100,7 @@ class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
public:
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -111,7 +111,7 @@ class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
public:
using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -121,7 +121,7 @@ class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
public:
using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -134,7 +134,7 @@ class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
public:
using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -145,22 +145,22 @@ class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
// ConstantOp with composite type.
//===----------------------------------------------------------------------===//
-PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
+LogicalResult ConstantCompositeOpConversion::matchAndRewrite(
ConstantOp constCompositeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto compositeType =
constCompositeOp.getResult().getType().dyn_cast<RankedTensorType>();
if (!compositeType)
- return matchFailure();
+ return failure();
auto spirvCompositeType = typeConverter.convertType(compositeType);
if (!spirvCompositeType)
- return matchFailure();
+ return failure();
auto linearizedElements =
constCompositeOp.value().dyn_cast<DenseElementsAttr>();
if (!linearizedElements)
- return matchFailure();
+ return failure();
// If composite type has rank greater than one, then perform linearization.
if (compositeType.getRank() > 1) {
@@ -171,24 +171,24 @@ PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
constCompositeOp, spirvCompositeType, linearizedElements);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
// ConstantOp with index type.
//===----------------------------------------------------------------------===//
-PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
+LogicalResult ConstantIndexOpConversion::matchAndRewrite(
ConstantOp constIndexOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!constIndexOp.getResult().getType().isa<IndexType>()) {
- return matchFailure();
+ return failure();
}
// The attribute has index type which is not directly supported in
// SPIR-V. Get the integer value and create a new IntegerAttr.
auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
if (!constAttr) {
- return matchFailure();
+ return failure();
}
// Use the bitwidth set in the value attribute to decide the result type
@@ -197,7 +197,7 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
auto constVal = constAttr.getValue();
auto constValType = constAttr.getType().dyn_cast<IndexType>();
if (!constValType) {
- return matchFailure();
+ return failure();
}
auto spirvConstType =
typeConverter.convertType(constIndexOp.getResult().getType());
@@ -205,14 +205,14 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
spirvConstVal);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpFOpOperandAdaptor cmpFOpOperands(operands);
@@ -223,7 +223,7 @@ CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
cmpFOpOperands.lhs(), \
cmpFOpOperands.rhs()); \
- return matchSuccess();
+ return success();
// Ordered.
DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
@@ -245,14 +245,14 @@ CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
default:
break;
}
- return matchFailure();
+ return failure();
}
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpIOpOperandAdaptor cmpIOpOperands(operands);
@@ -263,7 +263,7 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
cmpIOpOperands.lhs(), \
cmpIOpOperands.rhs()); \
- return matchSuccess();
+ return success();
DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
@@ -278,14 +278,14 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
#undef DISPATCH
}
- return matchFailure();
+ return failure();
}
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpOperandAdaptor loadOperands(operands);
@@ -293,42 +293,42 @@ LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
typeConverter, loadOp.memref().getType().cast<MemRefType>(),
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (returnOp.getNumOperands()) {
- return matchFailure();
+ return failure();
}
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
SelectOpOperandAdaptor selectOperands(operands);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
selectOperands.true_value(),
selectOperands.false_value());
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
StoreOpOperandAdaptor storeOperands(operands);
@@ -338,7 +338,7 @@ StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
storeOperands.value());
- return matchSuccess();
+ return success();
}
namespace {
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index efa7bb0306f8..3af01e025801 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -26,8 +26,8 @@ class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
public:
using OpRewritePattern<LoadOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(LoadOp loadOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const override;
};
/// Merges subview operation with store operation.
@@ -35,8 +35,8 @@ class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
public:
using OpRewritePattern<StoreOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(StoreOp storeOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const override;
};
} // namespace
@@ -107,43 +107,43 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
// Folding SubViewOp and LoadOp.
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
PatternRewriter &rewriter) const {
auto subViewOp = dyn_cast_or_null<SubViewOp>(loadOp.memref().getDefiningOp());
if (!subViewOp) {
- return matchFailure();
+ return failure();
}
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
loadOp.indices(), sourceIndices)))
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
sourceIndices);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
// Folding SubViewOp and StoreOp.
//===----------------------------------------------------------------------===//
-PatternMatchResult
+LogicalResult
StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
PatternRewriter &rewriter) const {
auto subViewOp =
dyn_cast_or_null<SubViewOp>(storeOp.memref().getDefiningOp());
if (!subViewOp) {
- return matchFailure();
+ return failure();
}
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
storeOp.indices(), sourceIndices)))
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
subViewOp.source(), sourceIndices);
- return matchSuccess();
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 459c7243fd46..2b3020d046a2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -133,13 +133,13 @@ class VectorBroadcastOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context,
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto broadcastOp = cast<vector::BroadcastOp>(op);
VectorType dstVectorType = broadcastOp.getVectorType();
if (typeConverter.convertType(dstVectorType) == nullptr)
- return matchFailure();
+ return failure();
// Rewrite when the full vector type can be lowered (which
// implies all 'reduced' types can be lowered too).
auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
@@ -149,7 +149,7 @@ class VectorBroadcastOpConversion : public ConvertToLLVMPattern {
op, expandRanks(adaptor.source(), // source value to be expanded
op->getLoc(), // location of original broadcast
srcVectorType, dstVectorType, rewriter));
- return matchSuccess();
+ return success();
}
private:
@@ -284,7 +284,7 @@ class VectorMatmulOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto matmulOp = cast<vector::MatmulOp>(op);
@@ -293,7 +293,7 @@ class VectorMatmulOpConversion : public ConvertToLLVMPattern {
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
matmulOp.rhs_columns());
- return matchSuccess();
+ return success();
}
};
@@ -304,7 +304,7 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto reductionOp = cast<vector::ReductionOp>(op);
@@ -335,8 +335,8 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
op, llvmType, operands[0]);
else
- return matchFailure();
- return matchSuccess();
+ return failure();
+ return success();
} else if (eltType.isF32() || eltType.isF64()) {
// Floating-point reductions: add/mul/min/max
@@ -364,10 +364,10 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
op, llvmType, operands[0]);
else
- return matchFailure();
- return matchSuccess();
+ return failure();
+ return success();
}
- return matchFailure();
+ return failure();
}
};
@@ -378,7 +378,7 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -392,7 +392,7 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
// Bail if result type cannot be lowered.
if (!llvmType)
- return matchFailure();
+ return failure();
// Get rank and dimension sizes.
int64_t rank = vectorType.getRank();
@@ -406,7 +406,7 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
rewriter.replaceOp(op, shuffle);
- return matchSuccess();
+ return success();
}
// For all other cases, insert the individual values individually.
@@ -425,7 +425,7 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
llvmType, rank, insPos++);
}
rewriter.replaceOp(op, insert);
- return matchSuccess();
+ return success();
}
};
@@ -436,7 +436,7 @@ class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
context, typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
@@ -446,11 +446,11 @@ class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
// Bail if result type cannot be lowered.
if (!llvmType)
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
op, llvmType, adaptor.vector(), adaptor.position());
- return matchSuccess();
+ return success();
}
};
@@ -461,7 +461,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -474,14 +474,14 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
// Bail if result type cannot be lowered.
if (!llvmResultType)
- return matchFailure();
+ return failure();
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
- return matchSuccess();
+ return success();
}
// Potential extraction of 1-D vector from array.
@@ -505,7 +505,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
- return matchSuccess();
+ return success();
}
};
@@ -530,17 +530,17 @@ class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::FMAOpOperandAdaptor(operands);
vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
adaptor.acc());
- return matchSuccess();
+ return success();
}
};
@@ -551,7 +551,7 @@ class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
context, typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
@@ -561,11 +561,11 @@ class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
// Bail if result type cannot be lowered.
if (!llvmType)
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
- return matchSuccess();
+ return success();
}
};
@@ -576,7 +576,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -589,7 +589,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
// Bail if result type cannot be lowered.
if (!llvmResultType)
- return matchFailure();
+ return failure();
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
@@ -597,7 +597,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
rewriter.replaceOp(op, inserted);
- return matchSuccess();
+ return success();
}
// Potential extraction of 1-D vector from array.
@@ -632,7 +632,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
}
rewriter.replaceOp(op, inserted);
- return matchSuccess();
+ return success();
}
};
@@ -661,11 +661,11 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
public:
using OpRewritePattern<FMAOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(FMAOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(FMAOp op,
+ PatternRewriter &rewriter) const override {
auto vType = op.getVectorType();
if (vType.getRank() < 2)
- return matchFailure();
+ return failure();
auto loc = op.getLoc();
auto elemType = vType.getElementType();
@@ -680,7 +680,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
desc = rewriter.create<InsertOp>(loc, fma, desc, i);
}
rewriter.replaceOp(op, desc);
- return matchSuccess();
+ return success();
}
};
@@ -704,19 +704,19 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(InsertStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
- return matchFailure();
+ return failure();
auto loc = op.getLoc();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff == 0)
- return matchFailure();
+ return failure();
int64_t rankRest = dstType.getRank() - rankDiff;
// Extract / insert the subvector of matching rank and InsertStridedSlice
@@ -735,7 +735,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
op, stridedSliceInnerOp.getResult(), op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
/*dropFront=*/rankRest));
- return matchSuccess();
+ return success();
}
};
@@ -753,22 +753,22 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(InsertStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
- return matchFailure();
+ return failure();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff != 0)
- return matchFailure();
+ return failure();
if (srcType == dstType) {
rewriter.replaceOp(op, op.source());
- return matchSuccess();
+ return success();
}
int64_t offset =
@@ -813,7 +813,7 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
}
rewriter.replaceOp(op, res);
- return matchSuccess();
+ return success();
}
};
@@ -824,7 +824,7 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
: ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
typeConverter) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -837,18 +837,18 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
// Only static shape casts supported atm.
if (!sourceMemRefType.hasStaticShape() ||
!targetMemRefType.hasStaticShape())
- return matchFailure();
+ return failure();
auto llvmSourceDescriptorTy =
operands[0].getType().dyn_cast<LLVM::LLVMType>();
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
- return matchFailure();
+ return failure();
MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
- return matchFailure();
+ return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -866,7 +866,7 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
}
// Only contiguous source tensors supported atm.
if (failed(successStrides) || !isContiguous)
- return matchFailure();
+ return failure();
auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
@@ -901,7 +901,7 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
}
rewriter.replaceOp(op, {desc});
- return matchSuccess();
+ return success();
}
};
@@ -924,7 +924,7 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
//
// TODO(ajcbik): rely solely on libc in future? something else?
//
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto printOp = cast<vector::PrintOp>(op);
@@ -932,7 +932,7 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
Type printType = printOp.getPrintType();
if (typeConverter.convertType(printType) == nullptr)
- return matchFailure();
+ return failure();
// Make sure element type has runtime support (currently just Float/Double).
VectorType vectorType = printType.dyn_cast<VectorType>();
@@ -948,13 +948,13 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
else if (eltType.isF64())
printer = getPrintDouble(op);
else
- return matchFailure();
+ return failure();
// Unroll vector into elementary print calls.
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
private:
@@ -1047,8 +1047,8 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
public:
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(StridedSliceOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(StridedSliceOp op,
+ PatternRewriter &rewriter) const override {
auto dstType = op.getResult().getType().cast<VectorType>();
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
@@ -1089,7 +1089,7 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, {res});
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
index 5736c1d747f4..b16f02ef6b9c 100644
--- a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
+++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
@@ -198,8 +198,8 @@ struct VectorTransferRewriter : public RewritePattern {
}
/// Performs the rewrite.
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
};
/// Lowers TransferReadOp into a combination of:
@@ -246,7 +246,7 @@ struct VectorTransferRewriter : public RewritePattern {
/// Performs the rewrite.
template <>
-PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
+LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
using namespace mlir::edsc::op;
@@ -282,7 +282,7 @@ PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
// 3. Propagate.
rewriter.replaceOp(op, vectorValue.getValue());
- return matchSuccess();
+ return success();
}
/// Lowers TransferWriteOp into a combination of:
@@ -304,7 +304,7 @@ PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
/// TODO(ntv): implement alternatives to clipping.
/// TODO(ntv): support non-data-parallel operations.
template <>
-PatternMatchResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
+LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
using namespace edsc::op;
@@ -340,7 +340,7 @@ PatternMatchResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
(std_dealloc(tmp)); // vexing parse...
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
} // namespace
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index f47fe62963f2..0b8795947e06 100644
--- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -727,8 +727,8 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
AffineMap map, ArrayRef<Value> mapOperands) const;
- PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineOpTy affineOp,
+ PatternRewriter &rewriter) const override {
static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
std::is_same<AffineOpTy, AffinePrefetchOp>::value ||
std::is_same<AffineOpTy, AffineStoreOp>::value ||
@@ -743,10 +743,10 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
composeAffineMapAndOperands(&map, &resultOperands);
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
resultOperands.begin()))
- return this->matchFailure();
+ return failure();
replaceAffineOp(rewriter, affineOp, map, resultOperands);
- return this->matchSuccess();
+ return success();
}
};
@@ -1405,13 +1405,13 @@ namespace {
struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
using OpRewritePattern<AffineForOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AffineForOp forOp,
+ PatternRewriter &rewriter) const override {
// Check that the body only contains a terminator.
if (!has_single_element(*forOp.getBody()))
- return matchFailure();
+ return failure();
rewriter.eraseOp(forOp);
- return matchSuccess();
+ return success();
}
};
} // end anonymous namespace
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
index 4084416a8894..2a876a332ea6 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -111,8 +111,8 @@ namespace {
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(DequantizeCastOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(DequantizeCastOp op,
+ PatternRewriter &rewriter) const override {
Type inputType = op.arg().getType();
Type outputType = op.getResult().getType();
@@ -121,16 +121,16 @@ struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
if (expressedOutputType != outputType) {
// Not a valid uniform cast.
- return matchFailure();
+ return failure();
}
Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
- return matchFailure();
+ return failure();
}
rewriter.replaceOp(op, dequantizedValue);
- return matchSuccess();
+ return success();
}
};
@@ -313,40 +313,40 @@ namespace {
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(RealAddEwOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(RealAddEwOp op,
+ PatternRewriter &rewriter) const override {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
- return matchFailure();
+ return failure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
};
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(RealMulEwOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(RealMulEwOp op,
+ PatternRewriter &rewriter) const override {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
- return matchFailure();
+ return failure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
};
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index ae5c981aad7e..d12611f33b9c 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -380,8 +380,8 @@ struct GpuAllReduceConversion : public RewritePattern {
explicit GpuAllReduceConversion(MLIRContext *context)
: RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
auto funcOp = cast<gpu::GPUFuncOp>(op);
auto callback = [&](gpu::AllReduceOp reduceOp) {
GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
@@ -391,7 +391,7 @@ struct GpuAllReduceConversion : public RewritePattern {
};
while (funcOp.walk(callback).wasInterrupted()) {
}
- return matchSuccess();
+ return success();
}
};
} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index cf95212982c9..8da000fa5260 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -534,10 +534,10 @@ namespace {
struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(GenericOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(GenericOp op,
+ PatternRewriter &rewriter) const override {
if (!op.hasTensorSemantics())
- return matchFailure();
+ return failure();
// Find the first operand that is defined by another generic op on tensors.
for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) {
@@ -551,9 +551,9 @@ struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
if (!fusedOp)
continue;
rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults());
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 22f53b2ab8bc..4dc41e2c87ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -531,13 +531,13 @@ class LinalgRewritePattern : public RewritePattern {
explicit LinalgRewritePattern(MLIRContext *context)
: RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>;
if (failed(Impl::doit(op, rewriter)))
- return matchFailure();
+ return failure();
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
};
@@ -595,26 +595,26 @@ struct FoldAffineOp : public RewritePattern {
FoldAffineOp(MLIRContext *context)
: RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
auto map = affineApplyOp.getAffineMap();
if (map.getNumResults() != 1 || map.getNumInputs() > 1)
- return matchFailure();
+ return failure();
AffineExpr expr = map.getResult(0);
if (map.getNumInputs() == 0) {
if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue());
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
rewriter.replaceOp(op, op->getOperand(0));
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
};
} // namespace
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
index a415e58002a1..2598e8cf5013 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
@@ -30,8 +30,8 @@ class ConvertConstPass : public FunctionPass<ConvertConstPass> {
struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
+ PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
@@ -39,14 +39,14 @@ struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
/// quantized and the operand type is quantizable.
-PatternMatchResult
+LogicalResult
QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
PatternRewriter &rewriter) const {
Attribute value;
// Is the operand a constant?
if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
- return matchFailure();
+ return failure();
}
// Does the qbarrier convert to a quantized type. This will not be true
@@ -56,10 +56,10 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
QuantizedType quantizedElementType =
QuantizedType::getQuantizedElementType(qbarrierResultType);
if (!quantizedElementType) {
- return matchFailure();
+ return failure();
}
if (!QuantizedType::castToStorageType(qbarrierResultType)) {
- return matchFailure();
+ return failure();
}
// Is the operand type compatible with the expressed type of the quantized
@@ -67,20 +67,20 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
// from and to a quantized type).
if (!quantizedElementType.isCompatibleExpressedType(
qbarrier.arg().getType())) {
- return matchFailure();
+ return failure();
}
// Is the constant value a type expressed in a way that we support?
if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
!value.isa<SparseElementsAttr>()) {
- return matchFailure();
+ return failure();
}
Type newConstValueType;
auto newConstValue =
quantizeAttr(value, quantizedElementType, newConstValueType);
if (!newConstValue) {
- return matchFailure();
+ return failure();
}
// When creating the new const op, use a fused location that combines the
@@ -92,7 +92,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
newConstOp);
- return matchSuccess();
+ return success();
}
void ConvertConstPass::runOnFunction() {
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
index 8d3097e0717c..c921aeafda90 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
@@ -35,16 +35,16 @@ class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
: OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
- PatternMatchResult matchAndRewrite(FakeQuantOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(FakeQuantOp op,
+ PatternRewriter &rewriter) const override {
// TODO: If this pattern comes up more frequently, consider adding core
// support for failable rewrites.
if (failableRewrite(op, rewriter)) {
*hadFailure = true;
- return Pattern::matchFailure();
+ return failure();
}
- return Pattern::matchSuccess();
+ return success();
}
private:
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index 68c7c018b584..f378047f36ea 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -88,13 +88,13 @@ struct CombineChainedAccessChain
: public OpRewritePattern<spirv::AccessChainOp> {
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
+ PatternRewriter &rewriter) const override {
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
accessChainOp.base_ptr().getDefiningOp());
if (!parentAccessChainOp) {
- return matchFailure();
+ return failure();
}
// Combine indices.
@@ -105,7 +105,7 @@ struct CombineChainedAccessChain
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
accessChainOp, parentAccessChainOp.base_ptr(), indices);
- return matchSuccess();
+ return success();
}
};
} // end anonymous namespace
@@ -291,24 +291,24 @@ struct ConvertSelectionOpToSelect
: public OpRewritePattern<spirv::SelectionOp> {
using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
+ PatternRewriter &rewriter) const override {
auto *op = selectionOp.getOperation();
auto &body = op->getRegion(0);
// Verifier allows an empty region for `spv.selection`.
if (body.empty()) {
- return matchFailure();
+ return failure();
}
// Check that region consists of 4 blocks:
// header block, `true` block, `false` block and merge block.
if (std::distance(body.begin(), body.end()) != 4) {
- return matchFailure();
+ return failure();
}
auto *headerBlock = selectionOp.getHeaderBlock();
if (!onlyContainsBranchConditionalOp(headerBlock)) {
- return matchFailure();
+ return failure();
}
auto brConditionalOp =
@@ -319,7 +319,7 @@ struct ConvertSelectionOpToSelect
auto *mergeBlock = selectionOp.getMergeBlock();
if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
- return matchFailure();
+ return failure();
auto trueValue = getSrcValue(trueBlock);
auto falseValue = getSrcValue(falseBlock);
@@ -335,7 +335,7 @@ struct ConvertSelectionOpToSelect
// `spv.selection` is not needed anymore.
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
private:
@@ -345,9 +345,8 @@ struct ConvertSelectionOpToSelect
// 2. Each `spv.Store` uses the same pointer and the same memory attributes.
// 3. A control flow goes into the given merge block from the given
// conditional blocks.
- PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
- Block *falseBlock,
- Block *mergeBlock) const;
+ LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
+ Block *mergeBlock) const;
bool onlyContainsBranchConditionalOp(Block *block) const {
return std::next(block->begin()) == block->end() &&
@@ -382,12 +381,12 @@ struct ConvertSelectionOpToSelect
}
};
-PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
+LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
// Each block must consists of 2 operations.
if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
(std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
- return matchFailure();
+ return failure();
}
auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
@@ -399,7 +398,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
!falseBrBranchOp) {
- return matchFailure();
+ return failure();
}
// Check that each `spv.Store` uses the same pointer, memory access
@@ -407,15 +406,15 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
!isValidType(trueBrStoreOp.value().getType())) {
- return matchFailure();
+ return failure();
}
if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
(falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
- return matchFailure();
+ return failure();
}
- return matchSuccess();
+ return success();
}
} // end anonymous namespace
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 3cf50466e072..4adabdaa597e 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -177,25 +177,25 @@ class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
public:
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
-PatternMatchResult
+LogicalResult
FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto fnType = funcOp.getType();
// TODO(antiagainst): support converting functions with one result.
if (fnType.getNumResults())
- return matchFailure();
+ return failure();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
for (auto argType : enumerate(funcOp.getType().getInputs())) {
auto convertedType = typeConverter.convertType(argType.value());
if (!convertedType)
- return matchFailure();
+ return failure();
signatureConverter.addInputs(argType.index(), convertedType);
}
@@ -216,7 +216,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
newFuncOp.end());
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
rewriter.eraseOp(funcOp);
- return matchSuccess();
+ return success();
}
void mlir::populateBuiltinFuncToSPIRVPatterns(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
index d137f45d8361..0645408398b6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
@@ -27,8 +27,8 @@ class SPIRVGlobalVariableOpLayoutInfoDecoration
public:
using OpRewritePattern<spirv::GlobalVariableOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(spirv::GlobalVariableOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
+ PatternRewriter &rewriter) const override {
spirv::StructType::LayoutInfo structSize = 0;
VulkanLayoutUtils::Size structAlignment = 1;
SmallVector<NamedAttribute, 4> globalVarAttrs;
@@ -50,7 +50,7 @@ class SPIRVGlobalVariableOpLayoutInfoDecoration
rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
op, TypeAttr::get(decoratedType), globalVarAttrs);
- return matchSuccess();
+ return success();
}
};
@@ -59,15 +59,15 @@ class SPIRVAddressOfOpLayoutInfoDecoration
public:
using OpRewritePattern<spirv::AddressOfOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(spirv::AddressOfOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(spirv::AddressOfOp op,
+ PatternRewriter &rewriter) const override {
auto spirvModule = op.getParentOfType<spirv::ModuleOp>();
auto varName = op.variable();
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
op, varOp.type(), rewriter.getSymbolRefAttr(varName));
- return matchSuccess();
+ return success();
}
};
} // namespace
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index ba8dee752782..4dbc54ecfca2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -138,7 +138,7 @@ namespace {
class ProcessInterfaceVarABI final : public SPIRVOpLowering<spirv::FuncOp> {
public:
using SPIRVOpLowering<spirv::FuncOp>::SPIRVOpLowering;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -151,13 +151,13 @@ class LowerABIAttributesPass final
};
} // namespace
-PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite(
+LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
spirv::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
spirv::getEntryPointABIAttrName())) {
// TODO(ravishankarm) : Non-entry point functions are not handled.
- return matchFailure();
+ return failure();
}
TypeConverter::SignatureConversion signatureConverter(
funcOp.getType().getNumInputs());
@@ -171,12 +171,12 @@ PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite(
// to pass around scalar/vector values and return a scalar/vector. For now
// non-entry point functions are not handled in this ABI lowering and will
// produce an error.
- return matchFailure();
+ return failure();
}
auto var =
createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo);
if (!var) {
- return matchFailure();
+ return failure();
}
OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
@@ -207,7 +207,7 @@ PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
});
- return matchSuccess();
+ return success();
}
void LowerABIAttributesPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 0560a0f19526..0c6d4647b933 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -313,14 +313,14 @@ namespace {
struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
using OpRewritePattern<AllocOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AllocOp alloc,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AllocOp alloc,
+ PatternRewriter &rewriter) const override {
// Check to see if any dimensions operands are constants. If so, we can
// substitute and drop them.
if (llvm::none_of(alloc.getOperands(), [](Value operand) {
return matchPattern(operand, m_ConstantIndex());
}))
- return matchFailure();
+ return failure();
auto memrefType = alloc.getType();
@@ -364,7 +364,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
alloc.getType());
rewriter.replaceOp(alloc, {resultCast});
- return matchSuccess();
+ return success();
}
};
@@ -373,13 +373,13 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
using OpRewritePattern<AllocOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(AllocOp alloc,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(AllocOp alloc,
+ PatternRewriter &rewriter) const override {
if (alloc.use_empty()) {
rewriter.eraseOp(alloc);
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
};
} // end anonymous namespace.
@@ -461,18 +461,18 @@ namespace {
struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
using OpRewritePattern<BranchOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(BranchOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(BranchOp op,
+ PatternRewriter &rewriter) const override {
// Check that the successor block has a single predecessor.
Block *succ = op.getDest();
Block *opParent = op.getOperation()->getBlock();
if (succ == opParent || !has_single_element(succ->getPredecessors()))
- return matchFailure();
+ return failure();
// Merge the successor into the current block and erase the branch.
rewriter.mergeBlocks(succ, opParent, op.getOperands());
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
};
} // end anonymous namespace.
@@ -545,18 +545,18 @@ struct SimplifyIndirectCallWithKnownCallee
: public OpRewritePattern<CallIndirectOp> {
using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
+ PatternRewriter &rewriter) const override {
// Check that the callee is a constant callee.
SymbolRefAttr calledFn;
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
- return matchFailure();
+ return failure();
// Replace with a direct call.
rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
indirectCall.getResultTypes(),
indirectCall.getArgOperands());
- return matchSuccess();
+ return success();
}
};
} // end anonymous namespace.
@@ -733,20 +733,20 @@ namespace {
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(CondBranchOp condbr,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
if (matchPattern(condbr.getCondition(), m_NonZero())) {
// True branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueOperands());
- return matchSuccess();
+ return success();
} else if (matchPattern(condbr.getCondition(), m_Zero())) {
// False branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseOperands());
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
};
} // end anonymous namespace.
@@ -958,21 +958,21 @@ namespace {
struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
using OpRewritePattern<DeallocOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(DeallocOp dealloc,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(DeallocOp dealloc,
+ PatternRewriter &rewriter) const override {
// Check that the memref operand's defining operation is an AllocOp.
Value memref = dealloc.memref();
if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
- return matchFailure();
+ return failure();
// Check that all of the uses of the AllocOp are other DeallocOps.
for (auto *user : memref.getUsers())
if (!isa<DeallocOp>(user))
- return matchFailure();
+ return failure();
// Erase the dealloc operation.
rewriter.eraseOp(dealloc);
- return matchSuccess();
+ return success();
}
};
} // end anonymous namespace.
@@ -2003,8 +2003,8 @@ class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
MemRefType subViewType = subViewOp.getType();
// Follow all or nothing approach for shapes for now. If all the operands
// for sizes are constants then fold it into the type of the result memref.
@@ -2012,7 +2012,7 @@ class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
llvm::any_of(subViewOp.sizes(), [](Value operand) {
return !matchPattern(operand, m_ConstantIndex());
})) {
- return matchFailure();
+ return failure();
}
SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
for (auto size : llvm::enumerate(subViewOp.sizes())) {
@@ -2028,7 +2028,7 @@ class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
- return matchSuccess();
+ return success();
}
};
@@ -2037,10 +2037,10 @@ class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
if (subViewOp.getNumStrides() == 0) {
- return matchFailure();
+ return failure();
}
// Follow all or nothing approach for strides for now. If all the operands
// for strides are constants then fold it into the strides of the result
@@ -2056,7 +2056,7 @@ class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
llvm::any_of(subViewOp.strides(), [](Value stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
- return matchFailure();
+ return failure();
}
SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
@@ -2077,7 +2077,7 @@ class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
- return matchSuccess();
+ return success();
}
};
@@ -2086,10 +2086,10 @@ class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
if (subViewOp.getNumOffsets() == 0) {
- return matchFailure();
+ return failure();
}
// Follow all or nothing approach for offsets for now. If all the operands
// for offsets are constants then fold it into the offset of the result
@@ -2106,7 +2106,7 @@ class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
llvm::any_of(subViewOp.offsets(), [](Value stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
- return matchFailure();
+ return failure();
}
auto staticOffset = baseOffset;
@@ -2128,7 +2128,7 @@ class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
- return matchSuccess();
+ return success();
}
};
@@ -2347,18 +2347,18 @@ namespace {
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
using OpRewritePattern<ViewOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ViewOp viewOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(ViewOp viewOp,
+ PatternRewriter &rewriter) const override {
// Return if none of the operands are constants.
if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
return matchPattern(operand, m_ConstantIndex());
}))
- return matchFailure();
+ return failure();
// Get result memref type.
auto memrefType = viewOp.getType();
if (memrefType.getAffineMaps().size() > 1)
- return matchFailure();
+ return failure();
auto map = memrefType.getAffineMaps().empty()
? AffineMap::getMultiDimIdentityMap(memrefType.getRank(),
rewriter.getContext())
@@ -2368,7 +2368,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
int64_t oldOffset;
SmallVector<int64_t, 4> oldStrides;
if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
- return matchFailure();
+ return failure();
SmallVector<Value, 4> newOperands;
@@ -2444,27 +2444,27 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
// Insert a cast so we have the same type as the old memref type.
rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
viewOp.getType());
- return matchSuccess();
+ return success();
}
};
struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
using OpRewritePattern<ViewOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ViewOp viewOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(ViewOp viewOp,
+ PatternRewriter &rewriter) const override {
Value memrefOperand = viewOp.getOperand(0);
MemRefCastOp memrefCastOp =
dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
if (!memrefCastOp)
- return matchFailure();
+ return failure();
Value allocOperand = memrefCastOp.getOperand();
AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
if (!allocOp)
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
viewOp.operands());
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 14831b9f38ae..342ce37ad515 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1145,18 +1145,18 @@ class StridedSliceConstantMaskFolder final
public:
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(StridedSliceOp stridedSliceOp,
+ PatternRewriter &rewriter) const override {
// Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp.
auto defOp = stridedSliceOp.vector().getDefiningOp();
auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
if (!constantMaskOp)
- return matchFailure();
+ return failure();
// Return if 'stridedSliceOp' has non-unit strides.
if (llvm::any_of(stridedSliceOp.strides(), [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt() != 1;
}))
- return matchFailure();
+ return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
@@ -1187,7 +1187,7 @@ class StridedSliceConstantMaskFolder final
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
stridedSliceOp, stridedSliceOp.getResult().getType(),
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
- return matchSuccess();
+ return success();
}
};
@@ -1619,14 +1619,14 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
public:
using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
+ PatternRewriter &rewriter) const override {
// Return if any of 'createMaskOp' operands are not defined by a constant.
auto is_not_def_by_constant = [](Value operand) {
return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
};
if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
- return matchFailure();
+ return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
for (auto operand : createMaskOp.operands()) {
@@ -1637,7 +1637,7 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
createMaskOp, createMaskOp.getResult().getType(),
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 38a83e01bcbf..dd47e0c80dc1 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -545,18 +545,18 @@ namespace {
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp,
+ PatternRewriter &rewriter) const override {
// TODO(andydavis, ntv) Support splitting TransferReadOp with non-identity
// permutation maps. Repurpose code from MaterializeVectors transformation.
if (!isIdentitySuffix(xferReadOp.permutation_map()))
- return matchFailure();
+ return failure();
// Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
Value xferReadResult = xferReadOp.getResult();
auto extractSlicesOp =
dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
if (!xferReadResult.hasOneUse() || !extractSlicesOp)
- return matchFailure();
+ return failure();
// Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
auto sourceVectorType = extractSlicesOp.getSourceVectorType();
@@ -593,7 +593,7 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
extractSlicesOp.strides());
- return matchSuccess();
+ return success();
}
};
@@ -601,23 +601,23 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
+ PatternRewriter &rewriter) const override {
// TODO(andydavis, ntv) Support splitting TransferWriteOp with non-identity
// permutation maps. Repurpose code from MaterializeVectors transformation.
if (!isIdentitySuffix(xferWriteOp.permutation_map()))
- return matchFailure();
+ return failure();
// Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
if (!insertSlicesOp)
- return matchFailure();
+ return failure();
// Get TupleOp operand of 'insertSlicesOp'.
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
insertSlicesOp.vectors().getDefiningOp());
if (!tupleOp)
- return matchFailure();
+ return failure();
// Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
auto sourceTupleType = insertSlicesOp.getSourceTupleType();
@@ -644,7 +644,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
// Erase old 'xferWriteOp'.
rewriter.eraseOp(xferWriteOp);
- return matchSuccess();
+ return success();
}
};
@@ -653,15 +653,15 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
// Check if 'shapeCastOp' has tuple source/result type.
auto sourceTupleType =
shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
auto resultTupleType =
shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
if (!sourceTupleType || !resultTupleType)
- return matchFailure();
+ return failure();
assert(sourceTupleType.size() == resultTupleType.size());
// Create single-vector ShapeCastOp for each source tuple element.
@@ -679,7 +679,7 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
// Replace 'shapeCastOp' with tuple of 'resultElements'.
rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
resultElements);
- return matchSuccess();
+ return success();
}
};
@@ -702,21 +702,21 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
// Check if 'shapeCastOp' has vector source/result type.
auto sourceVectorType =
shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
auto resultVectorType =
shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
if (!sourceVectorType || !resultVectorType)
- return matchFailure();
+ return failure();
// Check if shape cast op source operand is also a shape cast op.
auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
shapeCastOp.source().getDefiningOp());
if (!sourceShapeCastOp)
- return matchFailure();
+ return failure();
auto operandSourceVectorType =
sourceShapeCastOp.source().getType().cast<VectorType>();
auto operandResultVectorType =
@@ -725,10 +725,10 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
// Check if shape cast operations invert each other.
if (operandSourceVectorType != resultVectorType ||
operandResultVectorType != sourceVectorType)
- return matchFailure();
+ return failure();
rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
- return matchSuccess();
+ return success();
}
};
@@ -738,30 +738,30 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
using OpRewritePattern<vector::TupleGetOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
+ PatternRewriter &rewriter) const override {
// Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
tupleGetOp.vectors().getDefiningOp());
if (!extractSlicesOp)
- return matchFailure();
+ return failure();
// Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
extractSlicesOp.vector().getDefiningOp());
if (!insertSlicesOp)
- return matchFailure();
+ return failure();
// Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
insertSlicesOp.vectors().getDefiningOp());
if (!tupleOp)
- return matchFailure();
+ return failure();
// Forward Value from 'tupleOp' at 'tupleGetOp.index'.
Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
rewriter.replaceOp(tupleGetOp, tupleValue);
- return matchSuccess();
+ return success();
}
};
@@ -778,8 +778,8 @@ class ExtractSlicesOpLowering
public:
using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::ExtractSlicesOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::ExtractSlicesOp op,
+ PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType vectorType = op.getSourceVectorType();
@@ -806,7 +806,7 @@ class ExtractSlicesOpLowering
}
rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
- return matchSuccess();
+ return success();
}
};
@@ -825,8 +825,8 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
public:
using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::InsertSlicesOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::InsertSlicesOp op,
+ PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType vectorType = op.getResultVectorType();
@@ -860,7 +860,7 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
}
rewriter.replaceOp(op, result);
- return matchSuccess();
+ return success();
}
};
@@ -881,8 +881,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
public:
using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::OuterProductOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::OuterProductOp op,
+ PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType rhsType = op.getOperandVectorTypeRHS();
@@ -907,7 +907,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
}
rewriter.replaceOp(op, result);
- return matchSuccess();
+ return success();
}
};
@@ -934,11 +934,11 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions) {}
- PatternMatchResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
// TODO(ajcbik): implement masks
if (llvm::size(op.masks()) != 0)
- return matchFailure();
+ return failure();
// TODO(ntv, ajcbik): implement benefits, cost models, separate this out in
// a new pattern.
@@ -977,7 +977,7 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
else
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
- return matchSuccess();
+ return success();
}
}
@@ -987,7 +987,7 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
int64_t lhsIndex = batchDimMap[0].first;
int64_t rhsIndex = batchDimMap[0].second;
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
- return matchSuccess();
+ return success();
}
// Collect contracting dimensions.
@@ -1007,7 +1007,7 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
if (lhsContractingDimSet.count(lhsIndex) == 0) {
rewriter.replaceOp(
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
- return matchSuccess();
+ return success();
}
}
@@ -1018,17 +1018,17 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
if (rhsContractingDimSet.count(rhsIndex) == 0) {
rewriter.replaceOp(
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
- return matchSuccess();
+ return success();
}
}
// Lower the first remaining reduction dimension.
if (!contractingDimMap.empty()) {
rewriter.replaceOp(op, lowerReduction(op, rewriter));
- return matchSuccess();
+ return success();
}
- return matchFailure();
+ return failure();
}
private:
@@ -1275,12 +1275,12 @@ class ShapeCastOp2DDownCastRewritePattern
public:
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
- return matchFailure();
+ return failure();
auto loc = op.getLoc();
auto elemType = sourceVectorType.getElementType();
@@ -1295,7 +1295,7 @@ class ShapeCastOp2DDownCastRewritePattern
/*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
}
rewriter.replaceOp(op, desc);
- return matchSuccess();
+ return success();
}
};
@@ -1309,12 +1309,12 @@ class ShapeCastOp2DUpCastRewritePattern
public:
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
- return matchFailure();
+ return failure();
auto loc = op.getLoc();
auto elemType = sourceVectorType.getElementType();
@@ -1330,7 +1330,7 @@ class ShapeCastOp2DUpCastRewritePattern
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
}
rewriter.replaceOp(op, desc);
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index bf1cd3c29b3e..7d9fd5dd3c87 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -44,7 +44,7 @@ void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
"rewrite functions!");
}
-PatternMatchResult RewritePattern::match(Operation *op) const {
+LogicalResult RewritePattern::match(Operation *op) const {
llvm_unreachable("need to implement either match or matchAndRewrite!");
}
diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
index f16c767eea81..0112753c49fc 100644
--- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
+++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
@@ -35,13 +35,13 @@ class RemoveIdentityOpRewrite : public RewritePattern {
RemoveIdentityOpRewrite(MLIRContext *context)
: RewritePattern(OpTy::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
assert(op->getNumOperands() == 1);
assert(op->getNumResults() == 1);
rewriter.replaceOp(op, op->getOperand(0));
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index ae04e117282a..4b29d2edbe71 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1010,7 +1010,7 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
//===----------------------------------------------------------------------===//
/// Attempt to match and rewrite the IR root at the specified operation.
-PatternMatchResult
+LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
SmallVector<Value, 4> operands;
@@ -1705,7 +1705,7 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
: OpConversionPattern(ctx), converter(converter) {}
/// Hook for derived classes to implement combined matching and rewriting.
- PatternMatchResult
+ LogicalResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FunctionType type = funcOp.getType();
@@ -1714,12 +1714,12 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
- return matchFailure();
+ return failure();
// Convert the original function results.
SmallVector<Type, 1> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
- return matchFailure();
+ return failure();
// Update the function signature in-place.
rewriter.updateRootInPlace(funcOp, [&] {
@@ -1727,7 +1727,7 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
convertedResults, funcOp.getContext()));
rewriter.applySignatureConversion(&funcOp.getBody(), result);
});
- return matchSuccess();
+ return success();
}
/// The type converter to use when rewriting the signature.
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 0ebd3d8785fc..a91800d68fc0 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -94,32 +94,32 @@ struct ConvertToTargetEnv : public FunctionPass<ConvertToTargetEnv> {
struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
ConvertToAtomCmpExchangeWeak(MLIRContext *context);
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
};
struct ConvertToBitReverse : public RewritePattern {
ConvertToBitReverse(MLIRContext *context);
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
};
struct ConvertToGroupNonUniformBallot : public RewritePattern {
ConvertToGroupNonUniformBallot(MLIRContext *context);
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
};
struct ConvertToModule : public RewritePattern {
ConvertToModule(MLIRContext *context);
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
};
struct ConvertToSubgroupBallot : public RewritePattern {
ConvertToSubgroupBallot(MLIRContext *context);
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
@@ -145,7 +145,7 @@ ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
: RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
{"spv.AtomicCompareExchangeWeak"}, 1, context) {}
-PatternMatchResult
+LogicalResult
ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Value ptr = op->getOperand(0);
@@ -159,21 +159,21 @@ ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
spirv::MemorySemantics::AcquireRelease |
spirv::MemorySemantics::AtomicCounterMemory,
spirv::MemorySemantics::Acquire, value, comparator);
- return matchSuccess();
+ return success();
}
ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
: RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1,
context) {}
-PatternMatchResult
+LogicalResult
ConvertToBitReverse::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0);
rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
op, op->getResult(0).getType(), predicate);
- return matchSuccess();
+ return success();
}
ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
@@ -181,39 +181,39 @@ ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
: RewritePattern("test.convert_to_group_non_uniform_ballot_op",
{"spv.GroupNonUniformBallot"}, 1, context) {}
-PatternMatchResult ConvertToGroupNonUniformBallot::matchAndRewrite(
+LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0);
rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
- return matchSuccess();
+ return success();
}
ConvertToModule::ConvertToModule(MLIRContext *context)
: RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {}
-PatternMatchResult
+LogicalResult
ConvertToModule::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
op, spirv::AddressingModel::PhysicalStorageBuffer64,
spirv::MemoryModel::Vulkan);
- return matchSuccess();
+ return success();
}
ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
: RewritePattern("test.convert_to_subgroup_ballot_op",
{"spv.SubgroupBallotKHR"}, 1, context) {}
-PatternMatchResult
+LogicalResult
ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0);
rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
op, op->getResult(0).getType(), predicate);
- return matchSuccess();
+ return success();
}
namespace mlir {
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index da15486ef117..166edf820206 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -283,10 +283,10 @@ struct TestRemoveOpWithInnerOps
: public OpRewritePattern<TestOpWithRegionPattern> {
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
+ PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
};
} // end anonymous namespace
diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index 997d6090be80..c7235b8cb3a5 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -141,7 +141,7 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern {
TestRegionRewriteBlockMovement(MLIRContext *ctx)
: ConversionPattern("test.region", 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Inline this region into the parent region.
@@ -155,7 +155,7 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern {
// Drop this operation.
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
};
/// This pattern is a simple pattern that generates a region containing an
@@ -164,8 +164,8 @@ struct TestRegionRewriteUndo : public RewritePattern {
TestRegionRewriteUndo(MLIRContext *ctx)
: RewritePattern("test.region_builder", 1, ctx) {}
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final {
// Create the region operation with an entry block containing arguments.
OperationState newRegion(op->getLoc(), "test.region");
newRegion.addRegion();
@@ -179,7 +179,7 @@ struct TestRegionRewriteUndo : public RewritePattern {
// Drop this operation.
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
};
@@ -191,7 +191,7 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Region ®ion = op->getRegion(0);
@@ -202,12 +202,12 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i)
if (failed(converter.convertSignatureArg(
i, entry->getArgument(i).getType(), result)))
- return matchFailure();
+ return failure();
// Convert the region signature and just drop the operation.
rewriter.applySignatureConversion(®ion, result);
rewriter.eraseOp(op);
- return matchSuccess();
+ return success();
}
/// The type converter to use when rewriting the signature.
@@ -217,35 +217,35 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp(MLIRContext *ctx)
: ConversionPattern("test.invalid", 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
llvm::None);
- return matchSuccess();
+ return success();
}
};
/// This pattern handles the case of a split return value.
struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType(MLIRContext *ctx)
: ConversionPattern("test.return", 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Check for a return of F32.
if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
- return matchFailure();
+ return failure();
// Check if the first operation is a cast operation, if it is we use the
// results directly.
auto *defOp = operands[0].getDefiningOp();
if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
- return matchSuccess();
+ return success();
}
// Otherwise, fail to match.
- return matchFailure();
+ return failure();
}
};
@@ -254,52 +254,52 @@ struct TestSplitReturnType : public ConversionPattern {
struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is I32, change the type to F32.
if (!Type(*op->result_type_begin()).isSignlessInteger(32))
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
- return matchSuccess();
+ return success();
}
};
struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is F32, change the type to F64.
if (!Type(*op->result_type_begin()).isF32())
return rewriter.notifyMatchFailure(op, "expected single f32 operand");
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
- return matchSuccess();
+ return success();
}
};
struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 10, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Always convert to B16, even though it is not a legal type. This tests
// that values are unmapped correctly.
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
- return matchSuccess();
+ return success();
}
};
struct TestUpdateConsumerType : public ConversionPattern {
TestUpdateConsumerType(MLIRContext *ctx)
: ConversionPattern("test.type_consumer", 1, ctx) {}
- PatternMatchResult
+ LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Verify that the incoming operand has been successfully remapped to F64.
if (!operands[0].getType().isF64())
- return matchFailure();
+ return failure();
rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
- return matchSuccess();
+ return success();
}
};
@@ -312,15 +312,15 @@ struct TestNonRootReplacement : public RewritePattern {
TestNonRootReplacement(MLIRContext *ctx)
: RewritePattern("test.replace_non_root", 1, ctx) {}
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final {
auto resultType = *op->result_type_begin();
auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
rewriter.replaceOp(illegalOp, {legalOp});
rewriter.replaceOp(op, {illegalOp});
- return matchSuccess();
+ return success();
}
};
} // namespace
@@ -475,7 +475,7 @@ struct OneVResOneVOperandOp1Converter
: public OpConversionPattern<OneVResOneVOperandOp1> {
using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
- PatternMatchResult
+ LogicalResult
matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto origOps = op.getOperands();
@@ -490,7 +490,7 @@ struct OneVResOneVOperandOp1Converter
rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
remappedOperands);
- return matchSuccess();
+ return success();
}
};
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index fa57219fe005..217acb62b57e 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -215,7 +215,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function).
- os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
+ os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n",
depth);
}
if (tree.getNumArgs() != op.getNumArgs()) {
@@ -300,7 +300,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
os.indent(indent) << "if (!("
<< std::string(tgfmt(matcher.getConditionTemplate(),
&fmtCtx.withSelf(self)))
- << ")) return matchFailure();\n";
+ << ")) return failure();\n";
}
}
@@ -344,7 +344,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
// should just capture a mlir::Attribute() to signal the missing state.
// That is precisely what getAttr() returns on missing attributes.
} else {
- os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
+ os.indent(indent) << "if (!tblgen_attr) return failure();\n";
}
auto matcher = tree.getArgAsLeaf(argIndex);
@@ -360,7 +360,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
os.indent(indent) << "if (!("
<< std::string(tgfmt(matcher.getConditionTemplate(),
&fmtCtx.withSelf("tblgen_attr")))
- << ")) return matchFailure();\n";
+ << ")) return failure();\n";
}
// Capture the value
@@ -383,7 +383,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
auto &entities = appliedConstraint.entities;
auto condition = constraint.getConditionTemplate();
- auto cmd = "if (!({0})) return matchFailure();\n";
+ auto cmd = "if (!({0})) return failure();\n";
if (isa<TypeConstraint>(constraint)) {
auto self = formatv("({0}.getType())",
@@ -468,7 +468,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Emit matchAndRewrite() function.
os << R"(
- PatternMatchResult matchAndRewrite(Operation *op0,
+ LogicalResult matchAndRewrite(Operation *op0,
PatternRewriter &rewriter) const override {
)";
@@ -501,7 +501,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
os.indent(4) << "// Rewrite\n";
emitRewriteLogic();
- os.indent(4) << "return matchSuccess();\n";
+ os.indent(4) << "return success();\n";
os << " };\n";
os << "};\n";
}
More information about the Mlir-commits
mailing list