# -*- coding: utf8 -*-
'''
BASIC GRAPH CLASS
***********************************
*** 
*** BASIC GRAPH CLASS
*** 
*** written by Markus Doering
***    BGBM, Berlin, 2003
*** 
***********************************
$RCSfile: graph.py,v $
$Revision: 1231 $
$Author: j.holetschek $
$Date: 2012-11-21 15:12:34 +0100 (Mi, 21. Nov 2012) $

Some classes used for the graph class, a basic class for undirected graphs.
It stores the graph as a matrix with the matrix values being a dictionary if filled and None if not.
'''


import copy
from biocase.wrapper.graph.matrix     import *
from biocase.tools.various_functions   import unique

import logging
log = logging.getLogger("pywrapper.graph")


# ERROR CLASSES
# ----------------------------------------------------------
class NodeNotExistingError(Exception):
    """Exception raised when trying to reference a non existing node of a graph."""
    def __init__(self):
        Exception.__init__(self)
class NodeAlreadyExistingError(Exception):
    """Exception raised when trying to add an already existing node of a graph."""
    def __init__(self):
        Exception.__init__(self)
class EdgeNotExistingError(Exception):
    """Exception raised when trying to reference a non existing edge of a graph."""
    def __init__(self):
        Exception.__init__(self)
class EdgeAlreadyExistingError(Exception):
    """Exception raised when trying to add an already existing edge of a graph."""
    def __init__(self):
        Exception.__init__(self)



