/****************************************************************************
 *    lib/c/LTLQuery.cpp - This file is part of coala						*
 *																			*
 *    Copyright (C) 2009  Torsten Grote										*
 *																			*
 *    This program is free software; you can redistribute it and/or modify	*
 *    it under the terms of the GNU General Public License as published by	*
 *    the Free Software Foundation; either version 3 of the License, or		*
 *    (at your option) any later version.									*
 *																			*
 *    This program is distributed in the hope that it will be useful,		*
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of		*
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the			*
 *    GNU General Public License for more details.							*
 *																			*
 *    You should have received a copy of the GNU General Public License		*
 *    along with this program; if not, see http://www.gnu.org/licenses		*
 ****************************************************************************/

#include "LTLQuery.h"
#include "Parser.h"

using namespace C;


LTLQuery::LTLQuery(LTLNode* node, int line, Types* types) : Statement(line, types), Query(line, types) {
	node_ = node;
}

LTLQuery::~LTLQuery ( ) { }

LTLNode* LTLQuery::getNode() {
	return node_;
}

void LTLQuery::setNode(LTLNode* node) {
	node_ = node;
}

void LTLQuery::print(Printer* p, map<string, FluentAction*>* fluents) {
	if(p->rev) throw std::runtime_error("LTL Queries are not yet supported with reverse incremental option.");

	if(p->inc) p->printIncLTLAuxilliaries();
	else p->printLTLAuxilliaries();

	// special check for correctness of loop guess in case of direct encoding
	if(!p->no_direct_enc) printLoopGuessCheck(p, fluents);

	// negate node to find counter examples
	node_ = node_->negate(false);

	if(p->inc) {
		p->setSection('c');
		node_->printIncPropagator(p);
		p->add("ltl_counter_example(t) :- ");
	}
	else {
		p->add("ltl_counter_example :- ");
	}
	node_->print(p, true, false);
}

void LTLQuery::printLoopGuessCheck(Printer* p, map<string, FluentAction*>* fluents) {
	// % check that loop guess was correct
	p->setSection('v');
	for(map<string, FluentAction*>::iterator f = fluents->begin(); f != fluents->end(); ++f) {
		if(f->first == "0") continue;
		p->add(":- loop_start(T-1), ");
		p->add(f->second->print(p, "T-1") + ", ");
		if(p->inc) p->add("time(T), ");
		p->add(p->neg + f->second->print(p, "t") + ".\n");

		p->add(":- loop_start(T-1), ");
		p->add(f->second->print(p, "t") + ", ");
		if(p->inc) p->add("time(T), ");
		p->add(p->neg + f->second->print(p, "T-1") + ".\n");
	}
}


/* class LTLNode starts here */

LTLNode::LTLNode(int type, LTLNode* left, LTLNode* right, int num) {
	type_ = type;
	left_ = left;
	right_ = right;
	content_ = NULL;
	num_ = num;
}

LTLNode::LTLNode(int type, FluentAction* content, int num) {
	type_ = type;
	left_ = NULL;
	right_ = NULL;
	content_ = content;
	num_ = num;
}

LTLNode::~LTLNode() { }

int LTLNode::getType() {
	return type_;
}

string LTLNode::getClass() {
	string result = "";

	if(type_ == Parser::IDENTIFIER) {
		result = content_->getClass();
	}
	else if(type_ == Parser::LTLNOT) {
		result = right_->getClass();
	}

	return result;
}

LTLNode* LTLNode::getRightNode() {
	return right_;
}

bool LTLNode::printContent(Printer* p, string T, bool sign=true) {
	// works and used also for non-content nodes because of empty class
	
	bool is_ltl_function = false;

	if(type_ == Parser::IDENTIFIER) {
		if(p->no_direct_enc) {
			if(getClass() == "Action") p->add("occurs(");
			else if(getClass() == "Fluent") p->add("holds(");

			if(!sign) {
				p->add("neg(" + content_->print(p, T) + ")");
				return false;
			}
			else p->add(content_->print(p, T));
		}
		else p->add(content_->print(p, T));
	}
	else if(type_ == Parser::LTLNOT) {
		if(p->no_direct_enc && getClass() != "") {
			right_->printContent(p, T, false);
		}
		else if(getClass() != "") {
			p->add(p->neg);
			right_->printContent(p, T);
		}
		else {
			p->add("not ");
			is_ltl_function = right_->printContent(p, T);

			// end function here because content was printed
			// else the predicate would be ended twice at the end
			return is_ltl_function;
		}
	}
	else {
		is_ltl_function = true;

		// construct name of auxilliary function
		string name = "";
		stringstream ss; ss << num_; ss >> name;
		name = "ltl_f_" + name;

		// special negation for LTL predicate
		if(!sign) p->add("not ");
		
		if(p->no_direct_enc) {
			if(p->inc) p->add("holds(" + name + ", t");
			else p->add("holds(" + name);
		}
		else {
			if(p->inc) p->add(name + "(t, " + T + ")");
			else p->add(name + "(" + T + ")");
		}
	}

	// end predicate for non-direct encoding
	if(p->no_direct_enc) p->add(", " + T + ")");
	
	return is_ltl_function;
}

