"""
File        : GoClusteringSimilarityFunction.py
Author      : Pablo Boixeda & Ramon Aragues
Creation    : 6.2005
Contents    : Similarity function that determines similarity between clusters with go terms
Called from : 

=======================================================================================================

"""

# GoClusteringSimilarityFunction.py: implements a class for finding similarity according to GO terms
#
# Copyright (C) 2005  Ramon Aragues
# author email: ramon.aragues@upf.edu and boliva@imim.es
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#    http://www.gnu.org/copyleft/gpl.html
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
# University Pompeu Fabra, hereby disclaims all copyright
# interest in the program 'PIANA'
# (software for working with protein-protein interaction networks) written 
# by Ramon Aragues

import numarray

import PianaGlobals
from GraphCluster import *
from Clustering import *
from ClusteringSimilarityFunction import *

verbose = 0

#------------------------------------------------------------------------------------------------
class GoClusteringSimilarityFunction(ClusteringSimilarityFunction):
    """
    Similarity function that determines similarity between clusters with go terms
    """
    
    def __init__(self, piana_access= None, term_type= None, mode= None, path_length_threshold= None):

        """
        
        "piana_access" is a database accession object used to access information from a PIANA database

        "term_type" specifies which kind of GO terms we will use for the clustering.

          valid values are:
             - molecular_function
             - biological_process
             - cellular_component

        "mode" defines how to evaluate the distance between two clusters:
          - random takes a random element from each cluster and evaluates similarity between them
          - min takes the minimal distance between elements of each cluster
          - max takes the maximal distance between elements of each cluster
          - average takes the average distance between all elements of each cluster


        "path_length_threshold": Maximum distance between two clusters to be joined

        """
            
        # TO DO!!! path length threshold should be applied to the original proteins, not to the clusters!!! Otherwise, at the end we are joining
        # proteins that in fact were far away from each other...


        self.term_type = term_type
        self.highest_depth = PianaGlobals.huge_value  # just a big number to indicate nothing has been done yet

        ClusteringSimilarityFunction.__init__(self, dbaccess = piana_access, mode = mode, path_length_threshold_value = path_length_threshold)

    def get_highest_depth(self):
        """
        returns highest detph found

        used by the go clustering stop condition to stop the clustering
        """
        return self.highest_depth

    def get_proteinPianas_list(self, list_node_attribute=None):
        """
        Method that returns a protein piana list, from an attributes list given

        "list_node_attribute": list of attributes into a cluster. Each attribute is a protein pianaID
        """
        proteinPiana_list=[]

        for node_attribute in list_node_attribute:
            proteinPiana_list.append(node_attribute.get_proteinPiana())

        return proteinPiana_list

    def get_terms_depth(self, list_of_terms_ids=None, search_mode=None):
        """
        Method that searches the go term with maximum or minimum depth in a cluster

        "list_of_terms_ids": the list of term ids in the cluster

        "search_mode": determines which of the terms will be returned:

           - min: returns the term id that has the lowest depth
           - max: returns the term id that has the highest depth

        returns the term id matching the search criteria
        """

        # TO DO!!! Change everywhere highest and lowest, and min and max, for something
        #          easier to understand, such as "specific" and "general"
        #   --> min and max refer to the degree of generality a certain term has, how high or low
        #       it appears in the hierarchy

        if search_mode is None:
            raise ValueError("Error:Search Mode can't be None")

        if search_mode=="max":
            max_depth=0
            for i in range(len(list_of_terms_ids)):
                current_depth = self.dbaccess.get_go_depth(list_of_terms_ids[i])
                if current_depth > max_depth:
                    max_depth = self.dbaccess.get_go_depth(list_of_terms_ids[i])
                    term = list_of_terms_ids[i]
            return term

        elif search_mode=="min":
            min_depth=10000
            for i in range(len(list_of_terms_ids)):
                current_depth = self.dbaccess.get_go_depth(list_of_terms_ids[i])
                if current_depth < min_depth:
                    min_depth = self.dbaccess.get_go_depth(list_of_terms_ids[i])
                    term = list_of_terms_ids[i]
            return term

    def calculate_formula(self,term1=None, term2= None):
        """
        Method that calculates the formula using two specific elements inside
        a cluster:
        
        "term1": Go term id of the first element

        "term2": Go term id of the second element
        """
        
        depth_cluster1= self.dbaccess.get_go_depth(term1)
        depth_cluster2= self.dbaccess.get_go_depth(term2)
        distance_cluster1_2_cluster2= self.dbaccess.get_term2term_distance(term1 = term1,
                                                                          term2 = term2)

        if depth_cluster1 and depth_cluster1 < self.highest_depth:
            self.highest_depth = depth_cluster1
            
        if depth_cluster2 and depth_cluster2 < self.highest_depth:
            self.highest_depth = depth_cluster2
        
        if distance_cluster1_2_cluster2 == 0:
            # if distance between the clusters is 0 (same term?), similarity is huge
            return PianaGlobals.huge_value
            
        elif distance_cluster1_2_cluster2 == PianaGlobals.huge_distance or distance_cluster1_2_cluster2 is None:
            # if distance is huge, similarity is 0
            # If no distance is found on database, this is the case that will be applied...
            return 0
            
        elif depth_cluster1 is None or depth_cluster2 is None:
            # if no depth found for clusters, similarity is 0
            return 0
        else :
            
            return (1 + abs(depth_cluster1 - depth_cluster2)) / float( self.path_length_threshold + distance_cluster1_2_cluster2)

    def calculate_similarity(self, list_node_attributes1, list_node_attributes2,
                             cluster1_id= None, cluster2_id=None,
                             clustered_graph= None, original_graph=None):
        """
        Method that returns similarity score from two lists of attributes of nodes that are being clustered

        "list_node_attributes1" is a list of node attributes that belong to the same cluster
        
        "list_node_attributes2" is a list of node attributes that belong to the same cluster

        The following arguments are  here just for being compatible with the general call to calculate_similarity
        
        "cluster1_id" is the id for cluster 1 (not used for GO clustering)
        "cluster2_id" is the id for cluster 1 (not used for GO clustering)
        
        "clustered_graph" is the current ClusterGraph (not used for GO clustering)
        "original_graph" (not used for GO clustering)

        This method calculates how similar two GO clusters (ie two groups of GO terms) are
        """

        go_terms_list1=[]
        for node_attribute1 in list_node_attributes1:
            go_terms_list1.append(node_attribute1.get_term_id())

            
        # get go terms for cluster2 from its proteinPianas, using self.term_type
        go_terms_list2=[]
        for node_attribute2 in list_node_attributes2:
            go_terms_list2.append(node_attribute2.get_term_id())

        if self.mode=="random":

            return self.calculate_formula(term1=go_terms_list1[0], term2= go_terms_list2[0])
            
        elif self.mode=="min":

            similarity = PianaGlobals.huge_value
            for go_term1 in go_terms_list1:
                for go_term2 in go_terms_list2:
                    current_similarity = self.calculate_formula(term1=go_term1, term2= go_term2)
                    if current_similarity < similarity:
                        similarity=current_similarity

        elif self.mode=="max":

            similarity = 0
            for go_term1 in go_terms_list1:
                for go_term2 in go_terms_list2:
                    current_similarity = self.calculate_formula(term1=go_term1, term2= go_term2)
                    if current_similarity > similarity:
                        similarity=current_similarity

        elif self.mode=="average":

            similarity = 0
            iterations = 0
            for go_term1 in go_terms_list1:
                for go_term2 in go_terms_list2:
                    similarity = similarity + self.calculate_formula(term1=go_term1, term2= go_term2)
                    iterations= iterations + 1

            similarity = similarity / iterations
        # END OF (......) elif self.mode=="average":

        return similarity

    def print_go_graph_dot_file(self, filter_mode="all", output_target= None, use_alternative_id="no", representative_term="min"):
        """
        Generates .dot file output of the clustered graph of go
        It can be fed directly into the GraphViz program neato. This produces very nice network images...
        
        filter_mode can be:
        - "all": prints all edges in the graph
        - "hidden": prints hidden edges of the graph
        - "unhidden": prints unhidden edges of the graph
        
        "use_alternative_id" can be
            - "yes" --> uses alternative id for printing graph
            - "no"  --> uses internal id for printing graph

        "representative_term" sets which of the go terms in the cluster will be shown on the cluster box
           - min takes the term with the minimal depth (more general term)
           - max takes the term with the maximal depth (more specific term)
        """
        # TO DO!!! This should go into a new class GoGraph

        # TO DO!!! Change everywhere highest and lowest, and min and max, for something
        #          easier to understand, such as "specific" and "general"
        #   --> min and max refer to the degree of generality a certain term has, how high or low
        #       it appears in the hierarchy

        # print graph headers
        output_target.write("graph G { graph [orientation=portrait, pack=true, overlap=scale]\n")
        output_target.write(" node [shape=box,fontsize=12,width=0.15,height=0.15,style=filled,fillcolor=lightblue];\n")
        
        for node in self.graph.get_node_object_list():

            elements= node.get_node_attribute_object().get_list_elements()   # returns all elements (ie. source node attributes) in the cluster
            
            for element in elements:
                if element.get_is_root():
                    # print only  nodes that contain one of the root proteins (in order to set their colour)
                    term_elements= []
                    for element2 in elements:
                        term_elements.append(element2.get_term_id())
              
                    go_term=self.get_terms_depth(list_of_terms_ids=term_elements, search_mode="min") # get highest term as id of this cluster

                    # use the node id to differentiate between clusters that have the same go name
                    output_target.write(""""%s.%s" [fillcolor = %s]\n""" %(node.get_node_id(), self.dbaccess.get_protein_go_name(go_term),"yellow"))
                    break
        
        for edge in self.graph.get_edge_object_list():
            start_node=edge.get_start_node_id()
            end_node=edge.get_end_node_id()
            start_node_attributes=self.graph.get_node(start_node).get_node_attribute_object().get_list_elements()
            end_node_attributes=self.graph.get_node(end_node).get_node_attribute_object().get_list_elements()

            start_terms_list=[]
            for node_attribute in start_node_attributes:
                start_terms_list.append(node_attribute.get_term_id())

            end_terms_list=[]
            for node_attribute in end_node_attributes:
                end_terms_list.append(node_attribute.get_term_id())
                
            start_node_go_term=self.get_terms_depth(list_of_terms_ids=start_terms_list, search_mode="min")
            
            end_node_go_term=self.get_terms_depth(list_of_terms_ids=end_terms_list, search_mode="min")

            if start_node_go_term and end_node_go_term: 
                output_target.write(""" "%s.%s" -- "%s.%s" [len=2];\n""" %(start_node, self.dbaccess.get_protein_go_name(start_node_go_term),
                                                                           end_node, self.dbaccess.get_protein_go_name(end_node_go_term)))
        # END OF for edge in self.edges:
                                 
        # print graph termination    
        output_target.write( "}\n")   