# GRAPH CLASS
# ----------------------------------------------------------
class graph:
    '''A basic class for undirected graphs. It stores the graph as a matrix with the matrix values being a dictionary if filled and None if not.'''
    def __init__(self):
        self.matrix = symmetricMatrix([])    # the matrix storing the graph. all empty entries are set to None, filled entries contain a dictionary.
        self.edges = {}        # names of all used edges. key=edgename, value=tuple of 2 connected nodes.
        

    # ----- NODES -----
    def addNode(self, node):
        '''add one node to the graph without edges.'''
        try:
            self.matrix.addRow(node)
        except:
            raise NodeAlreadyExistingError()
            
    def addNodes(self, nodes):
        '''add a list of nodes to the graph.'''
        for node in nodes:
            self.addNode(node)
    
    def delNode(self, node):
        '''remove node from graph. also remove all connected edges.'''
        try:
            adjEdges = [e[0] for e in self.listAdjacentEdges(node)] # all edges of one node as tuples (edge-name, other-node)
            for edgename in adjEdges:
                del self.edges[edgename]
            self.matrix.delRow(node)
        except:
            raise NodeNotExistingError()
            
    def nodeExists(self, node):
        if not self.matrix._cols.has_key(node):
            return 0
        return 1
        
    def listNodes(self):
        '''return a list of all nodes (vertices) of the graph.'''
        return self.matrix.listRows()
    
    def listAdjacentNodes(self, node):
        '''List all adjacent nodes.'''
        adjNodes = []
        rowDict = self.matrix.getRow(node)
        for col, edges in rowDict.items():
            if type(edges) == type({}) and len(edges)>0:
                adjNodes.append(col)
        return adjNodes
        
    def listConnectedNodes(self, node):
        '''List all nodes that are connected to "node".'''
        self._tmp = []
        self.DepthFirstSearch(node, self._listNodes, {})
        return self._tmp
        
    def listUnconnectedNodes(self, node):
        '''List all nodes that are NOT connected to "node".'''
        unconNodes = {}
        # get all nodes into dictionary
        for n in self.listNodes():
            unconNodes[n] = 1
        # remove connected nodes
        for n in self.listConnectedNodes(node):
            del unconNodes[n]
        return unconNodes.keys()
        
    def listIsolatedNodes(self):
        '''return a list of isolated nodes, that is nodes not connected with edges.'''
        isolatedNodes = []
        for node in self.matrix.listRows():
            rowDict = self.matrix.getRow(node)
            isolated = 0
            for value in rowDict.values():
                if value <> None:
                    isolated = 1
                    break
            if isolated == 0:
                isolatedNodes.append(node)
        return isolatedNodes
    
    
    # ----- EDGES -----
    def addPath(self, nodes):
        '''connects several nodes at once, setting the edge value-dictionary to en empty dict and autonaming the edges.'''
        first = nodes[0]
        for last in nodes[1:]:
            self.addEdge((first,last))
            first = last
        
    def addEdge(self, nodes, value={}, name=None):
        '''adds an edge "name" between node1 and node2 with the relation value of "value". Naming edges is optional, leaving out edge-names will result in autonumbering them.'''
        if name == None:
            # autonumber new edge
            oldIntEdges = []
            for edge in self.edges.keys():
                if type(edge)==type(1):
                    oldIntEdges.append(edge)
            if len(oldIntEdges) > 0:
                name = max(oldIntEdges)+1
            else:
                name = 1
        if self.edges.has_key(name):
            raise EdgeAlreadyExistingError()
        else:
            self.edges[name] = nodes
            EdgeDict = self.matrix.get(nodes)
            if type(EdgeDict)==type(None):
                EdgeDict={}
            EdgeDict[name] = value
            self.matrix.set( nodes, EdgeDict)
        # return name of edge
        return name

    def listEdges(self):
        '''return dictionary of all existing edges with their nodetuple as values.'''
        return self.edges
        
    def listEdgesBetweenNodes(self, (node1, node2)):
        '''return a list of all edges between two nodes.'''
        return self.matrix.get( (node1, node2) ).keys()
    
    def listAdjacentEdges(self, node):
        '''list all edges of one node as tuples (edge-name, other-node)'''
        adjEdges = []
        rowDict = self.matrix.getRow(node)
        for col, edges in rowDict.items():
            if type(edges) == type({}):
                for edge in edges.keys():
                    adjEdges.append((edge, col))
        return adjEdges
    
    def getEdge(self, name):
        '''return tuple of 1.value of edge "name" and 2.tuple of connected nodes.'''
        try:
            nodeTuple = self.edges[name]
        except:
            raise EdgeNotExistingError()
        return ((self.matrix.get(nodeTuple)[name], nodeTuple))
        
    def delEdge(self, name):
        '''delete edge "name".'''
        if not self.edges.has_key(name):
            raise EdgeNotExistingError()
        else:
            nodes = self.edges[name]
            del self.edges[name]
            EdgeDict = self.matrix.get(nodes)
            del EdgeDict[name]
            self.matrix.set(nodes, EdgeDict)
    
    def cut2Nodes(self, nodes):
        '''delete all edges between two adjacent nodes.'''
        edgesToDel = []
        edges = self.matrix.get(nodes)
        if type(edges)==type({}):
            edgesToDel = edges.keys()
        for edge in edgesToDel:
            self.delEdge(edge)

    # ----- GENERAL -----
    def makeTree(self, root):
        if self.hasClosedWalk():
            raise CircleExistingError()
        else:
            treeObj = tree()
            newMatrix = copy.deepcopy(self.matrix)
            treeObj.matrix = newMatrix
            edges = copy.deepcopy(self.edges)
            treeObj.edges  = edges
            treeObj.root = root
            # remove nodes not connected to root
            for node in self.listUnconnectedNodes(root):
                treeObj.delNode(node)
            return treeObj

    def degree(self, node):
        '''return the degree (number of edges connected) of a node.'''
        row = self.matrix.getRow(node)
        degree = 0
        for col in row.values():
            if type(col) == type({}):
                # edges existing
                degree += len(col)
        return degree
        
    def hasClosedWalk(self):
        '''returns 1 if the graph allows circular walks - that is its number of possible spanning trees is > 1.'''
        # heuristic function using recursion.
        for node in self.listNodes():
            # only check nodes with at least 2 edges !
            if self.degree(node) > 1:
                if self._testClosed(node,node,{}) > 0:
                    return 1
        return 0
    
    def isConnected(self):
        '''returns 1 if the graph is connected, that is if all nodes are connected.'''
        pass


    # ----- TRAVERSAL ALGORITHMS -----
    def BreadthFirstSearch(self, start, func):
        '''Traverse the graph in breadth-first search order and apply the function func to every node..'''
        queue = [start]
        seen = {}
        while len(queue)>0:
            node = queue.pop(0)
            seen[node]=1
            func(node)
            for neighbor in self.listAdjacentNodes(node):
                if not seen.has_key(neighbor):
                    queue.append(neighbor)

    def _listNodes(self, node):
        self._tmp.append(node)

    def DepthFirstSearch(self, start, func, _visited={}):
        '''Traverse the graph in depth-first search order and apply the function func to every node..'''
        _visited[start]=1
        func(start)
        for node in self.listAdjacentNodes(start):
            if not _visited.has_key(node):
                self.DepthFirstSearch(node, func, _visited)

    def shortestPath(self, node1, node2):
        '''find the shortest possible path of two nodes using the single-source or all-pair shortest path (SSSP/APSP) algorithm.'''
        pass

    def minimumSpanningTree(self, node1, node2):
        '''find the minimum spanning tree using Kruskal's or Prim's algorithm.'''
        pass

    def EulerPath(self):
        '''find the Euler path.'''
        pass

    def HamiltonianPath(self):
        '''find the Euler path.'''
        pass
        

    # ----- OUTPUT -----
    def __repr__(self):
        '''return simple string output of the graph.'''
        return self.__class__.__name__ +': ' + self.__reprGraph__()

    def __reprGraph__(self):
        '''return simple string output of the graph.'''
        result = ''
        for node in self.listIsolatedNodes():
            result += node+", "
        for edge,nodes in self.edges.items():
            result += nodes[0]+"-"+nodes[1]+", "
        return result

    # ----- INTERNAL ONLY ------
    def _testClosed(self, startNode, currNode, dictOfUsedEdges):
        '''recursive testing function for each node'''
        edges=self.listAdjacentEdges(currNode) # list of (edge-name, other-node) tuple
        for edge in edges:
            # ignore used edges
            if not dictOfUsedEdges.has_key(edge[0]):
                if edge[1]==startNode:
                    # start node found. circular walks possible
                    return 1
                else:
                    # recursively test all adjacent nodes
                    newDict=dictOfUsedEdges.copy()
                    newDict[edge[0]]=1
                    if self._testClosed(startNode, edge[1], newDict) > 0:
                        return 1
            pass
        return 0

    
    

