8#include "include/expression.h"
11Constant::Constant(
double val)
16std::string Constant::toString()
const
18 if (value ==
static_cast<int>(value)) {
19 return std::to_string(
static_cast<int>(value));
21 return std::to_string(value);
24std::shared_ptr<Expression> Constant::derivative(
const std::string&)
const
26 return std::make_shared<Constant>(0);
29std::shared_ptr<Expression> Constant::simplify()
const
31 return std::make_shared<Constant>(value);
34double Constant::evaluate(
const std::map<std::string, double>&)
const
39Expression::Type Constant::getType()
const
44std::shared_ptr<Expression> Constant::clone()
const
46 return std::make_shared<Constant>(value);
50Variable::Variable(
const std::string& n)
55std::string Variable::toString()
const
60std::shared_ptr<Expression> Variable::derivative(
const std::string& var)
const
63 return std::make_shared<Constant>(1);
65 return std::make_shared<Constant>(0);
68std::shared_ptr<Expression> Variable::simplify()
const
70 return std::make_shared<Variable>(name);
73double Variable::evaluate(
const std::map<std::string, double>& vars)
const
75 auto it = vars.find(name);
76 if (it != vars.end()) {
79 throw std::runtime_error(
"Variable " + name
80 +
" not found in evaluation context");
83Expression::Type Variable::getType()
const
88std::shared_ptr<Expression> Variable::clone()
const
90 return std::make_shared<Variable>(name);
94BinaryOp::BinaryOp(std::shared_ptr<Expression> l,
96 std::shared_ptr<Expression> r)
120 throw std::runtime_error(
"Unknown binary operation: "
121 + std::string(1, operation));
125auto BinaryOp::toString() const -> std::
string
127 std::string left_str = left->toString();
128 std::string right_str = right->toString();
130 if (left->getType() == BINARY_OP) {
131 left_str =
"(" + left_str +
")";
133 if (right->getType() == BINARY_OP) {
134 right_str =
"(" + right_str +
")";
137 return left_str +
" " + op +
" " + right_str;
140std::shared_ptr<Expression> BinaryOp::derivative(
const std::string& var)
const
142 auto dl = left->derivative(var);
143 auto dr = right->derivative(var);
147 return std::make_shared<BinaryOp>(dl,
'+', dr);
149 return std::make_shared<BinaryOp>(dl,
'-', dr);
151 return std::make_shared<BinaryOp>(
152 std::make_shared<BinaryOp>(dl,
'*', right->clone()),
154 std::make_shared<BinaryOp>(left->clone(),
'*', dr));
156 return std::make_shared<BinaryOp>(
157 std::make_shared<BinaryOp>(
158 std::make_shared<BinaryOp>(dl,
'*', right->clone()),
160 std::make_shared<BinaryOp>(left->clone(),
'*', dr)),
162 std::make_shared<BinaryOp>(right->clone(),
'*', right->clone()));
164 if (right->getType() == CONSTANT) {
165 auto const_right = std::static_pointer_cast<Constant>(right);
166 double n = const_right->getValue();
167 return std::make_shared<BinaryOp>(
168 std::make_shared<BinaryOp>(
169 std::make_shared<Constant>(n),
171 std::make_shared<BinaryOp>(
172 left->clone(),
'^', std::make_shared<Constant>(n - 1))),
178 throw std::runtime_error(
"Unsupported derivative operation");
181std::shared_ptr<Expression> BinaryOp::simplify()
const
183 auto s_left = left->simplify();
184 auto s_right = right->simplify();
187 if (s_left->getType() == CONSTANT && s_right->getType() == CONSTANT) {
188 auto c_left = std::static_pointer_cast<Constant>(s_left);
189 auto c_right = std::static_pointer_cast<Constant>(s_right);
190 double l_val = c_left->getValue();
191 double r_val = c_right->getValue();
195 return std::make_shared<Constant>(l_val + r_val);
197 return std::make_shared<Constant>(l_val - r_val);
199 return std::make_shared<Constant>(l_val * r_val);
202 return std::make_shared<Constant>(l_val / r_val);
206 return std::make_shared<Constant>(std::pow(l_val, r_val));
211 if (s_left->getType() == CONSTANT) {
212 auto c_left = std::static_pointer_cast<Constant>(s_left);
213 double const l_val = c_left->getValue();
224 return std::make_shared<BinaryOp>(
225 std::make_shared<Constant>(-1),
'*', s_right);
230 return std::make_shared<Constant>(0);
238 return std::make_shared<Constant>(0);
245 if (s_right->getType() == CONSTANT) {
246 auto c_right = std::static_pointer_cast<Constant>(s_right);
247 double r_val = c_right->getValue();
262 return std::make_shared<Constant>(0);
275 return std::make_shared<Constant>(1);
285 return std::make_shared<BinaryOp>(s_left, op, s_right);
288double BinaryOp::evaluate(
const std::map<std::string, double>& vars)
const
290 double const l_val = left->evaluate(vars);
291 double const r_val = right->evaluate(vars);
295 return l_val + r_val;
297 return l_val - r_val;
299 return l_val * r_val;
302 throw std::runtime_error(
"Division by zero");
304 return l_val / r_val;
306 return std::pow(l_val, r_val);
309 throw std::runtime_error(
"Unknown binary operation");
312Expression::Type BinaryOp::getType()
const
317std::shared_ptr<Expression> BinaryOp::clone()
const
319 return std::make_shared<BinaryOp>(left->clone(), op, right->clone());
323Function::Function(
const std::string& fname, std::shared_ptr<Expression> arg)
329std::string Function::toString()
const
331 return name +
"(" + argument->toString() +
")";
334std::shared_ptr<Expression> Function::derivative(
const std::string& var)
const
336 auto argDeriv = argument->derivative(var);
339 return std::make_shared<BinaryOp>(
340 std::make_shared<Function>(
"cos", argument->clone()),
'*', argDeriv);
341 }
else if (name ==
"cos") {
342 return std::make_shared<BinaryOp>(
343 std::make_shared<BinaryOp>(
344 std::make_shared<Constant>(-1),
346 std::make_shared<Function>(
"sin", argument->clone())),
349 }
else if (name ==
"tan") {
350 return std::make_shared<BinaryOp>(
351 std::make_shared<BinaryOp>(
352 std::make_shared<Function>(
"sec", argument->clone()),
354 std::make_shared<Constant>(2)),
357 }
else if (name ==
"ln" || name ==
"log") {
358 return std::make_shared<BinaryOp>(
359 std::make_shared<BinaryOp>(
360 std::make_shared<Constant>(1),
'/', argument->clone()),
363 }
else if (name ==
"exp") {
364 return std::make_shared<BinaryOp>(
365 std::make_shared<Function>(
"exp", argument->clone()),
'*', argDeriv);
366 }
else if (name ==
"sqrt") {
367 return std::make_shared<BinaryOp>(
368 std::make_shared<BinaryOp>(
369 std::make_shared<Constant>(1),
371 std::make_shared<BinaryOp>(
372 std::make_shared<Constant>(2),
374 std::make_shared<Function>(
"sqrt", argument->clone()))),
379 throw std::runtime_error(
"Derivative not implemented for function: " + name);
382std::shared_ptr<Expression> Function::simplify()
const
384 auto sArg = argument->simplify();
387 if (sArg->getType() == CONSTANT) {
388 auto constArg = std::static_pointer_cast<Constant>(sArg);
389 double argVal = constArg->getValue();
393 return std::make_shared<Constant>(std::sin(argVal));
396 return std::make_shared<Constant>(std::cos(argVal));
399 return std::make_shared<Constant>(std::tan(argVal));
402 return std::make_shared<Constant>(std::exp(argVal));
404 if (name ==
"ln" || name ==
"log") {
406 return std::make_shared<Constant>(std::log(argVal));
409 if (name ==
"sqrt") {
411 return std::make_shared<Constant>(std::sqrt(argVal));
415 return std::make_shared<Constant>(std::abs(argVal));
417 if (name ==
"floor") {
418 return std::make_shared<Constant>(std::floor(argVal));
420 if (name ==
"ceil") {
421 return std::make_shared<Constant>(std::ceil(argVal));
428 return std::make_shared<Function>(name, sArg);
431double Function::evaluate(
const std::map<std::string, double>& vars)
const
433 double argVal = argument->evaluate(vars);
436 return std::sin(argVal);
439 return std::cos(argVal);
442 return std::tan(argVal);
445 return std::exp(argVal);
447 if (name ==
"ln" || name ==
"log") {
449 throw std::runtime_error(
"Cannot take logarithm of non-positive number");
451 return std::log(argVal);
453 if (name ==
"sqrt") {
455 throw std::runtime_error(
"Cannot take square root of negative number");
457 return std::sqrt(argVal);
460 return std::abs(argVal);
462 if (name ==
"floor") {
463 return std::floor(argVal);
465 if (name ==
"ceil") {
466 return std::ceil(argVal);
469 throw std::runtime_error(
"Unknown function: " + name);
472Expression::Type Function::getType()
const
477std::shared_ptr<Expression> Function::clone()
const
479 return std::make_shared<Function>(name, argument->clone());
487UnaryOpExpression::UnaryOpExpression(
char op,
488 std::shared_ptr<Expression> operand)
493 throw std::invalid_argument(
"Operand cannot be null");
497double UnaryOpExpression::evaluate(
498 const std::map<std::string, double>& variables)
const
501 throw std::runtime_error(
"Invalid operand in unary expression");
504 double operandValue = operand_->evaluate(variables);
510 return -operandValue;
513 if (operandValue < 0 || operandValue !=
static_cast<int>(operandValue)) {
514 throw std::runtime_error(
"Factorial requires non-negative integer");
517 int n =
static_cast<int>(operandValue);
519 for (
int i = 2; i <= n; ++i) {
525 throw std::runtime_error(
"Unknown unary operator: "
526 + std::string(1, operator_));
530std::string UnaryOpExpression::toString()
const
539 return std::string(1, operator_) +
"(" + operand_->toString() +
")";
541 return "(" + operand_->toString() +
")!";
543 return std::string(1, operator_) +
"(" + operand_->toString() +
")";
547Expression::Type UnaryOpExpression::getType()
const
549 return Expression::Type::UNARY_OP;
560char UnaryOpExpression::getOperator()
const
565std::shared_ptr<Expression> UnaryOpExpression::getOperand()
const
572std::shared_ptr<Expression> UnaryOpExpression::derivative(
573 const std::string& var)
const
576 throw std::runtime_error(
"Invalid operand in unary expression derivative");
579 auto operandDeriv = operand_->derivative(var);
587 return std::make_shared<UnaryOpExpression>(
'-', operandDeriv);
592 throw std::runtime_error(
"Derivative of factorial not implemented");
594 throw std::runtime_error(
"Unknown unary operator for derivative: "
595 + std::string(1, operator_));
601std::shared_ptr<Expression> UnaryOpExpression::simplify()
const
604 throw std::runtime_error(
"Invalid operand in unary expression simplify");
607 auto simplifiedOperand = operand_->simplify();
610 if (simplifiedOperand->getType() == CONSTANT) {
611 auto constOperand = std::static_pointer_cast<Constant>(simplifiedOperand);
612 double val = constOperand->getValue();
616 return std::make_shared<Constant>(val);
618 return std::make_shared<Constant>(-val);
621 if (val >= 0 && val ==
static_cast<int>(val) && val <= 20) {
622 int n =
static_cast<int>(val);
624 for (
int i = 2; i <= n; ++i) {
627 return std::make_shared<Constant>(result);
637 return simplifiedOperand;
640 if (simplifiedOperand->getType() == UNARY_OP) {
642 std::static_pointer_cast<UnaryOpExpression>(simplifiedOperand);
643 if (unaryOperand->getOperator() ==
'-') {
644 return unaryOperand->getOperand()->simplify();
650 return std::make_shared<UnaryOpExpression>(operator_, simplifiedOperand);
652std::shared_ptr<Expression> UnaryOpExpression::clone()
const
657 return std::make_shared<UnaryOpExpression>(operator_, operand_->clone());