[Mlir-commits] [mlir] 3c69bc4 - [mlir][NFC] Remove a few op builders that simply swap parameter order
River Riddle
llvmlistbot at llvm.org
Mon Feb 7 19:04:25 PST 2022
Author: River Riddle
Date: 2022-02-07T19:03:57-08:00
New Revision: 3c69bc4d6e99dd1801e0120824963b894060569e
URL: https://github.com/llvm/llvm-project/commit/3c69bc4d6e99dd1801e0120824963b894060569e
DIFF: https://github.com/llvm/llvm-project/commit/3c69bc4d6e99dd1801e0120824963b894060569e.diff
LOG: [mlir][NFC] Remove a few op builders that simply swap parameter order
Differential Revision: https://reviews.llvm.org/D119093
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index bef9178a507a2..b278de529db29 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -73,12 +73,6 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins From:$in)>,
Results<(outs To:$out)> {
- let builders = [
- OpBuilder<(ins "Value":$source, "Type":$destType), [{
- impl::buildCastOp($_builder, $_state, source, destType);
- }]>
- ];
-
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
}
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 81839dfde9c5e..79ad1ed7f8046 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -374,11 +374,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyRankedOrUnrankedMemRef:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
- let builders = [
- OpBuilder<(ins "Value":$source, "Type":$destType), [{
- impl::buildCastOp($_builder, $_state, source, destType);
- }]>
- ];
let extraClassDeclaration = [{
/// Fold the given CastOp into consumer op.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 80f50a3996e36..56ee70eba3bfa 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1003,11 +1003,11 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
- loc, value, IntegerType::get(rewriter.getContext(), 64));
+ loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
value = rewriter.create<arith::ExtSIOp>(
- loc, value, IntegerType::get(rewriter.getContext(), 64));
+ loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::None:
break;
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index c4e1632d8023a..02526264ce4f3 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -94,8 +94,8 @@ struct IndexCastOpInterface
getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
layout, sourceType.getMemorySpace());
- replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
- resultType);
+ replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
+ source);
return success();
}
};
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 74d6d42e2b9b0..fd98f31e09feb 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -835,15 +835,15 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
scalingFactor);
}
Value numWorkersIndex =
- b.create<arith::IndexCastOp>(numWorkerThreadsVal, b.getI32Type());
+ b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
Value numWorkersFloat =
- b.create<arith::SIToFPOp>(numWorkersIndex, b.getF32Type());
+ b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
Value scaledNumWorkers =
b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
Value scaledNumInt =
- b.create<arith::FPToSIOp>(scaledNumWorkers, b.getI32Type());
+ b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
Value scaledWorkers =
- b.create<arith::IndexCastOp>(scaledNumInt, b.getIndexType());
+ b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
Value maxComputeBlocks = b.create<arith::MaxSIOp>(
b.create<arith::ConstantIndexOp>(1), scaledWorkers);
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index d47d6ead0273e..ed9170cdf55ac 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -887,7 +887,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
auto i32Vec = broadcast(builder.getI32Type(), shape);
// exp2(k)
- Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
+ Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32);
Value exp2KValue = exp2I32(builder, k);
// exp(x) = exp(y) * exp2(k)
@@ -1042,7 +1042,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
auto i32Vec = broadcast(builder.getI32Type(), shape);
auto fPToSingedInteger = [&](Value a) -> Value {
- return builder.create<arith::FPToSIOp>(a, i32Vec);
+ return builder.create<arith::FPToSIOp>(i32Vec, a);
};
auto modulo4 = [&](Value a) -> Value {
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 21cb8d60f2d3a..da672dc86b66d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -165,7 +165,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
alloc.alignmentAttr());
// Insert a cast so we have the same type as the old alloc.
auto resultCast =
- rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
+ rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
rewriter.replaceOp(alloc, {resultCast});
return success();
@@ -2156,8 +2156,8 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
rewriter.replaceOp(subViewOp, subViewOp.source());
return success();
}
- rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(),
- subViewOp.getType());
+ rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
+ subViewOp.source());
return success();
}
};
@@ -2177,7 +2177,7 @@ struct SubViewReturnTypeCanonicalizer {
/// A canonicalizer wrapper to replace SubViewOps.
struct SubViewCanonicalizer {
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
- rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
}
};
@@ -2422,7 +2422,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
viewOp.getOperand(0),
viewOp.byte_shift(), newOperands);
// Insert a cast so we have the same type as the old memref type.
- rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
+ rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
return success();
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 2a839771f97f1..ca490a28332c4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -101,8 +101,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
if (!size.getType().isa<IndexType>())
- size = rewriter.create<arith::IndexCastOp>(loc, size,
- rewriter.getIndexType());
+ size = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), size);
sizes[i] = size;
} else {
sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 94e87b3b79b7f..07875fcfc727f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -309,7 +309,7 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
ValueRange{ivs[0], idx});
val =
- rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
+ rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), val);
rewriter.create<memref::StoreOp>(loc, val, ind, idx);
}
return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 72e70ddbc123e..be707b126cf9c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -831,11 +831,11 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
if (!etp.isa<IndexType>()) {
if (etp.getIntOrFloatBitWidth() < 32)
vload = rewriter.create<arith::ExtUIOp>(
- loc, vload, vectorType(codegen, rewriter.getI32Type()));
+ loc, vectorType(codegen, rewriter.getI32Type()), vload);
else if (etp.getIntOrFloatBitWidth() < 64 &&
!codegen.options.enableSIMDIndex32)
vload = rewriter.create<arith::ExtUIOp>(
- loc, vload, vectorType(codegen, rewriter.getI64Type()));
+ loc, vectorType(codegen, rewriter.getI64Type()), vload);
}
return vload;
}
@@ -846,9 +846,9 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
if (!load.getType().isa<IndexType>()) {
if (load.getType().getIntOrFloatBitWidth() < 64)
- load = rewriter.create<arith::ExtUIOp>(loc, load, rewriter.getI64Type());
+ load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load);
load =
- rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType());
+ rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load);
}
return load;
}
@@ -868,7 +868,7 @@ static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
if (auto vtp = i.getType().dyn_cast<VectorType>()) {
Value inv =
- rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType());
+ rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul);
mul = genVectorInvariantValue(codegen, rewriter, inv);
}
return rewriter.create<arith::AddIOp>(loc, mul, i);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 31e7fb5a07edd..37e077acf06aa 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -671,25 +671,25 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
rewriter.getZeroAttr(v0.getType())),
v0);
case kTruncF:
- return rewriter.create<arith::TruncFOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
case kExtF:
- return rewriter.create<arith::ExtFOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
case kCastFS:
- return rewriter.create<arith::FPToSIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
case kCastFU:
- return rewriter.create<arith::FPToUIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
case kCastSF:
- return rewriter.create<arith::SIToFPOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
case kCastUF:
- return rewriter.create<arith::UIToFPOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
case kCastS:
- return rewriter.create<arith::ExtSIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
case kCastU:
- return rewriter.create<arith::ExtUIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
case kTruncI:
- return rewriter.create<arith::TruncIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
case kBitCast:
- return rewriter.create<arith::BitcastOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
// Binary ops.
case kMulF:
return rewriter.create<arith::MulFOp>(loc, v0, v1);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 1ceebf26abb69..f574713ffb2a4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -255,7 +255,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+ res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
@@ -271,7 +271,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
alloc);
b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
Value casted =
- b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+ b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
@@ -309,7 +309,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+ res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
@@ -324,7 +324,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
loc, MemRefType::get({}, vector.getType()), alloc));
Value casted =
- b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+ b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
@@ -360,7 +360,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+ res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.indices().begin(),
@@ -369,7 +369,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
},
[&](OpBuilder &b, Location loc) {
Value casted =
- b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+ b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.getTransferRank(), zero);
More information about the Mlir-commits
mailing list