sokobo
Loading...
Searching...
No Matches
expression.cpp
1#include <cmath>
2#include <cstdlib>
3#include <map>
4#include <memory>
5#include <stdexcept>
6#include <string>
7
8#include "include/expression.h"
9
10// =================== Constant Implementation ===================
11Constant::Constant(double val)
12 : value(val)
13{
14}
15
16std::string Constant::toString() const
17{
18 if (value == static_cast<int>(value)) {
19 return std::to_string(static_cast<int>(value));
20 }
21 return std::to_string(value);
22}
23
24std::shared_ptr<Expression> Constant::derivative(const std::string&) const
25{
26 return std::make_shared<Constant>(0);
27}
28
29std::shared_ptr<Expression> Constant::simplify() const
30{
31 return std::make_shared<Constant>(value);
32}
33
34double Constant::evaluate(const std::map<std::string, double>&) const
35{
36 return value;
37}
38
39Expression::Type Constant::getType() const
40{
41 return CONSTANT;
42}
43
44std::shared_ptr<Expression> Constant::clone() const
45{
46 return std::make_shared<Constant>(value);
47}
48
49// =================== Variable Implementation ===================
50Variable::Variable(const std::string& n)
51 : name(n)
52{
53}
54
55std::string Variable::toString() const
56{
57 return name;
58}
59
60std::shared_ptr<Expression> Variable::derivative(const std::string& var) const
61{
62 if (name == var) {
63 return std::make_shared<Constant>(1);
64 }
65 return std::make_shared<Constant>(0);
66}
67
68std::shared_ptr<Expression> Variable::simplify() const
69{
70 return std::make_shared<Variable>(name);
71}
72
73double Variable::evaluate(const std::map<std::string, double>& vars) const
74{
75 auto it = vars.find(name);
76 if (it != vars.end()) {
77 return it->second;
78 }
79 throw std::runtime_error("Variable " + name
80 + " not found in evaluation context");
81}
82
83Expression::Type Variable::getType() const
84{
85 return VARIABLE;
86}
87
88std::shared_ptr<Expression> Variable::clone() const
89{
90 return std::make_shared<Variable>(name);
91}
92
93// =================== BinaryOp Implementation ===================
94BinaryOp::BinaryOp(std::shared_ptr<Expression> l,
95 char operation,
96 std::shared_ptr<Expression> r)
97 : left(l)
98 , right(r)
99 , op(operation)
100{
101
102 // Set the enum based on the character operation
103 switch (operation) {
104 case '+':
105 binaryOp = ADD;
106 break;
107 case '-':
108 binaryOp = SUB;
109 break;
110 case '*':
111 binaryOp = MUL;
112 break;
113 case '/':
114 binaryOp = DIV;
115 break;
116 case '^':
117 binaryOp = POW;
118 break;
119 default:
120 throw std::runtime_error("Unknown binary operation: "
121 + std::string(1, operation));
122 }
123}
124
125auto BinaryOp::toString() const -> std::string
126{
127 std::string left_str = left->toString();
128 std::string right_str = right->toString();
129
130 if (left->getType() == BINARY_OP) {
131 left_str = "(" + left_str + ")";
132 }
133 if (right->getType() == BINARY_OP) {
134 right_str = "(" + right_str + ")";
135 }
136
137 return left_str + " " + op + " " + right_str;
138}
139
140std::shared_ptr<Expression> BinaryOp::derivative(const std::string& var) const
141{
142 auto dl = left->derivative(var);
143 auto dr = right->derivative(var);
144
145 switch (op) {
146 case '+':
147 return std::make_shared<BinaryOp>(dl, '+', dr);
148 case '-':
149 return std::make_shared<BinaryOp>(dl, '-', dr);
150 case '*':
151 return std::make_shared<BinaryOp>(
152 std::make_shared<BinaryOp>(dl, '*', right->clone()),
153 '+',
154 std::make_shared<BinaryOp>(left->clone(), '*', dr));
155 case '/':
156 return std::make_shared<BinaryOp>(
157 std::make_shared<BinaryOp>(
158 std::make_shared<BinaryOp>(dl, '*', right->clone()),
159 '-',
160 std::make_shared<BinaryOp>(left->clone(), '*', dr)),
161 '/',
162 std::make_shared<BinaryOp>(right->clone(), '*', right->clone()));
163 case '^':
164 if (right->getType() == CONSTANT) {
165 auto const_right = std::static_pointer_cast<Constant>(right);
166 double n = const_right->getValue();
167 return std::make_shared<BinaryOp>(
168 std::make_shared<BinaryOp>(
169 std::make_shared<Constant>(n),
170 '*',
171 std::make_shared<BinaryOp>(
172 left->clone(), '^', std::make_shared<Constant>(n - 1))),
173 '*',
174 dl);
175 }
176 break;
177 }
178 throw std::runtime_error("Unsupported derivative operation");
179}
180
181std::shared_ptr<Expression> BinaryOp::simplify() const
182{
183 auto s_left = left->simplify();
184 auto s_right = right->simplify();
185
186 // Check for constant folding
187 if (s_left->getType() == CONSTANT && s_right->getType() == CONSTANT) {
188 auto c_left = std::static_pointer_cast<Constant>(s_left);
189 auto c_right = std::static_pointer_cast<Constant>(s_right);
190 double l_val = c_left->getValue();
191 double r_val = c_right->getValue();
192
193 switch (op) {
194 case '+':
195 return std::make_shared<Constant>(l_val + r_val);
196 case '-':
197 return std::make_shared<Constant>(l_val - r_val);
198 case '*':
199 return std::make_shared<Constant>(l_val * r_val);
200 case '/':
201 if (r_val != 0) {
202 return std::make_shared<Constant>(l_val / r_val);
203 }
204 break;
205 case '^':
206 return std::make_shared<Constant>(std::pow(l_val, r_val));
207 }
208 }
209
210 // Algebraic simplifications
211 if (s_left->getType() == CONSTANT) {
212 auto c_left = std::static_pointer_cast<Constant>(s_left);
213 double const l_val = c_left->getValue();
214
215 switch (op) {
216 case '+':
217 if (l_val == 0) {
218 return s_right; // 0 + x = x
219 }
220 break;
221 case '-':
222 if (l_val == 0) {
223 // 0 - x = -x (need to implement unary minus or use -1 * x)
224 return std::make_shared<BinaryOp>(
225 std::make_shared<Constant>(-1), '*', s_right);
226 }
227 break;
228 case '*':
229 if (l_val == 0) {
230 return std::make_shared<Constant>(0); // 0 * x = 0
231 }
232 if (l_val == 1) {
233 return s_right; // 1 * x = x
234 }
235 break;
236 case '/':
237 if (l_val == 0) {
238 return std::make_shared<Constant>(0); // 0 / x = 0
239 }
240 break;
241 default:;
242 }
243 }
244
245 if (s_right->getType() == CONSTANT) {
246 auto c_right = std::static_pointer_cast<Constant>(s_right);
247 double r_val = c_right->getValue();
248
249 switch (op) {
250 case '+':
251 if (r_val == 0) {
252 return s_left; // x + 0 = x
253 }
254 break;
255 case '-':
256 if (r_val == 0) {
257 return s_left; // x - 0 = x
258 }
259 break;
260 case '*':
261 if (r_val == 0) {
262 return std::make_shared<Constant>(0); // x * 0 = 0
263 }
264 if (r_val == 1) {
265 return s_left; // x * 1 = x
266 }
267 break;
268 case '/':
269 if (r_val == 1) {
270 return s_left; // x / 1 = x
271 }
272 break;
273 case '^':
274 if (r_val == 0) {
275 return std::make_shared<Constant>(1); // x^0 = 1
276 }
277 if (r_val == 1) {
278 return s_left; // x^1 = x
279 }
280 break;
281 default:;
282 }
283 }
284
285 return std::make_shared<BinaryOp>(s_left, op, s_right);
286}
287
288double BinaryOp::evaluate(const std::map<std::string, double>& vars) const
289{
290 double const l_val = left->evaluate(vars);
291 double const r_val = right->evaluate(vars);
292
293 switch (op) {
294 case '+':
295 return l_val + r_val;
296 case '-':
297 return l_val - r_val;
298 case '*':
299 return l_val * r_val;
300 case '/':
301 if (r_val == 0) {
302 throw std::runtime_error("Division by zero");
303 }
304 return l_val / r_val;
305 case '^':
306 return std::pow(l_val, r_val);
307 default:;
308 }
309 throw std::runtime_error("Unknown binary operation");
310}
311
312Expression::Type BinaryOp::getType() const
313{
314 return BINARY_OP;
315}
316
317std::shared_ptr<Expression> BinaryOp::clone() const
318{
319 return std::make_shared<BinaryOp>(left->clone(), op, right->clone());
320}
321
322// =================== Function Implementation ===================
323Function::Function(const std::string& fname, std::shared_ptr<Expression> arg)
324 : name(fname)
325 , argument(arg)
326{
327}
328
329std::string Function::toString() const
330{
331 return name + "(" + argument->toString() + ")";
332}
333
334std::shared_ptr<Expression> Function::derivative(const std::string& var) const
335{
336 auto argDeriv = argument->derivative(var);
337
338 if (name == "sin") {
339 return std::make_shared<BinaryOp>(
340 std::make_shared<Function>("cos", argument->clone()), '*', argDeriv);
341 } else if (name == "cos") {
342 return std::make_shared<BinaryOp>(
343 std::make_shared<BinaryOp>(
344 std::make_shared<Constant>(-1),
345 '*',
346 std::make_shared<Function>("sin", argument->clone())),
347 '*',
348 argDeriv);
349 } else if (name == "tan") {
350 return std::make_shared<BinaryOp>(
351 std::make_shared<BinaryOp>(
352 std::make_shared<Function>("sec", argument->clone()),
353 '^',
354 std::make_shared<Constant>(2)),
355 '*',
356 argDeriv);
357 } else if (name == "ln" || name == "log") {
358 return std::make_shared<BinaryOp>(
359 std::make_shared<BinaryOp>(
360 std::make_shared<Constant>(1), '/', argument->clone()),
361 '*',
362 argDeriv);
363 } else if (name == "exp") {
364 return std::make_shared<BinaryOp>(
365 std::make_shared<Function>("exp", argument->clone()), '*', argDeriv);
366 } else if (name == "sqrt") {
367 return std::make_shared<BinaryOp>(
368 std::make_shared<BinaryOp>(
369 std::make_shared<Constant>(1),
370 '/',
371 std::make_shared<BinaryOp>(
372 std::make_shared<Constant>(2),
373 '*',
374 std::make_shared<Function>("sqrt", argument->clone()))),
375 '*',
376 argDeriv);
377 }
378
379 throw std::runtime_error("Derivative not implemented for function: " + name);
380}
381
382std::shared_ptr<Expression> Function::simplify() const
383{
384 auto sArg = argument->simplify();
385
386 // If argument is constant, we can evaluate the function
387 if (sArg->getType() == CONSTANT) {
388 auto constArg = std::static_pointer_cast<Constant>(sArg);
389 double argVal = constArg->getValue();
390
391 try {
392 if (name == "sin") {
393 return std::make_shared<Constant>(std::sin(argVal));
394 }
395 if (name == "cos") {
396 return std::make_shared<Constant>(std::cos(argVal));
397 }
398 if (name == "tan") {
399 return std::make_shared<Constant>(std::tan(argVal));
400 }
401 if (name == "exp") {
402 return std::make_shared<Constant>(std::exp(argVal));
403 }
404 if (name == "ln" || name == "log") {
405 if (argVal > 0) {
406 return std::make_shared<Constant>(std::log(argVal));
407 }
408 }
409 if (name == "sqrt") {
410 if (argVal >= 0) {
411 return std::make_shared<Constant>(std::sqrt(argVal));
412 }
413 }
414 if (name == "abs") {
415 return std::make_shared<Constant>(std::abs(argVal));
416 }
417 if (name == "floor") {
418 return std::make_shared<Constant>(std::floor(argVal));
419 }
420 if (name == "ceil") {
421 return std::make_shared<Constant>(std::ceil(argVal));
422 }
423 } catch (...) {
424 // If evaluation fails, return the original function
425 }
426 }
427
428 return std::make_shared<Function>(name, sArg);
429}
430
431double Function::evaluate(const std::map<std::string, double>& vars) const
432{
433 double argVal = argument->evaluate(vars);
434
435 if (name == "sin") {
436 return std::sin(argVal);
437 }
438 if (name == "cos") {
439 return std::cos(argVal);
440 }
441 if (name == "tan") {
442 return std::tan(argVal);
443 }
444 if (name == "exp") {
445 return std::exp(argVal);
446 }
447 if (name == "ln" || name == "log") {
448 if (argVal <= 0) {
449 throw std::runtime_error("Cannot take logarithm of non-positive number");
450 }
451 return std::log(argVal);
452 }
453 if (name == "sqrt") {
454 if (argVal < 0) {
455 throw std::runtime_error("Cannot take square root of negative number");
456 }
457 return std::sqrt(argVal);
458 }
459 if (name == "abs") {
460 return std::abs(argVal);
461 }
462 if (name == "floor") {
463 return std::floor(argVal);
464 }
465 if (name == "ceil") {
466 return std::ceil(argVal);
467 }
468
469 throw std::runtime_error("Unknown function: " + name);
470}
471
472Expression::Type Function::getType() const
473{
474 return FUNCTION;
475}
476
477std::shared_ptr<Expression> Function::clone() const
478{
479 return std::make_shared<Function>(name, argument->clone());
480}
481
482
483
484// =================== UnaryOpExpression Implementation ===================
485
486
487UnaryOpExpression::UnaryOpExpression(char op,
488 std::shared_ptr<Expression> operand)
489 : operator_(op)
490 , operand_(operand)
491{
492 if (!operand) {
493 throw std::invalid_argument("Operand cannot be null");
494 }
495}
496
497double UnaryOpExpression::evaluate(
498 const std::map<std::string, double>& variables) const
499{
500 if (!operand_) {
501 throw std::runtime_error("Invalid operand in unary expression");
502 }
503
504 double operandValue = operand_->evaluate(variables);
505
506 switch (operator_) {
507 case '+':
508 return operandValue;
509 case '-':
510 return -operandValue;
511 case '!':
512 // Factorial (assuming integer input)
513 if (operandValue < 0 || operandValue != static_cast<int>(operandValue)) {
514 throw std::runtime_error("Factorial requires non-negative integer");
515 }
516 {
517 int n = static_cast<int>(operandValue);
518 double result = 1.0;
519 for (int i = 2; i <= n; ++i) {
520 result *= i;
521 }
522 return result;
523 }
524 default:
525 throw std::runtime_error("Unknown unary operator: "
526 + std::string(1, operator_));
527 }
528}
529
530std::string UnaryOpExpression::toString() const
531{
532 if (!operand_) {
533 return "null";
534 }
535
536 switch (operator_) {
537 case '+':
538 case '-':
539 return std::string(1, operator_) + "(" + operand_->toString() + ")";
540 case '!':
541 return "(" + operand_->toString() + ")!";
542 default:
543 return std::string(1, operator_) + "(" + operand_->toString() + ")";
544 }
545}
546
547Expression::Type UnaryOpExpression::getType() const
548{
549 return Expression::Type::UNARY_OP;
550}
551
552//std::shared_ptr<Expression> UnaryOpExpression::clone() const
553//{
554// if (!operand_) {
555// return nullptr;
556// }
557// return std::make_shared<UnaryOpExpression>(operator_, operand_->clone());
558//}
559
560char UnaryOpExpression::getOperator() const
561{
562 return operator_;
563}
564
565std::shared_ptr<Expression> UnaryOpExpression::getOperand() const
566{
567 return operand_;
568}
569
570
571
572std::shared_ptr<Expression> UnaryOpExpression::derivative(
573 const std::string& var) const
574{
575 if (!operand_) {
576 throw std::runtime_error("Invalid operand in unary expression derivative");
577 }
578
579 auto operandDeriv = operand_->derivative(var);
580
581 switch (operator_) {
582 case '+':
583 // d/dx(+u) = du/dx
584 return operandDeriv;
585 case '-':
586 // d/dx(-u) = -du/dx
587 return std::make_shared<UnaryOpExpression>('-', operandDeriv);
588 case '!':
589 // Factorial derivative is complex and typically not implemented
590 // for symbolic computation. You might want to throw an error
591 // or implement using gamma function derivatives if needed
592 throw std::runtime_error("Derivative of factorial not implemented");
593 default:
594 throw std::runtime_error("Unknown unary operator for derivative: "
595 + std::string(1, operator_));
596 }
597}
598
599
600
601std::shared_ptr<Expression> UnaryOpExpression::simplify() const
602{
603 if (!operand_) {
604 throw std::runtime_error("Invalid operand in unary expression simplify");
605 }
606
607 auto simplifiedOperand = operand_->simplify();
608
609 // Handle constant folding
610 if (simplifiedOperand->getType() == CONSTANT) {
611 auto constOperand = std::static_pointer_cast<Constant>(simplifiedOperand);
612 double val = constOperand->getValue();
613
614 switch (operator_) {
615 case '+':
616 return std::make_shared<Constant>(val); // +5 = 5
617 case '-':
618 return std::make_shared<Constant>(-val); // -5 = -5
619 case '!':
620 // Only simplify factorial for small non-negative integers
621 if (val >= 0 && val == static_cast<int>(val) && val <= 20) {
622 int n = static_cast<int>(val);
623 double result = 1.0;
624 for (int i = 2; i <= n; ++i) {
625 result *= i;
626 }
627 return std::make_shared<Constant>(result);
628 }
629 break;
630 }
631 }
632
633 // Handle algebraic simplifications
634 switch (operator_) {
635 case '+':
636 // +x = x (unary plus is redundant)
637 return simplifiedOperand;
638 case '-':
639 // Check for double negation: -(-x) = x
640 if (simplifiedOperand->getType() == UNARY_OP) {
641 auto unaryOperand =
642 std::static_pointer_cast<UnaryOpExpression>(simplifiedOperand);
643 if (unaryOperand->getOperator() == '-') {
644 return unaryOperand->getOperand()->simplify();
645 }
646 }
647 break;
648 }
649
650 return std::make_shared<UnaryOpExpression>(operator_, simplifiedOperand);
651}
652std::shared_ptr<Expression> UnaryOpExpression::clone() const
653{
654 if (!operand_) {
655 return nullptr;
656 }
657 return std::make_shared<UnaryOpExpression>(operator_, operand_->clone());
658}