"""
File        : GraphCluster.py
Author      : Pablo Boixeda & Ramon Aragues
Creation    : 4.2005
Contents    : Implements a Graph where the nodes are clusters composed of node attributes of another type of graph
Called from : Programs that implements graph clustering

=======================================================================================================

Implements a Graph where the nodes are clusters composed of node attributes of another type of graph

"""

# GraphCluster.py: Implements a Graph where the nodes are clusters composed of node attributes of another type of graph
#
# Copyright (C) 2005  Ramon Aragues
# author email: ramon.aragues@upf.edu and boliva@imim.es
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#    http://www.gnu.org/copyleft/gpl.html
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
# University Pompeu Fabra, hereby disclaims all copyright
# interest in the program 'PIANA'
# (software for working with protein-protein interaction networks) written 
# by Ramon Aragues

from Graph import *
from GraphNode import *
from GraphClusterNodeAttribute import *

verbose = 0

verbose_detailed = 0

#-----------------------------------------------------------------------------------------------
class GraphCluster(Graph):
#-----------------------------------------------------------------------------------------------

    def __init__(self, graph_id=None):
        """
        Method used to initialise ClusterGraph object
        """
        
        self.old_node_correspondence={}

        self.node_id_counter = 0
        
        Graph.__init__(self, graphID= graph_id)

    def _get_new_node_id(self):
        """
        This method is created to give an id to new nodes

        Attention!!! Node Ids are normally controlled by class Clustering: this method is only used when we are initializing
        the GraphCluster from a Graph object (one cluster for each Node in Graph)
        """
        new_node_id = self.node_id_counter
        self.node_id_counter += 1
        return new_node_id


    def get_node_alternative_id(self, node_attribute_object):
        """
        returns the alternative id that will describe this node
        """
        alternative_id = "%s" %(node_attribute_object.get_node_id())

        # create string that will describe this node
        for one_element in node_attribute_object.get_list_elements():
            alternative_id += ".%s" %(one_element.get_node_id())

        return alternative_id

    
    def initialize_from_graph(self, graph):
        """
        initialize the GraphCluster from a Graph: places one GraphNodeAttribute in each cluster (ie. in each GraphClusterNodeAttribute)
        """
        if graph is None:    raise ValueError ("No Graph given to transform into GraphCluster")

        if verbose:   sys.stderr.write("Initializing GraphCluster from graph\n")

        dic_correspondences_nodes = {}   # dic that follows this structure:
                                        #
                                        #     { node_id in graph: node_object in GraphCluster,
                                        #       node_id in graph: node_object in GraphCluster,
                                        #       ................
                                        #     }

            
        for one_graph_edge in graph.get_edge_object_list():
            # for each edge, create a link between two Graph nodes (with GraphClusterNodeAttribute attributes) 
            # --
            # NODE a
            # --
            graph_node_id_a = one_graph_edge.get_start_node_id()

            if dic_correspondences_nodes.has_key(graph_node_id_a):
                # the GraphClusterNode was already created for this GraphNode... use it
                cluster_node_a = dic_correspondences_nodes[graph_node_id_a]
                cluster_id_a = cluster_node_a.get_node_id()
            else:
                # a new GraphClusterNode needs to be created.

                # get the attribute for graph_node_id_a
                graph_node_attribute_a = graph.get_node(identifier= graph_node_id_a, get_mode="error").get_node_attribute_object()

                cluster_id_a = self._get_new_node_id()

                # create nodes for IRGraph (standard GraphNode with a IRGraphNodeAttribute)
                cluster_attribute_a = GraphClusterNodeAttribute(node_id=cluster_id_a,
                                                                list_elements=[graph_node_attribute_a])

                cluster_node_a = GraphNode(nodeID= cluster_id_a , attribute=cluster_attribute_a,
                                           original = 1,
                                           alternative_id= self.get_node_alternative_id(node_attribute_object=cluster_attribute_a))

                
                self.add_node(cluster_node_a)
                dic_correspondences_nodes[graph_node_id_a] = cluster_node_a
            # END OF else: (if dic_correspondece_nodes.has_key(node_id_a):)
            
            # --
            # NODE b
            # --
            graph_node_id_b = one_graph_edge.get_end_node_id()

            if dic_correspondences_nodes.has_key(graph_node_id_b):
                # the GraphClusterNode was already created for this GraphNode... use it
                cluster_node_b = dic_correspondences_nodes[graph_node_id_b]
                cluster_id_b = cluster_node_b.get_node_id()
            else:
                # a new GraphClusterNode needs to be created.

                # get the attribute for graph_node_id_b
                graph_node_attribute_b = graph.get_node(identifier= graph_node_id_b, get_mode="error").get_node_attribute_object()

                cluster_id_b = self._get_new_node_id()

                # create nodes for IRGraph (standard GraphNode with a IRGraphNodeAttribute)
                cluster_attribute_b = GraphClusterNodeAttribute(node_id=cluster_id_b,
                                                                list_elements=[graph_node_attribute_b])

                cluster_node_b = GraphNode(nodeID= cluster_id_b , attribute=cluster_attribute_b ,
                                           original = 1,
                                           alternative_id= self.get_node_alternative_id(node_attribute_object=cluster_attribute_b) )
                
                self.add_node(cluster_node_b)
                dic_correspondences_nodes[graph_node_id_b] = cluster_node_b
            # END OF else: (if dic_correspondece_nodes.has_key(node_id_b):)

            # --
            # EDGE
            # --

            # create edge for GraphCluster (standard GraphEdge)
            new_edge_attribute = GraphEdgeAttribute() # empty attribute: nothing to do with it at the moment
            new_edge= GraphEdge(node1_id=cluster_id_a , node2_id= cluster_id_b,
                                attribute_object= new_edge_attribute)
            
            self.add_edge(new_edge)
            
        # END OF for one_graph_edge in graph.get_edge_object_list():


        

    def do_action(self):
        """
        this is a generic method that can be used by particular GraphCluster subclasses to do
        something after the clustering of this level (eg. in CIR, we fuse as well the clusters
        that have a common cluster partner).

        It is called from Clustering.cluster_graph after each step of the clustering

        In those GraphClusters that nothing has to be done are not affected, since this method won't exist
        and this generic method will be called (which doesn't do anything)
        """
        pass
    
    def create_grouped_node(self, node_id1= None, node_id2= None, new_node_id= None, old_graph= None):
        """
        This method joins two GraphNodes into one GraphNode. These two node attributtes are inserted
        into the new node attribute. This attribute is a list of attributes

        returns the resulting node

        "node_id1" is the id from one of the clusters to join
        
        "node_id2" is the id from one of the clusters to join
        
        "new_node_id" is the id of the new cluster node to create.
        
        "old_graph" is the graph that contains the nodes with node_id1 and node_id2
        
        """
        if node_id1 is None or node_id2 is None:
            raise ValueError("Error: how can I create a clustered node from a None id?")

        if verbose_detailed:
            sys.stderr.write("Joining nodes %s and %s to form a new node with id %s\n" %(node_id1, node_id2, new_node_id))

        node1=old_graph.get_node(node_id1)
        node2=old_graph.get_node(node_id2)
        
        if new_node_id is None:
            # if user didn't give us a node id, get it here from the old_graph
            new_node_id= self._get_new_node_id()

        NewClusterNodeAttribute= GraphClusterNodeAttribute(node_id= new_node_id)
        NewClusterNodeAttribute.add_element_list(list_node_attribute_object= node1.get_node_attribute_object().get_list_elements())
        NewClusterNodeAttribute.add_element_list(list_node_attribute_object= node2.get_node_attribute_object().get_list_elements())
            
        
        NewNode=GraphNode(nodeID= new_node_id, 
                          alternative_id= self.get_node_alternative_id(node_attribute_object=NewClusterNodeAttribute),
                          attribute = NewClusterNodeAttribute,
                          original = 0,
                          graph = self)

        return NewNode

    

    def print_cluster_composition(self, output_target= sys.stdout):
        """
        prints the node ids for each cluster in the graph

        This is very general: should be overwritten by a method specific to the clustering being performed
        """
        nodes=self.get_node_object_list()
        
        for node in nodes:
            # node is a cluster, which has an attribute that contains elements (each element is a GraphNodeAttribute of
            #                                                                   the graph we are clustering)
            output_target.write( "cluster %s: " %node.get_node_id() )
            elements_list=node.get_node_attribute_object().get_list_elements()
            for element in elements_list:
                output_target.write("%s --" %(element.get_node_id()) )
            output_target.write("\n")

    def print_pairs_same_cluster(self, output_target):
        """
        prints pairs of elements  that appear in the same cluster

        To be overwritten by the subclass of GraphCluster in case an element is not described by its identifier
        """
        nodes=self.get_node_object_list()
        
        for node in nodes:
            elements_list=node.get_node_attribute_object().get_list_elements()
            number_of_elements = len(elements_list)
            for i in range(number_of_elements):
                for j in range(i+1, number_of_elements):
                    output_target.write( "p;%s-%s\n" %(elements_list[i].get_node_id(), elements_list[j].get_node_id()) )
                # END OF for j in range(i+1, number_of_elements):
            # END OF for i in range(number_of_elements):
        # END OF for node in nodes:
        
    
    def print_cluster_interactions(self, output_target= sys.stdout):
        """
        prints the edge table of the graph

        This is very general: should be overwritten by a method specific to the clustering being performed
        """
        self.output_edges_table(output_target=output_target)
            

    def print_pairs_interactions(self, output_target):
        """
        prints pairs of elements whose clusters interact

        To be overwritten by the subclass of GraphCluster in case an element is not described by its identifier
        """
        for edge in self.get_edge_object_list():

            node_id_start = edge.get_start_node_id()
            node_start = self.get_node(identifier= node_id_start, get_mode="error") 
            list_elements_start = node_start.get_node_attribute_object().get_list_elements()
            
            node_id_end = edge.get_end_node_id()
            node_end = self.get_node(identifier= node_id_end, get_mode="error") 
            list_elements_end = node_end.get_node_attribute_object().get_list_elements()

            for one_element_start in list_elements_start:
                for one_element_end in list_elements_end:
                    output_target.write("i;%s-%s\n" %(one_element_start.get_node_id(), one_element_end.get_node_id()) )
                # END OF for one_element_end in list_elements_end:
            # END OF for one_element_start in list_elements_start:



    # TO DO!!! These two methods (print_proteins_*) should not be here but on a higher class IPGGraphCluster
    
    def print_proteins_same_cluster(self, output_target, root_protein):
        """
        prints to "output_target" pairs of proteins (ie. proteins with elements in that cluster)  that appear in the same cluster

        "root_protein" is used to uniquely identify the clusters (clusters have unique ids if they were all obtained from the
                       same network... but since we are doing multiple calls from run_piana_protein_by_protein, clusters
                       need to be distinguished by their id in the network and the root protein that was used to generate them 
        """

        already_printed_pairs = {}  # to avoid printing repetitions
        
        nodes= self.get_node_object_list()
        
        for node in nodes:
            elements_list= node.get_node_attribute_object().get_list_elements()
            number_of_elements = len(elements_list)
            for i in range(number_of_elements):
                for j in range(i+1, number_of_elements):
		    # the way this loop is designed makes sure that only clusters with more than one element are printed out.

                    proteinPiana_i = elements_list[i].get_proteinPiana()
                    proteinPiana_j = elements_list[j].get_proteinPiana()
                    if proteinPiana_i != proteinPiana_j:

                        # make sure that no duplication of pairs occurs...
                        if proteinPiana_i <= proteinPiana_j:
                            p1 = proteinPiana_i
                            p2 = proteinPiana_j
                        else:
                            p1 = proteinPiana_j
                            p2 = proteinPiana_i
                        
                        key = "%s.%s" %(p1, p2)
                        if already_printed_pairs.has_key(key):
                            continue
                        already_printed_pairs[key] = None

                        cluster_name = "%s.%s" %(root_protein, node.get_node_id())
                        
                        output_target.write("share\t%s\t%s\t%s\n" %(cluster_name , p1, p2) )
                # END OF for j in range(i+1, number_of_elements):
            # END OF for i in range(number_of_elements):
        # END OF for node in nodes:
            

    def print_proteins_interactions(self, output_target, root_protein):
        """
        prints  to "output_target" pairs of proteins whose clusters interact
        
        "root_protein" is used to uniquely identify the clusters (clusters have unique ids if they were all obtained from the
                       same network... but since we are doing multiple calls from run_multiple_pianas, clusters
                       need to be distinguished by their id in the network and the root protein that was used to generate them 
        """
        already_printed_pairs = {}  # to avoid printing repetitions
        
        for edge in self.get_edge_object_list():

            node_id_start = edge.get_start_node_id()
            node_start = self.get_node(identifier= node_id_start, get_mode="error") 
            list_elements_start = node_start.get_node_attribute_object().get_list_elements()
            
            node_id_end = edge.get_end_node_id()
            node_end = self.get_node(identifier= node_id_end, get_mode="error") 
            list_elements_end = node_end.get_node_attribute_object().get_list_elements()

            for one_element_start in list_elements_start:
                for one_element_end in list_elements_end:

                    proteinPiana_start = one_element_start.get_proteinPiana()
                    proteinPiana_end = one_element_end.get_proteinPiana()
                    
                    # make sure that no duplication of pairs occurs...
                    if proteinPiana_start <= proteinPiana_end:
                        p1 = proteinPiana_start
                        cluster_1_name = "%s.%s" %(root_protein, node_start.get_node_id())
                        p2 = proteinPiana_end
                        cluster_2_name = "%s.%s" %(root_protein, node_end.get_node_id())
                    else:
                        p1 = proteinPiana_end
                        cluster_1_name = "%s.%s" %(root_protein, node_end.get_node_id())
                        p2 = proteinPiana_start
                        cluster_2_name = "%s.%s" %(root_protein, node_start.get_node_id())
                        
                    key = "%s.%s" %(p1, p2)
                    if already_printed_pairs.has_key(key):
                        continue
                    already_printed_pairs[key] = None

                    output_target.write("int\t%s-%s\t%s\t%s\n" %(cluster_1_name, cluster_2_name, p1, p2) )
                # END OF for one_element_end in list_elements_end:
            # END OF for one_element_start in list_elements_start:
        # END OF for edge in self.get_edge_object_list():

