BPL BPL - 2 months ago 13
Python Question

How to know node execution (before, middle, after calls) using DFS

Let's say I got implemented a simple version of an iterative DFS like this:

import sys
import traceback


def dfs(graph, start):
visited, stack = [], [start]
while stack:
node = stack.pop()

if node not in visited:
visited.append(node)
childs = reversed(graph.get(node, list()))
stack.extend([item for item in childs if item not in visited])

return visited

if __name__ == "__main__":
graphs = [
{
'A': ['B', 'C'],
'B': ['D']
}
]

for i, g in enumerate(graphs):
try:
print "{0}Graph {1}{2}".format('-' * 40, i, '-' * 33)

for f in [dfs]:
print f.__name__, '-->', f(g, 'A')
print '-' * 80
except Exception as e:
print "Exception in user code: {0}".format(e)
print '-' * 60
traceback.print_exc(file=sys.stdout)
print '-' * 60


The output of the above snippet is this:

----------------------------------------Graph 0---------------------------------
dfs --> ['A', 'B', 'D', 'C']
--------------------------------------------------------------------------------


Now, I'm trying to figure out how to get the following output (instead running node's method just printing is fine):

A_start, B_start, D_start, D_end, B_end, A_middle, C_start, C_end, A_end


*_middle will only be executed between subnodes execution. For instance, if a node doesn't have any subnodes, or has only a single one, it never gets executed. That's why my desired output only has A_middle (none of the B_middle, C_middle, D_middle) in the above example.

How can I do this?

EDIT:

Trying to find the recursive solution to my problem:

def dfs(graph, node):
if node not in graph:
return

print '{0}_start'.format(node)

for i, node in enumerate(graph[node]):
if i > 0:
print '{0}_middle'.format(node)

dfs(graph, node)

print '{0}_end'.format(node)

if __name__ == "__main__":
graphs = [
{
'A': ['B', 'C'],
'B': ['D']
}
]

for i, g in enumerate(graphs):
try:
print "{0}Graph {1}{2}".format('-' * 40, i, '-' * 33)

for f in [dfs]:
print f.__name__, '-->'
f(g, 'A')
print '-' * 80
except Exception as e:
print "Exception in user code: {0}".format(e)
print '-' * 60
traceback.print_exc(file=sys.stdout)
print '-' * 60


Will give me the wrong output:

----------------------------------------Graph 0---------------------------------
dfs -->
A_start
B_start
D_end
C_middle
C_end
--------------------------------------------------------------------------------


EDIT2:

I will explain the reasons about my validated answer. First of all, I do really appreciate the effort from all you guys, @Gerrat, @ffledgling and @Blckknght

1) Gerrrat's -> It meets the requirements and it was the first answer. At first of my post I had only posted the iterative version but then I edited my question and said a recursive one would be good enough. One little note, when dealing with graphs containing cycles will spawn
maximum recursion depth exceeded
though. It wouldn't be a bad thing whether I could iterate/repeat infinitely over the graph, take a look to the nodes called
repeat
of this tool

2) ffledgling -> It doesn't meet the requirements, it's giving this output
A_start, B_start, D_start, D_end, B_middle, B_end, A_middle, C_start, C_end, A_middle, A_end
and it should give
A_start, B_start, D_start, D_end, B_end, A_middle, C_start, C_end, A_end
instead

3) Blckknght -> It meets the requirements. It answers the first version I had posted and it won't crash with cyclic graphs, even if it has been the last answer, I think it deserves to be validated.

NS: It's not an easy decission to validate answers like these ones when they've shown a good effort, I've upvoted all of them though. In case of any concern, please let me know.

Answer

As the other answers have shown, the main issue with your current recursive code is the base case:

if node not in graph:
    return

This incorrectly skips the output when there are no children from a node. Get rid of those lines and, just use enumerate(graph.get(start, [])) instead of enumerate(graph[start]) in the for loop and it should work as desired.

Making your iterative code work is quite a bit more complicated. One way of attempting it would be to push 2-tuples to the stack. The first value would a node, as before, but the second will be either a predecessor of the node (so we can print a middle message for the parent), or None indicating that we need to print the end marker for the node.

However, keeping track of which nodes have been visited gets a bit more complicated. Rather than a single list of nodes, I'm using a dictionary mapping from node to an integer. A non-existent value means the node has not yet been visited. A one means the node has been visited and it's start message has been printed. A 2 means that at least one of the node's children has been visited, and each further child should print a middle message on the parent's behalf. A 3 means the end message has been printed.

def dfs(graph, start):
    visited = {}
    stack = [(start, "XXX_THIS_NODE_DOES_NOT_EXIST_XXX")]
    while stack:
        node, parent = stack.pop()
        if parent is None:
            if visited[node] < 3:
                print "{}_end".format(node)
            visited[node] = 3

        elif node not in visited:
            if visited.get(parent) == 2:
                print "{}_middle".format(parent)
            elif visited.get(parent) == 1:
                visited[parent] = 2

            print "{}_start".format(node)
            visited[node] = 1
            stack.append((node, None))
            for child in reversed(graph.get(node, [])):
                if child not in visited:
                    stack.append((child, node))

Because I'm using an dictionary for visited, returning it at the end is probably not appropriate, so I've removed the return statement. I think you could restore it if you really wanted to by using a collections.OrderedDict rather than a normal dict, and returning its keys().

Comments