[Mlir-commits] [mlir] update: added sub operation (PR #126468)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 9 21:36:59 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: kel (kel404x)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/126468.diff


5 Files Affected:

- (modified) mlir/examples/toy/Ch7/include/toy/Ops.td (+25) 
- (modified) mlir/examples/toy/Ch7/mlir/Dialect.cpp (+22) 
- (modified) mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp (+4-1) 
- (modified) mlir/examples/toy/Ch7/mlir/MLIRGen.cpp (+2) 
- (added) mlir/examples/toy/Ch7/test/test.py (+14) 


``````````diff
diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index 71ab7b0aeebb9f2..4b5ce9831106ce4 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -135,6 +135,31 @@ def AddOp : Toy_Op<"add",
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// SubOp
+//===----------------------------------------------------------------------===//
+
+def SubOp : Toy_Op<"sub",
+    [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+  let summary = "element-wise subtraction operation";
+  let description = [{
+    The "sub" operation performs element-wise subtraction between two tensors.
+    The shapes of the tensor operands are expected to match.
+  }];
+
+  let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+  let results = (outs F64Tensor);
+
+  // Indicate that the operation has a custom parser and printer method.
+  let hasCustomAssemblyFormat = 1;
+
+  // Allow building a SubOp with from the two input operands.
+  let builders = [
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
+  ];
+}
+
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 52881db87d86bb1..523d9cf24c24b28 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -662,3 +662,25 @@ mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
   return builder.create<ConstantOp>(loc, type,
                                     llvm::cast<mlir::DenseElementsAttr>(value));
 }
+
+
+//===----------------------------------------------------------------------===//
+// SubOp
+//===----------------------------------------------------------------------===//
+
+void SubOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+                  mlir::Value lhs, mlir::Value rhs) {
+  state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+  state.addOperands({lhs, rhs});
+}
+
+mlir::ParseResult SubOp::parse(mlir::OpAsmParser &parser,
+                              mlir::OperationState &result) {
+  return parseBinaryOp(parser, result);
+}
+
+void SubOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
+
+/// Infer the output shape of the SubOp, this is required by the shape inference
+/// interface.
+void SubOp::inferShapes() { getResult().setType(getLhs().getType()); }
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index bf2bc43301a33e3..e84d77670237578 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -146,6 +146,9 @@ struct BinaryOpLowering : public ConversionPattern {
 };
 using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
 using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
+using SubOpLowering = BinaryOpLowering<toy::SubOp, arith::SubFOp>;
+using NegOpLowering = BinaryOpLowering<toy::NegOp, arith::NegFOp>;
+
 
 //===----------------------------------------------------------------------===//
 // ToyToAffine RewritePatterns: Constant operations
@@ -365,7 +368,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
   // Now that the conversion target has been defined, we just need to provide
   // the set of patterns that will lower the Toy operations.
   RewritePatternSet patterns(&getContext());
-  patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
+  patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering, SubOpLowering,NegOpLowering, 
                PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
       &getContext());
 
diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
index e554e375209f1c7..f34ed677f6fc38d 100644
--- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
@@ -348,6 +348,8 @@ class MLIRGenImpl {
       return builder.create<AddOp>(location, lhs, rhs);
     case '*':
       return builder.create<MulOp>(location, lhs, rhs);
+    case '-':
+      return builder.create<SubOp>(location, lhs, rhs);
     }
 
     emitError(location, "invalid binary operator '") << binop.getOp() << "'";
diff --git a/mlir/examples/toy/Ch7/test/test.py b/mlir/examples/toy/Ch7/test/test.py
new file mode 100644
index 000000000000000..3336825bf2761b9
--- /dev/null
+++ b/mlir/examples/toy/Ch7/test/test.py
@@ -0,0 +1,14 @@
+
+
+def main() {
+  var a = [10,20,30];
+  var b = [40,50,60];
+
+  var result = b - a;
+  print(result);
+ 
+}
+
+
+  
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/126468


More information about the Mlir-commits mailing list