Mike Mike -4 years ago 204
Python Question

np.roll vs scipy.interpolation.shift--discrepancy for integer shift values

I wrote some code to shift an array, and was trying to generalize it to handle non-integer shifts using the "shift" function in

scipy.ndimage
. The data is circular and so the result should wrap around, exactly as the
np.roll
command does it.

However,
scipy.ndimage.shift
does not appear to wrap integer shifts properly. The following code snippet shows the discrepancy:

import numpy as np
import scipy.ndimage as sciim
import matplotlib.pyplot as plt

def shiftfunc(data, amt):
return sciim.interpolation.shift(data, amt, mode='wrap', order = 3)

if __name__ == "__main__":
xvals = np.arange(100)*1.0

yvals = np.sin(xvals*0.1)

rollshift = np.roll(yvals, 2)

interpshift = shiftfunc(yvals, 2)

plt.plot(xvals, rollshift, label = 'np.roll', alpha = 0.5)
plt.plot(xvals, interpshift, label = 'interpolation.shift', alpha = 0.5)
plt.legend()
plt.show()


roll vs shift

It can be seen that the first couple of values are highly discrepant, while the rest are fine. I suspect this is an implementation error of the prefiltering and interpolation operation when using the
wrap
option. A way around this would be to modify
shiftfunc
to revert to np.roll when the shift value is an integer, but this is unsatisfying.

Am I missing something obvious here?

Is there a way to make
ndimage.shift
coincide with
np.roll
?

Answer Source

I dont think there is anything wrong with the shift function. when you use roll, your need to chop an extra element for fair comparision. please see the code below.

import numpy as np
import scipy.ndimage as sciim
import matplotlib.pyplot as plt 


def shiftfunc(data, amt):
    return sciim.interpolation.shift(data, amt, mode='wrap', order = 3)

def rollfunc(data,amt):
    rollshift   = np.roll(yvals, amt)
    # Here I remove one element (first one before rollshift) from the array 
    return np.concatenate((rollshift[:amt], rollshift[amt+1:]))

if __name__ == "__main__":
    shift_by = 5
    xvals = np.linspace(0,2*np.pi,20)
    yvals = np.sin(xvals)
    rollshift   = rollfunc(yvals, shift_by)
    interpshift = shiftfunc(yvals,shift_by)
    plt.plot(xvals, yvals, label = 'original', alpha = 0.5)
    plt.plot(xvals[1:], rollshift, label = 'np.roll', alpha = 0.5,marker='s')
    plt.plot(xvals, interpshift, label = 'interpolation.shift', alpha = 0.5,marker='o') 
    plt.legend()
    plt.show()

results in

enter image description here

Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download