[Mlir-commits] [mlir] 413fbb0 - [mlir][scf] Retain existing attributes in scf.for transforms

Lei Zhang llvmlistbot at llvm.org
Wed May 25 07:53:10 PDT 2022


Author: Lei Zhang
Date: 2022-05-25T10:53:02-04:00
New Revision: 413fbb045d714bbb1f1f3887104ccbc4b7b395c2

URL: https://github.com/llvm/llvm-project/commit/413fbb045d714bbb1f1f3887104ccbc4b7b395c2
DIFF: https://github.com/llvm/llvm-project/commit/413fbb045d714bbb1f1f3887104ccbc4b7b395c2.diff

LOG: [mlir][scf] Retain existing attributes in scf.for transforms

These attributes can carry useful information, e.g., pipelines
might use them to organize and chain patterns.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D126320

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/bufferize.mlir
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 3236017a5986c..9d3e61e69d9a5 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -660,6 +660,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     scf::ForOp newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), newIterArgs);
+    newForOp->setAttrs(forOp->getAttrs());
     Block &newBlock = newForOp.getRegion().front();
 
     // Replace the null placeholders with newly constructed values.
@@ -802,6 +803,7 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
   scf::ForOp newForOp = rewriter.create<scf::ForOp>(
       forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
       forOp.getStep(), newIterOperands);
+  newForOp->setAttrs(forOp->getAttrs());
   Block &newBlock = newForOp.getRegion().front();
   SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
                                              newBlock.getArguments().end());

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index d0c3c381dbfa1..2a6a95e5e0d1c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -491,6 +491,7 @@ struct ForOpInterface
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), initArgs);
+    newForOp->setAttrs(forOp->getAttrs());
     ValueRange initArgsRange(initArgs);
     TypeRange initArgsTypes(initArgsRange);
     Block *loopBody = &newForOp.getLoopBody().front();

diff  --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir
index eb795aeb5eb1f..6193101e9264d 100644
--- a/mlir/test/Dialect/SCF/bufferize.mlir
+++ b/mlir/test/Dialect/SCF/bufferize.mlir
@@ -30,14 +30,14 @@ func.func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) ->
 // CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32>
 // CHECK:           %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
 // CHECK:             scf.yield %[[ITER]] : memref<f32>
-// CHECK:           }
+// CHECK:           } {some_attr}
 // CHECK:           %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_9:.*]] : memref<f32>
 // CHECK:           return %[[VAL_8]] : tensor<f32>
 // CHECK:         }
 func.func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> {
   %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> {
     scf.yield %iter : tensor<f32>
-  }
+  } {some_attr}
   return %ret : tensor<f32>
 }
 

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index cca2f439bd70f..8e087fc0f38a4 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -372,7 +372,7 @@ func.func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i
   %r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) {
     %c = func.call @make_i32() : () -> (i32)
     scf.yield %0, %c, %2 : i32, i32, i32
-  }
+  } {some_attr}
   return %r#0, %r#1, %r#2 : i32, i32, i32
 }
 
@@ -382,7 +382,7 @@ func.func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i
 //  CHECK-NEXT:     %[[r1:.*]] = scf.for {{.*}} iter_args(%arg4 = %[[a]]) -> (i32) {
 //  CHECK-NEXT:       %[[c:.*]] = func.call @make_i32() : () -> i32
 //  CHECK-NEXT:       scf.yield %[[c]] : i32
-//  CHECK-NEXT:     }
+//  CHECK-NEXT:     } {some_attr}
 //  CHECK-NEXT:     return %[[a]], %[[r1]], %[[b]] : i32, i32, i32
 
 // -----
@@ -846,11 +846,12 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32
 //       CHECK:   %[[DONE:.*]] = func.call @do(%[[CAST]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
 //       CHECK:   %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor<?x?xf32> to tensor<32x1024xf32>
 //       CHECK:   scf.yield %[[UNCAST]] : tensor<32x1024xf32>
+//       CHECK: } {some_attr}
   %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
   %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
     %2 = func.call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
     scf.yield %2 : tensor<?x?xf32>
-  }
+  } {some_attr}
 //   CHECK-NOT: tensor.cast
 //       CHECK: %[[RES:.*]] = tensor.insert_slice %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
 //       CHECK: return %[[RES]] : tensor<1024x1024xf32>


        


More information about the Mlir-commits mailing list