void LTLNode::printLTLCondition(Printer* p, string T) {
	string name = "";
	stringstream ss; ss << num_; ss >> name;
	name = "ltl_f_" + name;

	if(p->no_direct_enc) p->add("holds(ltl_c("+name+"), "+T+")");
	else p->add("ltl_c("+name+", "+T+")");
}

void LTLNode::print(Printer* p, bool top_node, bool time_node) {
	switch(type_) {
		case Parser::IDENTIFIER:
			printContent(p, "0");
			p->add(".\n");
			break;
		case Parser::LTLNOT:
			if(printContent(p, "0", false)) {
				p->add(".\n");
				right_->check_and_print(p, time_node);
			}
			else p->add(".\n");
			break;
		case Parser::AND:
			printAND(p, top_node, time_node);
			break;
		case Parser::LTLOR:
			printOR(p, top_node, time_node);
			break;
		case Parser::IMPL:
			p->add("->");
			throw std::runtime_error("LTL Queries are not yet supported with implication.");
			break;
		case Parser::EQUIV:
			p->add("=");
			throw std::runtime_error("LTL Queries are not yet supported with equations.");
			break;
		case Parser::X:
			printX(p, top_node, time_node);
			break;
		case Parser::F:
			printF(p, top_node, time_node);
			break;
		case Parser::G:
			printG(p, top_node, time_node);
			break;
		case Parser::U:
			printU(p, top_node, time_node);
			break;
		case Parser::R:
			printR(p, top_node, time_node);
			break;
		default:
            cerr << "\ntype: " << type_;
            throw std::runtime_error("Unknown LTL Query element.");
			break;
	}
}

void LTLNode::check_and_print(Printer* p, bool time_node) {
	if(type_ != Parser::IDENTIFIER && type_ != Parser::LTLNOT) {
		print(p, false, time_node);
	}
}

// TODO simplify LTL formula

void LTLNode::printAND(Printer* p, bool top_node, bool time_node) {
	string T = "0";
	
	// loop to print it two times, one for T=0 and one for general T
	for(int i=0; i<=1; ++i) {
		if(i == 1) {
			if(!top_node && time_node) {
				if(p->inc) T = "t";
				else T = "T";
			}
			else break;
		}
		
		if(!top_node) {
			printContent(p, T);
			p->add(" :- ");
		}
		
		left_->printContent(p, T);
		p->add(", ");
		right_->printContent(p, T);
		p->add(".\n");
	}

	left_->check_and_print(p, time_node);
	right_->check_and_print(p, time_node);
}

void LTLNode::printOR(Printer* p, bool top_node, bool time_node) {
	string T = "0";
	
	if(top_node) {
		printContent(p, T);
		p->add(".\n");
	}

	// loop to print it two times, one for T=0 and one for general T
	for(int i=0; i<=1; ++i) {
		if(i == 1) {
			if(!top_node && time_node) {
				if(p->inc) T = "t";
				else T = "T";
			}
			else break;
		}
		
		printContent(p, T);
		p->add(" :- ");
		left_->printContent(p, T);
		p->add(".\n");

		printContent(p, T);
		p->add(" :- ");
		right_->printContent(p, T);
		p->add(".\n");
	}
	
	left_->check_and_print(p, time_node);
	right_->check_and_print(p, time_node);
}

void LTLNode::printX(Printer* p, bool top_node, bool time_node) {
	string T = "0";
	
	if(top_node) {
		printContent(p, T);
		p->add(".\n");
	}
	
	// loop to print it two times, one for T=0 and one for general T
	for(int i=0; i<=1; ++i) {
		if(i == 1) {
			if(!top_node && time_node) {
				if(p->inc) T = "t";
				else T = "T";
				
				printContent(p, "t");
				p->add(" :- after_loop_start("+T+"), ");
				right_->printContent(p, T);
				p->add(".\n");
			}
			else break;
		}
		
		if(T != "0" && p->inc) {
			printContent(p, "T-1");
			p->add(" :- ");
			right_->printContent(p, "T");
			p->add(", time(T)");
			p->add(".\n");
		}
		else {
			printContent(p, T);
			p->add(" :- ");
			right_->printContent(p, T+"+1");
			p->add(".\n");
		}
	}
	
	right_->check_and_print(p, true);
}

