```#!/usr/bin/env python

from __future__ import division
import numpy as np
from scipy.linalg import hankel
from scipy.signal import convolve2d
import scipy.io as sio
import matplotlib.pyplot as plt

from ..tools import *

def bispectrumdx(x, y, z, nfft=None, wind=None, nsamp=None, overlap=None):
"""
Parameters:
x    - data vector or time-series
y    - data vector or time-series  (same dimensions as x)
z    - data vector or time-series  (same dimensions as x)
nfft - fft length [default = power of two > segsamp]
wind - window specification for frequency-domain smoothing
if 'wind' is a scalar, it specifies the length of the side
of the square for the Rao-Gabr optimal window  [default=5]
if 'wind' is a vector, a 2D window will be calculated via
w2(i,j) = wind(i) * wind(j) * wind(i+j)
if 'wind' is a matrix, it specifies the 2-D filter directly
segsamp - samples per segment [default: such that we have 8 segments]
- if x is a matrix, segsamp is set to the number of rows
overlap - percentage overlap, allowed range [0,99]. [default = 50];
- if x is a matrix, overlap is set to 0.

Output:
Bspec   - estimated bispectrum: an nfft x nfft array, with origin
at the center, and axes pointing down and to the right.
waxis   - vector of frequencies associated with the rows and columns
of Bspec;  sampling frequency is assumed to be 1.
"""

(lx, lrecs) = x.shape
(ly, nrecs) = y.shape
(lz, krecs) = z.shape

if lx != ly or lrecs != nrecs or ly != lz or nrecs != krecs:
raise Exception('x, y and z should have identical dimensions')

if ly == 1:
x = x.reshape(1,-1)
y = y.reshape(1,-1)
z = z.reshape(1,-1)
ly = nrecs
nrecs = 1

if not overlap: overlap = 50
overlap = max(0,min(overlap,99))
if nrecs > 1: overlap = 0
if not nsamp: nsamp = 0
if nrecs > 1: nsamp = ly
if nrecs == 1 and nsamp <= 0:
nsamp = np.fix(ly/ (8 - 7 * overlap/100))
if nfft < nsamp:
nfft = 2**nextpow2(nsamp)

overlap = np.fix(overlap/100 * nsamp)
nadvance = nsamp - overlap
nrecs = np.fix((ly*nrecs - overlap) / nadvance)

# create the 2-D window
if not wind: wind = 5

m = n = 0
try:
(m, n) = wind.shape
except ValueError:
(m,) = wind.shape
n = 1
except AttributeError:
m = n = 1

window = wind
# scalar: wind is size of Rao-Gabr window
if max(m, n) == 1:
winsize = wind
if winsize < 0: winsize = 5 # the window size L
winsize = winsize - (winsize%2) + 1 # make it odd
if winsize > 1:
mwind = np.fix(nfft/winsize) # the scale parameter M
lby2 = (winsize - 1)/2

theta = np.array([np.arange(-1*lby2, lby2+1)]) # force a 2D array
opwind = np.ones([winsize, 1]) * (theta**2) # w(m,n) = m**2
opwind = opwind + opwind.transpose() + (np.transpose(theta) * theta) # m**2 + n**2 + mn
opwind = 1 - ((2*mwind/nfft)**2) * opwind
Hex = np.ones([winsize,1]) * theta
Hex = abs(Hex) + abs(np.transpose(Hex)) + abs(Hex + np.transpose(Hex))
Hex = (Hex < winsize)
opwind = opwind * Hex
opwind = opwind * (4 * mwind**2) / (7 * np.pi**2)
else:
opwind = 1

# 1-D window passed: convert to 2-D
elif min(m, n) == 1:
window = window.reshape(1,-1)

if np.any(np.imag(window)) != 0:
print "1-D window has imaginary components: window ignored"
window = 1

if np.any(window) < 0:
print "1-D window has negative components: window ignored"
window = 1

lwind = np.size(window)
w = window.ravel(order='F')
# the full symmetric 1-D
windf = np.array(w[range(lwind-1, 0, -1) + [window]])
window = np.array([window], np.zeros([lwind-1,1]))
# w(m)w(n)w(m+n)
opwind = (windf * np.transpose(windf)) * hankel(np.flipud(window), window)
winsize = np.size(window)

# 2-D window passed: use directly
else:
winsize = m

if m != n:
print "2-D window is not square: window ignored"
window = 1
winsize = m

if m%2 == 0:
print "2-D window does not have odd length: window ignored"
window = 1
winsize = m

opwind = window

# accumulate triple products
Bspec = np.zeros([nfft, nfft]) # the hankel mask (faster)
locseg = np.arange(nsamp).transpose()
x = x.ravel(order='F')
y = y.ravel(order='F')
z = z.ravel(order='F')

for krec in xrange(nrecs):
xseg = x[locseg].reshape(1,-1)
yseg = y[locseg].reshape(1,-1)
zseg = z[locseg].reshape(1,-1)

Xf = np.fft.fft(xseg - np.mean(xseg), nfft) / nsamp
Yf = np.fft.fft(yseg - np.mean(yseg), nfft) / nsamp
CZf = np.fft.fft(zseg - np.mean(zseg), nfft) / nsamp
CZf = np.conjugate(CZf).ravel(order='F')

Bspec = Bspec + \
flat_eq(Bspec, (Xf * np.transpose(Yf)) * CZf[mask].reshape(nfft, nfft))
locseg = locseg + int(nadvance)

Bspec = np.fft.fftshift(Bspec) / nrecs

# frequency-domain smoothing
if winsize > 1:
lby2 = int((winsize-1)/2)
Bspec = convolve2d(Bspec,opwind)
Bspec = Bspec[range(lby2+1,lby2+nfft+1), :][:, np.arange(lby2+1,lby2+nfft+1)]

if nfft%2 == 0:
waxis = np.transpose(np.arange(-1*nfft/2, nfft/2)) / nfft
else:
waxis = np.transpose(np.arange(-1*(nfft-1)/2, (nfft-1)/2+1)) / nfft

# cont1 = plt.contour(abs(Bspec), 4, waxis, waxis)
cont = plt.contourf(waxis, waxis, abs(Bspec), 100, cmap=plt.cm.Spectral_r)
plt.colorbar(cont)
plt.title('Bispectrum estimated via the direct (FFT) method')
plt.xlabel('f1')
plt.ylabel('f2')
plt.show()

return (Bspec, waxis)

def test():
nl1 = sio.loadmat(here(__file__) + '/demo/nl1.mat')
dbic = bispectrumdx(nl1['x'], nl1['x'], nl1['y'], 128,5)

if __name__ == '__main__':
test()

```