#!/usr/bin/env python2.2

import sys
sys.path.append("./BioGraph")
import string
from Graph import *
from BioDatabases import *
import MySQLdb

class RouteFinder:
    """
    The main class controlling the route finding between two proteins.
    """
    def __init__(self, start=None, target=None, databases=[]):
        self.databases = (databases)
        self.start = start
        self.target = target
        self.graph = None
        self.depth = 3
        self.filters = []

    def set_start_protein(self,start):
        self.start = start

    def set_target_protein(self,target):
        self.target = target

    def add_database(self,db):
        if not isinstance(db,BioDatabase):
            raise TypeError("Must be passed a BioDatabase object")
        self.databases.append(db)

    def clear_databases(self):
        self.databases = []

    def rm_database(self,db):
        if db not in self.databases:
            raise Exception("Database not in database list")
        self.databases.remove(db)

    def add_filter(self, filter):
        if not isinstance(filter, Filter):
            raise TypeError("Must be passed a Filter object")
        self.filters.append(filter)

    def clear_filters(self):
        self.filters = []

    def rm_filter(self, filter):
        if filter not in self.filters:
            raise Exceptions("Filter not in filter list")
        self.filters.remove(filter)

    def find_routes(self):
        if self.graph is None:
            self.init_graph()
        # see if target node is in graph
        try:
            self.graph.get_node(self.target,mode="error")
            return self.graph.find_shortest_route(self.start, self.target)
        except NodeNotFoundError:
            return "Nodes not linked on depth level %s" %self.depth

    def init_graph(self):
        self.graph = Graph()
        for db in self.databases:
            # get all links for start and target protein, up to depth
            print "Getting all links for start protein %s" %self.start
            links = db.get_all_links(self.start,depth=self.depth)
            print "Getting all links for target proteins %s" %self.target
            links += db.get_all_links(self.target,depth=self.depth)
            for link in links:
                #print "adding link",link
                node1 = self.graph.get_node(link.node1_name, link.get_node1_attributes(),mode="new")
                node2 = self.graph.get_node(link.node2_name, link.get_node2_attributes(),mode="new")
                edge = self.graph.get_edge(node1,node2,link.get_edge_attributes(),mode="new")
                self.graph.add_edge(edge)
        # apply the filtering here
        print "Applying filtering"
        self.graph.filter(self.filters)
        print "Done!"
        print self.graph

    def output_graph(self, output_target=None):

        if output_target is None:
            raise ValueError("output_target needed to print graph")
        
        graph.output_dot_file(output_target= output_target)


if __name__ == "__main__":
    string_db = StringDB()
    rf = RouteFinder(databases=[string_db])
    print "Please enter an start protein:"
    start = string.rstrip(sys.stdin.readline())
    rf.set_start_protein(start)
    print "Please enter a target protein:"
    target = string.rstrip(sys.stdin.readline())
    rf.set_target_protein(target)
    rf.init_graph()
    print rf.find_routes()
