[Mlir-commits] [mlir] 3c69bc4 - [mlir][NFC] Remove a few op builders that simply swap parameter order

River Riddle llvmlistbot at llvm.org
Mon Feb 7 19:04:25 PST 2022


Author: River Riddle
Date: 2022-02-07T19:03:57-08:00
New Revision: 3c69bc4d6e99dd1801e0120824963b894060569e

URL: https://github.com/llvm/llvm-project/commit/3c69bc4d6e99dd1801e0120824963b894060569e
DIFF: https://github.com/llvm/llvm-project/commit/3c69bc4d6e99dd1801e0120824963b894060569e.diff

LOG: [mlir][NFC] Remove a few op builders that simply swap parameter order

Differential Revision: https://reviews.llvm.org/D119093

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index bef9178a507a2..b278de529db29 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -73,12 +73,6 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
     DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins From:$in)>,
     Results<(outs To:$out)> {
-  let builders = [
-    OpBuilder<(ins "Value":$source, "Type":$destType), [{
-      impl::buildCastOp($_builder, $_state, source, destType);
-    }]>
-  ];
-
   let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
 }
 

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 81839dfde9c5e..79ad1ed7f8046 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -374,11 +374,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
   let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
   let results = (outs AnyRankedOrUnrankedMemRef:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
-  let builders = [
-    OpBuilder<(ins "Value":$source, "Type":$destType), [{
-       impl::buildCastOp($_builder, $_state, source, destType);
-    }]>
-  ];
 
   let extraClassDeclaration = [{
     /// Fold the given CastOp into consumer op.

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 80f50a3996e36..56ee70eba3bfa 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1003,11 +1003,11 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
       switch (conversion) {
       case PrintConversion::ZeroExt64:
         value = rewriter.create<arith::ExtUIOp>(
-            loc, value, IntegerType::get(rewriter.getContext(), 64));
+            loc, IntegerType::get(rewriter.getContext(), 64), value);
         break;
       case PrintConversion::SignExt64:
         value = rewriter.create<arith::ExtSIOp>(
-            loc, value, IntegerType::get(rewriter.getContext(), 64));
+            loc, IntegerType::get(rewriter.getContext(), 64), value);
         break;
       case PrintConversion::None:
         break;

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index c4e1632d8023a..02526264ce4f3 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -94,8 +94,8 @@ struct IndexCastOpInterface
         getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
                       layout, sourceType.getMemorySpace());
 
-    replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
-                                                     resultType);
+    replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
+                                                     source);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 74d6d42e2b9b0..fd98f31e09feb 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -835,15 +835,15 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
                                                 scalingFactor);
     }
     Value numWorkersIndex =
-        b.create<arith::IndexCastOp>(numWorkerThreadsVal, b.getI32Type());
+        b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
     Value numWorkersFloat =
-        b.create<arith::SIToFPOp>(numWorkersIndex, b.getF32Type());
+        b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
     Value scaledNumWorkers =
         b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
     Value scaledNumInt =
-        b.create<arith::FPToSIOp>(scaledNumWorkers, b.getI32Type());
+        b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
     Value scaledWorkers =
-        b.create<arith::IndexCastOp>(scaledNumInt, b.getIndexType());
+        b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
 
     Value maxComputeBlocks = b.create<arith::MaxSIOp>(
         b.create<arith::ConstantIndexOp>(1), scaledWorkers);

diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index d47d6ead0273e..ed9170cdf55ac 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -887,7 +887,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
   auto i32Vec = broadcast(builder.getI32Type(), shape);
 
   // exp2(k)
-  Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
+  Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32);
   Value exp2KValue = exp2I32(builder, k);
 
   // exp(x) = exp(y) * exp2(k)
@@ -1042,7 +1042,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
 
   auto i32Vec = broadcast(builder.getI32Type(), shape);
   auto fPToSingedInteger = [&](Value a) -> Value {
-    return builder.create<arith::FPToSIOp>(a, i32Vec);
+    return builder.create<arith::FPToSIOp>(i32Vec, a);
   };
 
   auto modulo4 = [&](Value a) -> Value {

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 21cb8d60f2d3a..da672dc86b66d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -165,7 +165,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
         alloc.alignmentAttr());
     // Insert a cast so we have the same type as the old alloc.
     auto resultCast =
-        rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
+        rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
 
     rewriter.replaceOp(alloc, {resultCast});
     return success();
@@ -2156,8 +2156,8 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
       rewriter.replaceOp(subViewOp, subViewOp.source());
       return success();
     }
-    rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(),
-                                        subViewOp.getType());
+    rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
+                                        subViewOp.source());
     return success();
   }
 };
@@ -2177,7 +2177,7 @@ struct SubViewReturnTypeCanonicalizer {
 /// A canonicalizer wrapper to replace SubViewOps.
 struct SubViewCanonicalizer {
   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
-    rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
+    rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
   }
 };
 
