Rudresh Ajgaonkar - 1 year ago 88

Python Question

I have the data in the following format:

I used plotly to generate a scatter plot and then a fit a Curve through the scatter points using the following code.

`from scipy.interpolate import griddata`

import numpy as np

import matplotlib.pyplot as plt

from matplotlib import cm

from mpl_toolkits.mplot3d import Axes3D

x=np.asarray([3,5,9,3,3,7,6,9,1,9]);

y=np.asarray([4,3,3,10,8,2,4,10,9,3]);

z=np.asarray([1,2,4,10,1,7,10,3,1,7]);

# x = np.random.random(100)

xi=np.linspace(min(x), max(x),50)

#print xi

yi=np.linspace(min(y),max(y),50)

X,Y= np.meshgrid(xi,yi)

Z = np.nan_to_num(griddata((x,y), z, (X, Y), method='cubic'))

fig = plt.figure()

ax = fig.add_subplot(111, projection='3d')

ax.scatter(x, y, z)

ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,

linewidth=0, antialiased=False,alpha=0.4)

plt.show()

What i am looking to do is to color the plot according to categories something like this :

Where red represents the category 1 and Blue represents category 2.

So inorder to get something like this I need to generate a 2D Array and then use a colormap/colorscale to color the categories accordingly.

The above output have been created using

Can someone explain me how do i generate the Z data which will help me color the categories differently?

I have tried to something like dividing the 2D matrix into halves 0's and half 1's and got output something like this.

Considering the following sample data :

`x y z Category`

3 4 1 Cat 1

5 3 2 cat2

9 3 4 cat2

3 10 10 cat3

3 8 1 cat3

7 2 7 cat2

6 4 10 Cat 1

9 10 3 Cat 4

1 9 1 Cat 1

9 3 7 cat2

Answer Source

Just as `griddata`

can be used to interpolate the 1D `z`

array to a 2D grid, you can use `griddata`

to interpolate a 1D `color`

array to the same 2D grid:

```
color = [colormap[cat] for cat in category]
C = np.nan_to_num(griddata((x, y), color, (X, Y), method='cubic'))
```

Then you can use the colormap `cm.coolwarm`

to map values in `C`

to RGBA `facecolors`

:

```
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cmap,
linewidth=0, antialiased=False, alpha=0.4, facecolors=cmap(C))
```

```
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata
x = np.asarray([3, 5, 9, 3, 3, 7, 6, 9, 1, 9])
y = np.asarray([4, 3, 3, 10, 8, 2, 4, 10, 9, 3])
z = np.asarray([1, 2, 4, 10, 1, 7, 10, 3, 1, 7])
category = np.array(['Cat 1', 'cat2', 'cat2', 'cat3', 'cat3',
'cat2', 'Cat 1', 'Cat 4', 'Cat 1', 'cat2'])
# coolwarm: 0 --> blue, 1 --> red
# want: 'Cat 1' --> blue, 'cat2' --> red, 'cat3' --> ?, 'Cat 4' --> ?
colormap = {'Cat 1': 0, 'cat2': 1, 'cat3': 0.333, 'Cat 4': 0.666}
color = [colormap[cat] for cat in category]
xi = np.linspace(min(x), max(x), 50)
yi = np.linspace(min(y), max(y), 50)
X, Y = np.meshgrid(xi, yi)
Z = np.nan_to_num(griddata((x, y), z, (X, Y), method='cubic'))
C = np.nan_to_num(griddata((x, y), color, (X, Y), method='cubic'))
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
cmap = cm.coolwarm
ax.scatter(x, y, z, c=color, cmap=cmap)
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cmap,
linewidth=0, antialiased=False, alpha=0.4, facecolors=cmap(C))
plt.show()
```