void LTLNode::printF(Printer* p, bool top_node, bool time_node) {
	string T = "0";
	
	if(top_node) {
		printContent(p, T);
		p->add(".\n");
	}
	
	for(int i=0; i<=1; ++i) {
		if(i == 1) {
			if(p->inc) {
				T = "t";
				printContent(p, T);
				p->add(" :- after_loop_start("+T+"-1), ");
				printContent(p, T+"-1");
				p->add(".\n");
			}
			else {
				T = "T";
				printContent(p, "t+1");
				p->add(" :- after_loop_start("+T+"), ");
				printContent(p, T);
				p->add(".\n");
			}
		}
		
		printContent(p, T);
		p->add(" :- ");
		right_->printContent(p, T);
		p->add(".\n");

		if(T != "0" && p->inc) {
			printContent(p, "T-1");
			p->add(" :- ");
			printContent(p, "T");
			p->add(", time(T)");
			p->add(".\n");
		}
		else {
			printContent(p, T);
			p->add(" :- ");
			printContent(p, T+"+1");
			p->add(".\n");
		}
	}
	
	right_->check_and_print(p, true);
}

void LTLNode::printG(Printer* p, bool top_node, bool time_node) {
	string T = "0";
	
	if(top_node) {
		printContent(p, T);
		p->add(".\n");
	}
	
	for(int i=0; i<=1; ++i) {
		if(i == 1) {
			if(p->inc) T = "t";
			else T = "T";
			
			printLTLCondition(p, T);
			p->add(" :- in_loop("+T+"), not ");
			right_->printContent(p, T);
			p->add(".\n");
		}
		
		if(T != "0" && p->inc) {
			printContent(p, "T-1");
			p->add(" :- ");
			right_->printContent(p, "T-1");
			p->add(", ");
			printContent(p, "T");
			p->add(", time(T).\n");
		}
		else {
			printContent(p, T);
			p->add(" :- ");
			right_->printContent(p, T);
			p->add(", ");
			printContent(p, T+"+1");
			p->add(".\n");
		}

		if(p->inc) {
			printContent(p, "t");
			p->add(" :- loop_exists(t-1), not ");
		}
		else {
			printContent(p, "t+1");
			p->add(" :- loop_exists, not ");
		}
		printLTLCondition(p, T);
		p->add(".\n"); // TODO check t+1 and "+T+"
	}
	
	right_->check_and_print(p, true);
}

void LTLNode::printU(Printer* p, bool top_node, bool time_node) {
	string T = "0";
	
	if(top_node) {
		printContent(p, T);
		p->add(".\n");
	}
	
	for(int i=0; i<=1; ++i) {
		if(i == 1) {
			if(p->inc) {
				T = "t";
				printContent(p, T);
				p->add(" :- after_loop_start("+T+"-1), ");
				printContent(p, T+"-1");
				p->add(".\n");
			}
			else {
				T = "T";
				printContent(p, "t+1");
				p->add(" :- after_loop_start("+T+"), ");
				printContent(p, T);
				p->add(".\n");
			}
		}
		
		printContent(p, T);
		p->add(" :- ");
		right_->printContent(p , T);
		p->add(".\n");
		
		if(T != "0" && p->inc) {
			printContent(p, "T-1");
			p->add(" :- ");
			left_->printContent(p, "T-1");
			p->add(", ");
			printContent(p, "T");
			p->add(", time(T).\n");
		}
		else {
			printContent(p, T);
			p->add(" :- ");
			left_->printContent(p, T);
			p->add(", ");
			printContent(p, T+"+1");
			p->add(".\n");
		}
	}
	
	left_->check_and_print(p, true);
	right_->check_and_print(p, true);
}

void LTLNode::printR(Printer* p, bool top_node, bool time_node) {
	string T = "0";
	
	if(top_node) {
		printContent(p, T);
		p->add(".\n");
	}
	
	for(int i=0; i<=1; ++i) {
		if(i == 1) {
			if(p->inc) {
				T = "t";
				printContent(p, T);
				p->add(" :- after_loop_start("+T+"-1), ");
				printContent(p, T+"-1");
				p->add(".\n");
			}
			else {
				T = "T";
				printContent(p, "t+1");
				p->add(" :- after_loop_start("+T+"), ");
				printContent(p, T);
				p->add(".\n");
			}

			printLTLCondition(p, T);
			p->add(" :- in_loop("+T+"), not ");
			right_->printContent(p, T); // TODO check this holds line
			p->add(".\n");
		}
		
		printContent(p, T);
		p->add(" :- ");
		right_->printContent(p, T);
		p->add(", ");
		left_->printContent(p, T);
		p->add(".\n");

		if(T != "0" && p->inc) {
			printContent(p, "T-1");
			p->add(" :- ");
			right_->printContent(p, "T-1");
			p->add(", ");
			printContent(p, "T");
			p->add(", time(T).\n");
		}
		else {
			printContent(p, T);
			p->add(" :- ");
			right_->printContent(p, T);
			p->add(", ");
			printContent(p, T+"+1");
			p->add(".\n");
		}

		if(p->inc) {
			printContent(p, "t");
			p->add(" :- loop_exists(t-1), not ");
			printLTLCondition(p, T);
			p->add(".\n"); // TODO check t+1 and T
		}
		else {
			printContent(p, "t+1");
			p->add(" :- loop_exists, not ");
			printLTLCondition(p, T);
			p->add(".\n"); // TODO check t+1 and T
		}
	}
	
	left_->check_and_print(p, true);
	right_->check_and_print(p, true);
}

