[Mlir-commits] [mlir] d5cabf8 - Keep attribute when bufferizing `scf.forall` op (#91236)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 7 08:26:35 PDT 2024


Author: srcarroll
Date: 2024-05-07T10:26:30-05:00
New Revision: d5cabf8d89a5f5faa5255283821cb080bebbff86

URL: https://github.com/llvm/llvm-project/commit/d5cabf8d89a5f5faa5255283821cb080bebbff86
DIFF: https://github.com/llvm/llvm-project/commit/d5cabf8d89a5f5faa5255283821cb080bebbff86.diff

LOG: Keep attribute when bufferizing `scf.forall` op (#91236)

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 2a16b10bbaf8e..cf40443ff3839 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1267,6 +1267,9 @@ struct ForallOpInterface
         forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
         /*outputs=*/ValueRange(), forallOp.getMapping());
 
+    // Keep discardable attributes from the original op.
+    newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
+
     rewriter.eraseOp(newForallOp.getBody()->getTerminator());
 
     // Move over block contents of the old op.

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 485fdd9b0e593..bb9f7dfdba83f 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -499,7 +499,8 @@ func.func @parallel_insert_slice_no_conflict(
         tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] :
           tensor<?xf32> into tensor<?xf32>
       }
-  }
+  } {keep_this_attribute}
+  // CHECK: keep_this_attribute
 
   // CHECK: %[[load:.*]] = memref.load %[[arg2]]
   %f = tensor.extract %2[%c0] : tensor<?xf32>


        


More information about the Mlir-commits mailing list