CF84 - 1 year ago 116
Python Question

Networkx: all Spanning Trees and their associated total weight

Given a simple undirected grid network like this:

``````import networkx as nx
from pylab import *
import matplotlib.pyplot as plt
%pylab inline

ncols=3
N=3
G=nx.grid_2d_graph(N,N)
labels = dict( ((i,j), i + (N-1-j) * N ) for i, j in G.nodes() )
nx.relabel_nodes(G,labels,False)
inds=labels.keys()
vals=labels.values()
inds=[(N-j-1,N-i-1) for i,j in inds]
pos2=dict(zip(vals,inds))
nx.draw_networkx(G, pos=pos2, with_labels=True, node_size = 200, node_color='orange',font_size=10)
plt.axis('off')
plt.title('grid')
plt.show()
``````

And given that each edge has a weight corresponding to its length:

``````#Weights
from math import sqrt

weights = dict()
for source, target in G.edges():
x1, y1 = pos2[source]
x2, y2 = pos2[target]
weights[(source, target)] = round((math.sqrt((x2-x1)**2 + (y2-y1)**2)),3)

for e in G.edges():
G[e[0]][e[1]] = weights[e] #Assigning weights to G.edges()
``````

How could it be possible to compute all spanning trees in the grid, and their associated total weight?

NB: this is a trivial case where all weights=1.

This took way longer than expected, but the following code finds all spanning trees for the general case. Getting the associated total weight should be trivial, as you have access to the edgelist of each tree.

Don't use this on very large trees -- even the toy example yields 192 spanning trees.

``````import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

def _expand(G, explored_nodes, explored_edges):
"""
Expand existing solution by a process akin to BFS.

Arguments:
----------
G: networkx.Graph() instance
full graph

explored_nodes: set of ints
nodes visited

explored_edges: set of 2-tuples
edges visited

Returns:
--------
solutions: list, where each entry in turns contains two sets corresponding to explored_nodes and explored_edges
all possible expansions of explored_nodes and explored_edges

"""
frontier_nodes = list()
frontier_edges = list()
for v in explored_nodes:
for u in nx.neighbors(G,v):
if not (u in explored_nodes):
frontier_nodes.append(u)
frontier_edges.append([(u,v), (v,u)])

return zip([explored_nodes | frozenset([v]) for v in frontier_nodes], [explored_edges | frozenset(e) for e in frontier_edges])

def find_all_spanning_trees(G, root=0):
"""
Find all spanning trees of a Graph.

Arguments:
----------
G: networkx.Graph() instance
full graph

Returns:
ST: list of networkx.Graph() instances
list of all spanning trees

"""

# initialise solution
explored_nodes = frozenset([root])
explored_edges = frozenset([])
solutions = [(explored_nodes, explored_edges)]
# we need to expand solutions number_of_nodes-1 times
for ii in range(G.number_of_nodes()-1):
# get all new solutions
solutions = [_expand(G, nodes, edges) for (nodes, edges) in solutions]
# flatten nested structure and get unique expansions
solutions = set([item for sublist in solutions for item in sublist])

return [nx.from_edgelist(edges) for (nodes, edges) in solutions]

if __name__ == "__main__":

N = 3
G = nx.grid_2d_graph(N,N)
labels = dict( ((i,j), i + (N-1-j) * N ) for i, j in G.nodes() )
nx.relabel_nodes(G,labels,False)
inds=labels.keys()
vals=labels.values()
inds=[(N-j-1,N-i-1) for i,j in inds]
pos2=dict(zip(vals,inds))

fig, ax = plt.subplots(1,1)
nx.draw_networkx(G, pos=pos2, with_labels=True, node_size = 200, node_color='orange',font_size=10,ax=ax)
plt.axis('off')
plt.title('grid')

ST = find_all_spanning_trees(G)
print len(ST)

for g in ST:
fig, ax = plt.subplots(1,1)
nx.draw_networkx(g, pos=pos2, with_labels=True, node_size = 200, node_color='orange',font_size=10,ax=ax)
plt.axis('off')
plt.title('grid')
plt.show()
``````
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download