# ERROR CLASSES FOR TREES ONLY
# ----------------------------------------------------------
class CircleExistingError(Exception):
    """Exception raised when trying to change the tree into a graph containing circles."""
    def __init__(self):
        Exception.__init__(self)
class RootMissingError(Exception):
    """Exception raised when trying to change the tree into a graph containing circles."""
    def __init__(self):
        Exception.__init__(self)

# ----------------------------------------------------------
class tree(graph):
    '''.'''
    def __init__(self):
        graph.__init__(self)
        self.root = None
        
    def addEdge(self, nodes, value={}, name=None):
        '''add edge, but check first if graph is still a tree afterwards.'''
        testGraph = self
        graph.addEdge(testGraph, nodes, value, name)
        if testGraph.hasClosedWalk():
            raise CircleExistingError()
        else:
            graph.addEdge(self, nodes, value, name)

    def pathBetween2Nodes(self, (start, end)):
        '''return a path (list of nodes) with all nodes between two given.'''
        if self.nodeExists(start) == 0:
            raise NodeNotExistingError()
        if self.nodeExists(start) == 0:
            raise NodeNotExistingError()
        nodesVisited = [start]
        if start==end:
            return [start]
        result = self._DFSSearchNode(start, end, nodesVisited)
        if result <> None:
            return result
        else:
            return []
        
    def pathConnectingNodes(self, nodes):
        '''return a path (list of nodes) connecting all nodes given.'''
        start = nodes[0]
        result=[]
        for next in nodes[1:]:
            result += self.pathBetween2Nodes( (start,next) )
        return unique(result)

    def listOrderedNodesConnectedViaRoot(self, nodes=[]):
        '''return a list of tuples (node/edge-value) connecting all nodes given to the root node. 
        Order them from root downwards according to distance from root.'''
        if self.root == None:
            raise RootMissingError()
        result=[]
        paths=[[self.root]]
        nodes=unique(nodes)
        # get path to root for every node
        for node in nodes:
            paths.append(self.pathBetween2Nodes((self.root,node)))
        # reorder nodes to distance from root.
        for depth in range(max([len(x) for x in paths])):
            for path in paths:
                try:
                    result.append(path[depth])
                except:
                    pass
        # eliminate duplicates
        return unique(result)

    def __repr__(self):
        return "Rooted Tree (" +self.root +"): " + self.__reprGraph__()
    
    # ----- INTERNAL -----
    def _DFSSearchNode(self, currNode, endNode, visitedNodes):
        for node in self.listAdjacentNodes(currNode):
            if node==endNode:
                visitedNodes.append(node)
                return visitedNodes
            elif node not in visitedNodes:
                NewVisitedNodes = visitedNodes[:]
                NewVisitedNodes.append(node)
                deeper=self._DFSSearchNode(node, endNode, NewVisitedNodes)
                if deeper <> None:
                    return deeper
        return None
    


    