@@ -2422,7 +2422,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
                                              viewOp.getOperand(0),
                                              viewOp.byte_shift(), newOperands);
     // Insert a cast so we have the same type as the old memref type.
-    rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
+    rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 2a839771f97f1..ca490a28332c4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -101,8 +101,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
         Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
         size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
         if (!size.getType().isa<IndexType>())
-          size = rewriter.create<arith::IndexCastOp>(loc, size,
-                                                     rewriter.getIndexType());
+          size = rewriter.create<arith::IndexCastOp>(
+              loc, rewriter.getIndexType(), size);
         sizes[i] = size;
       } else {
         sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 94e87b3b79b7f..07875fcfc727f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -309,7 +309,7 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
                                                    ValueRange{ivs[0], idx});
     val =
-        rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
+        rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), val);
     rewriter.create<memref::StoreOp>(loc, val, ind, idx);
   }
   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 72e70ddbc123e..be707b126cf9c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -831,11 +831,11 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
     if (!etp.isa<IndexType>()) {
       if (etp.getIntOrFloatBitWidth() < 32)
         vload = rewriter.create<arith::ExtUIOp>(
-            loc, vload, vectorType(codegen, rewriter.getI32Type()));
+            loc, vectorType(codegen, rewriter.getI32Type()), vload);
       else if (etp.getIntOrFloatBitWidth() < 64 &&
                !codegen.options.enableSIMDIndex32)
         vload = rewriter.create<arith::ExtUIOp>(
-            loc, vload, vectorType(codegen, rewriter.getI64Type()));
+            loc, vectorType(codegen, rewriter.getI64Type()), vload);
     }
     return vload;
   }
@@ -846,9 +846,9 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
   if (!load.getType().isa<IndexType>()) {
     if (load.getType().getIntOrFloatBitWidth() < 64)
-      load = rewriter.create<arith::ExtUIOp>(loc, load, rewriter.getI64Type());
+      load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load);
     load =
-        rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType());
+        rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load);
   }
   return load;
 }
@@ -868,7 +868,7 @@ static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
   Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
     Value inv =
-        rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType());
+        rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul);
     mul = genVectorInvariantValue(codegen, rewriter, inv);
   }
   return rewriter.create<arith::AddIOp>(loc, mul, i);

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 31e7fb5a07edd..37e077acf06aa 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -671,25 +671,25 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
                                            rewriter.getZeroAttr(v0.getType())),
         v0);
   case kTruncF:
-    return rewriter.create<arith::TruncFOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
   case kExtF:
-    return rewriter.create<arith::ExtFOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
   case kCastFS:
-    return rewriter.create<arith::FPToSIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
   case kCastFU:
-    return rewriter.create<arith::FPToUIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
   case kCastSF:
-    return rewriter.create<arith::SIToFPOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
   case kCastUF:
-    return rewriter.create<arith::UIToFPOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
   case kCastS:
-    return rewriter.create<arith::ExtSIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
   case kCastU:
-    return rewriter.create<arith::ExtUIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
   case kTruncI:
-    return rewriter.create<arith::TruncIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
   case kBitCast:
-    return rewriter.create<arith::BitcastOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
   // Binary ops.
   case kMulF:
     return rewriter.create<arith::MulFOp>(loc, v0, v1);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 1ceebf26abb69..f574713ffb2a4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -255,7 +255,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
       [&](OpBuilder &b, Location loc) {
         Value res = memref;
         if (compatibleMemRefType != xferOp.getShapedType())
-          res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+          res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
         scf::ValueVector viewAndIndices{res};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
                               xferOp.indices().end());
@@ -271,7 +271,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
             alloc);
         b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
         Value casted =
-            b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+            b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
         scf::ValueVector viewAndIndices{casted};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
                               zero);
@@ -309,7 +309,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
       [&](OpBuilder &b, Location loc) {
         Value res = memref;
         if (compatibleMemRefType != xferOp.getShapedType())
-          res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+          res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
         scf::ValueVector viewAndIndices{res};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
                               xferOp.indices().end());
@@ -324,7 +324,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
                 loc, MemRefType::get({}, vector.getType()), alloc));
 
         Value casted =
-            b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+            b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
         scf::ValueVector viewAndIndices{casted};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
                               zero);
@@ -360,7 +360,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
           [&](OpBuilder &b, Location loc) {
             Value res = memref;
             if (compatibleMemRefType != xferOp.getShapedType())
-              res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+              res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
             scf::ValueVector viewAndIndices{res};
             viewAndIndices.insert(viewAndIndices.end(),
                                   xferOp.indices().begin(),
@@ -369,7 +369,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
           },
           [&](OpBuilder &b, Location loc) {
             Value casted =
-                b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+                b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
             scf::ValueVector viewAndIndices{casted};
             viewAndIndices.insert(viewAndIndices.end(),
                                   xferOp.getTransferRank(), zero);


        


More information about the Mlir-commits mailing list