void LTLNode::print_debug(Printer* p) {
	switch(type_) {
		case Parser::LTLNOT:
			p->add(" !");
			break;
		case Parser::AND:
			p->add(" &");
			break;
		case Parser::LTLOR:
			p->add(" |");
			break;
		case Parser::IMPL:
			p->add(" ->");
			break;
		case Parser::EQUIV:
			p->add(" <->");
			break;
		case Parser::X:
			p->add(" X");
			break;
		case Parser::F:
			p->add(" F");
			break;
		case Parser::G:
			p->add(" G");
			break;
		case Parser::U:
			p->add(" U");
			break;
		case Parser::R:
			p->add(" R");
			break;
		case Parser::IDENTIFIER:
			p->add(" ");
			break;
		default:
			p->add(" ?");
            cout << type_ << "?";
			break;
	}
	
	cout << "[" << num_ << "]";
	
	if(type_ != Parser::IDENTIFIER) p->add("(");
	
	if(left_ && right_) {
		left_->print_debug(p);
		p->add(",");
		right_->print_debug(p);
	}
	else if(right_) {
		right_->print_debug(p);
	}
	else if(content_) {
		p->add(content_->print(p));
	}
	else {
		p->add("ERROR LTL DEBUG PRINT");
	}
	
	if(type_ != Parser::IDENTIFIER) p->add(") ");
	
	cout.flush();
}

LTLNode* LTLNode::negate(bool negation=false) {
	// if negation of parent node was just removed, do nothing
	if(negation) return this;

	switch(type_) {
		case Parser::LTLNOT:
			// if three negations in a row, delete first two and negate third one
			if(right_->getType() == Parser::LTLNOT && right_->getRightNode()->getType() == Parser::LTLNOT) {
				delete right_;
				delete this;
				return right_->getRightNode()->negate();
			}
			else {
				delete this;
				return right_->negate(true);
			}
        case Parser::AND:
            type_ = Parser::LTLOR;
            left_ = left_->negate();
            right_ = right_->negate();
            return this;
        case Parser::LTLOR:
            type_ = Parser::AND;
            left_ = left_->negate();
            right_ = right_->negate();
            return this;
        case Parser::IMPL:
            // TODO
            return this;
        case Parser::EQUIV:
            // TODO
            return this;
        case Parser::X:
            right_ = right_->negate();
            return this;
        case Parser::F:
            type_ = Parser::G;
            right_ = right_->negate();
            return this;
        case Parser::G:
            type_ = Parser::F;
            right_ = right_->negate();
            return this;
        case Parser::U:
			type_ = Parser::R;
			left_ = left_->negate();
			right_ = right_->negate();
            return this;
        case Parser::R:
			type_ = Parser::U;
			left_ = left_->negate();
			right_ = right_->negate();
			return this;
        case Parser::IDENTIFIER:
            // assumes negation nodes can all have 0 as num_
			return new LTLNode(Parser::LTLNOT, NULL, this, 0);
        default:
            throw std::runtime_error("Unknown LTL symbol in negation.");
            return this;
    }
}

void LTLNode::printIncPropagator(Printer* p) {
	if(type_ == Parser::LTLNOT) {
		right_->printIncPropagator(p);
	}
	else if(type_ == Parser::X || type_ == Parser::F || type_ == Parser::G) {
		printIncPropagatorForThisNode(p);

		right_->printIncPropagator(p);
	}
	else if(type_ == Parser::AND || type_ == Parser::LTLOR || type_ == Parser::U || type_ == Parser::R) {
		printIncPropagatorForThisNode(p);

		left_->printIncPropagator(p);
		right_->printIncPropagator(p);
	}
}

void LTLNode::printIncPropagatorForThisNode(Printer* p) {
	printContent(p, "T");
	p->add(" :- ");

	// construct name of auxilliary function
	string name = "";
	stringstream ss; ss << num_; ss >> name;
	name = "ltl_f_" + name;

	if(p->no_direct_enc) p->add("holds(" + name + ", t-1, T)");
	else p->add(name + "(t-1, T)");

	p->add(", time(T).\n");
}