# -----------------------------------------------------------
if __name__ ==  "__main__":
    import biocase.tools.showobject as obj
    g = graph()
    g.addNode('heio')
    print str(g)
    
    g.addNodes(['markus','pia','brian','silke','peter','stefan','ewald','melitta','irene','paul'])
    g.addEdge( ('markus','pia'), value='verliebt')
    g.addEdge( ('brian','pia'), value='ex')
    g.addEdge( ('markus','silke'), value='geschwister')
    g.addEdge( ('markus','pia'), value='schlimm verliebt')
    g.addEdge( ('brian','silke'), value='bald verliebt', name='SilBri')
    g.addPath( ('pia','irene','paul','melitta','ewald','stefan') )
    
    print obj.view(g)
    print
    print "-"*80
    print "All Nodes %s"%(g.listNodes())
    print "All Edges %s"%(g.listEdges())
    
    print "Degree(markus): %i" %(g.degree('markus'))
    print "adjacentNodes(markus): %s" %(g.listAdjacentNodes('markus'))
    print "adjacentEdges(pia): %s" %([e[0] for e in g.listAdjacentEdges('pia')])
    print "graph has a closed walk ? %i" %(g.hasClosedWalk())
    print "isolated nodes: %s" %(g.listIsolatedNodes())
    print "print graph: %s" %(str(g))
    
    g.delEdge(1)
    print
    print "deleted edge 1"
    print "All Nodes %s"%(g.listNodes())
    print "All Edges %s"%(g.listEdges())
    print "graph has a closed walk ? %i" %(g.hasClosedWalk())
    print "print graph: %s" %(str(g))
    
    g.cut2Nodes(('silke','brian'))
    print
    print "cut off nodes silke & brian"
    print "All Nodes %s"%(g.listNodes())
    print "All Edges %s"%(g.listEdges())
    print "graph has a closed walk ? %i" %(g.hasClosedWalk())
    print "print graph: %s" %(str(g))
    g.cut2Nodes(('silke','silke'))
    print
    print "cut off nodes silke & silke"
    print "All Nodes %s"%(g.listNodes())
    print "All Edges %s"%(g.listEdges())
    print "graph has a closed walk ? %i" %(g.hasClosedWalk())
    print "print graph: %s" %(str(g))
    print
    print "##### TREE NOW !!! #####"
    t = g.makeTree('markus')
    print str(t)
    print "markus-irene: %s" %(t.pathBetween2Nodes(('markus','irene')) )
    print "markus-irene-silke: %s" %(t.pathConnectingNodes(('markus','irene','silke')) )
    print "root brian: %s" %(t.pathConnectingNodes([t.root,'brian']) )
    print "root-ordered paul,brian: %s" %(t.listOrderedNodesConnectedViaRoot(['paul','brian']) )
    

