```"""Classes to define the lattice structure of a model.

The base class :class:`Lattice` defines the general structure of a lattice,
you can subclass this to define you own lattice.
The :class:`SimpleLattice` is a slight simplification for lattices with a single-site unit cell.
Further, we have some predefined lattices, namely
:class:`Square`, :class:`Triangular`, :class:`Honeycomb`, and :class:`Kagome` in 2D.

The :class:`IrregularLattice` provides a way to remove or add sites to an existing, regular
lattice.

"""
# Copyright 2018-2020 TeNPy Developers, GNU GPLv3

import numpy as np
import itertools
import warnings

from ..networks.site import Site
from ..tools.misc import to_iterable, to_iterable_of_len, inverse_permutation
from ..networks.mps import MPS  # only to check boundary conditions

__all__ = [
'Lattice', 'TrivialLattice', 'IrregularLattice', 'SimpleLattice', 'Chain', 'Ladder', 'Square',
'Triangular', 'Honeycomb', 'Kagome', 'get_lattice', 'get_order', 'get_order_grouped'
]

# (update module doc string if you add further lattices)

bc_choices = {'open': True, 'periodic': False}
"""dict: maps possible choices of boundary conditions in a lattice to bool/int."""

class Lattice:
r"""A general, regular lattice.

The lattice consists of a **unit cell** which is repeated in `dim` different directions.
A site of the lattice is thus identified by **lattice indices** ``(x_0, ..., x_{dim-1}, u)``,
where ``0 <= x_l < Ls[l]`` pick the position of the unit cell in the lattice and
``0 <= u < len(unit_cell)`` picks the site within the unit cell. The site is located
in 'space' at ``sum_l x_l*basis[l] + unit_cell_positions[u]`` (see :meth:`position`).
(Note that the position in space is only used for plotting, not for defining the couplings.)

In addition to the pure geometry, this class also defines an `order` of all sites.
This order maps the lattice to a finite 1D chain and defines the geometry of MPSs and MPOs.
The **MPS index** `i` corresponds thus to the lattice sites given by
``(x_0, ..., x_{dim-1}, u) = tuple(self.order[i])``.
Infinite boundary conditions of the MPS repeat in the first spatial direction of the lattice,
i.e., if the site at ``(x_0, x_1, ..., x_{dim-1},u)`` has MPS index `i`, the site at
at ``(x_0 + Ls, x_1, ..., x_{dim-1}, u)`` corresponds to MPS index ``i + N_sites``.
Use :meth:`mps2lat_idx` and :meth:`lat2mps_idx` for conversion of indices.
The function :meth:`mps2lat_values` performs the necessary reshaping and re-ordering from
arrays indexed in MPS form to arrays indexed in lattice form.

.. deprecated :: 0.5.0
The parameters and attributes `nearest_neighbors`, `next_nearest_neighbors` and
`next_next_nearest_neighbors` are deprecated. Instead, we use a dictionary `pairs`
with those names as keys and the corresponding values as specified before.

Parameters
----------
Ls : list of int
the length in each direction
unit_cell : list of :class:`~tenpy.networks.site.Site`
The sites making up a unit cell of the lattice.
If you want to specify it only after initialization, use ``None`` entries in the list.
order : str | ``('standard', snake_winding, priority)`` | ``('grouped', groups)``
A string or tuple specifying the order, given to :meth:`ordering`.
bc : (iterable of) {'open' | 'periodic' | int}
Boundary conditions in each direction of the lattice.
A single string holds for all directions.
An integer `shift` means that we have periodic boundary conditions along this direction,
but shift/tilt by ``-shift*lattice.basis`` (~cylinder axis for ``bc_MPS='infinite'``)
when going around the boundary along this direction.
bc_MPS : 'finite' | 'segment' | 'infinite'
Boundary conditions for an MPS/MPO living on the ordered lattice.
If the system is ``'infinite'``, the infinite direction is always along the first basis
vector (justifying the definition of `N_rings` and `N_sites_per_ring`).
basis : iterable of 1D arrays
For each direction one translation vectors shifting the unit cell.
Defaults to the standard ONB ``np.eye(dim)``.
positions : iterable of 1D arrays
For each site of the unit cell the position within the unit cell.
Defaults to ``np.zeros((len(unit_cell), dim))``.
nearest_neighbors : ``None`` | list of ``(u1, u2, dx)``
next_nearest_neighbors : ``None`` | list of ``(u1, u2, dx)``
next_next_nearest_neighbors : ``None`` | list of ``(u1, u2, dx)``
pairs : dict
Of the form ``{'nearest_neighbors': [(u1, u2, dx), ...], ...}``.
Typical keys are ``'nearest_neighbors', 'next_nearest_neighbors'``.
For each of them, it specifies a list of tuples ``(u1, u2, dx)`` which can
be used as parameters for :meth:`~tenpy.models.model.CouplingModel.add_coupling`
to generate couplings over each pair of ,e.g., ``'nearest_neighbors'``.
Note that this adds couplings for each pair *only in one direction*!

Attributes
----------
Ls : tuple of int
the length in each direction.
shape : tuple of int
the 'shape' of the lattice, same as ``Ls + (len(unit_cell), )``
N_cells : int
the number of unit cells in the lattice, ``np.prod(self.Ls)``.
N_sites : int
the number of sites in the lattice, ``np.prod(self.shape)``.
N_sites_per_ring : int
Defined as ``N_sites / Ls``, for an infinite system the number of cites per "ring".
N_rings : int
Alias for ``Ls``, for an infinite system the number of "rings" in the unit cell.
unit_cell : list of :class:`~tenpy.networks.site.Site`
the sites making up a unit cell of the lattice.
bc : bool ndarray
Boundary conditions of the couplings in each direction of the lattice,
translated into a bool array with the global `bc_choices`.
bc_shift : None | ndarray(int)
The shift in x-direction when going around periodic boundaries in other directions.
bc_MPS : 'finite' | 'segment' | 'infinite'
Boundary conditions for an MPS/MPO living on the ordered lattice.
If the system is ``'infinite'``, the infinite direction is always along the first basis
vector (justifying the definition of `N_rings` and `N_sites_per_ring`).
basis : ndarray (dim, Dim)
translation vectors shifting the unit cell. The row `i` gives the vector shifting in
direction `i`.
unit_cell_positions : ndarray, shape (len(unit_cell), Dim)
for each site in the unit cell a vector giving its position within the unit cell.
pairs : dict
See above.
_order : ndarray (N_sites, dim+1)
The place where :attr:`order` is stored.
_strides : ndarray (dim, )
necessary for :meth:`lat2mps_idx`.
_perm : ndarray (N, )
permutation needed to make `order` lexsorted, ``_perm = np.lexsort(_order.T)``.
_mps2lat_vals_idx : ndarray `shape`
index array for reshape/reordering in :meth:`mps2lat_vals`
_mps_fix_u : tuple of ndarray (N_cells, ) np.intp
for each site of the unit cell an index array selecting the mps indices of that site.
_mps2lat_vals_idx_fix_u : tuple of ndarray of shape `Ls`
similar as `_mps2lat_vals_idx`, but for a fixed `u` picking a site from the unit cell.
"""
def __init__(self,
Ls,
unit_cell,
order='default',
bc='open',
bc_MPS='finite',
basis=None,
positions=None,
nearest_neighbors=None,
next_nearest_neighbors=None,
next_next_nearest_neighbors=None,
pairs=None):
self.unit_cell = list(unit_cell)
self._set_Ls(Ls)  # after setting unit_cell
if positions is None:
positions = np.zeros((len(self.unit_cell), self.dim))
if basis is None:
basis = np.eye(self.dim)
self.unit_cell_positions = np.array(positions)
self.basis = np.array(basis)
self.boundary_conditions = bc  # property setter for self.bc and self.bc_shift
self.bc_MPS = bc_MPS
# calculate order for MPS
self.order = self.ordering(order)
# uses attribute setter to calculte _mps2lat_vals_idx_fix_u etc and lat2mps
self.pairs = pairs if pairs is not None else {}
for name, NN in [('nearest_neighbors', nearest_neighbors),
('next_nearest_neighbors', next_nearest_neighbors),
('next_next_nearest_neighbors', next_next_nearest_neighbors)]:
if NN is None:
continue  # no value set
msg = "Lattice.__init__() got argument `{0!s}`.\nSet as `neighbors['{0!s}'] instead!"
msg = msg.format(name)
warnings.warn(msg, FutureWarning)
if name in self.pairs:
raise ValueError("{0!s} sepcified twice!".format(name))
self.pairs[name] = NN
self.test_sanity()  # check consistency

def test_sanity(self):
"""Sanity check.

Raises ValueErrors, if something is wrong.
"""
assert self.dim == len(self.Ls)
assert self.shape == self.Ls + (len(self.unit_cell), )
assert self.N_cells == np.prod(self.Ls)
if self.bc.shape != (self.dim, ):
raise ValueError("Wrong len of bc")
assert self.bc.dtype == np.bool
chinfo = None
for site in self.unit_cell:
if not isinstance(site, Site):
continue
if chinfo is None:
chinfo = site.leg.chinfo
if site.leg.chinfo != chinfo:
raise ValueError("All sites must have the same ChargeInfo!")
if self.basis.shape != self.dim:
raise ValueError("Need one basis vector for each direction!")
if self.unit_cell_positions.shape != len(self.unit_cell):
raise ValueError("Need one position for each site in the unit cell.")
if self.basis.shape != self.unit_cell_positions.shape:
raise ValueError("Different space dimensions of `basis` and `unit_cell_positions`")
if self.bc_MPS not in MPS._valid_bc:
raise ValueError("invalid MPS boundary conditions")
if self.bc and self.bc_MPS == 'infinite':
raise ValueError("Need periodic boundary conditions along the x-direction "
"for 'infinite' `bc_MPS`")
if not isinstance(self, IrregularLattice):
assert self.N_sites == np.prod(self.shape)
# if one of the following assert fails,
# the `ordering` function might have returned an invalid array
assert np.all(self._order >= 0) and np.all(self._order <= self.shape)
# rows of `order` unique and _perm correct?
assert np.all(
np.sum(self._order * self._strides, axis=1)[self._perm] == np.arange(self.N_sites))

def save_hdf5(self, hdf5_saver, h5gr, subpath):
"""Export `self` into a HDF5 file.

This method saves all the data it needs to reconstruct `self` with :meth:`from_hdf5`.

Specifically, it saves
:attr:`unit_cell`, :attr:`Ls`, :attr:`unit_cell_positions`, :attr:`basis`,
:attr:`boundary_conditions`, :attr:`pairs` under their name,
:attr:`bc_MPS` as ``"boundary_conditions_MPS"``, and
:attr:`order` as ``"order_for_MPS"``.
Moreover, it saves :attr:`dim` and :attr:`N_sites` as HDF5 attributes.

Parameters
----------
hdf5_saver : :class:`~tenpy.tools.hdf5_io.Hdf5Saver`
Instance of the saving engine.
h5gr : :class`Group`
HDF5 group which is supposed to represent `self`.
subpath : str
The `name` of `h5gr` with a ``'/'`` in the end.
"""
hdf5_saver.save(self.unit_cell, subpath + "unit_cell")
hdf5_saver.save(np.array(self.Ls, int), subpath + "lengths")
hdf5_saver.save(self.unit_cell_positions, subpath + "unit_cell_positions")
hdf5_saver.save(self.basis, subpath + "basis")
hdf5_saver.save(self.boundary_conditions, subpath + "boundary_conditions")
hdf5_saver.save(self.bc_MPS, subpath + "boundary_condition_MPS")
hdf5_saver.save(self.order, subpath + "order_for_MPS")
hdf5_saver.save(self.pairs, subpath + "pairs")
h5gr.attrs["dim"] = self.dim
h5gr.attrs["N_sites"] = self.N_sites

@classmethod
"""Load instance from a HDF5 file.

This method reconstructs a class instance from the data saved with :meth:`save_hdf5`.

Parameters
----------
h5gr : :class:`Group`
HDF5 group which is represent the object to be constructed.
subpath : str
The `name` of `h5gr` with a ``'/'`` in the end.

Returns
-------
obj : cls
Newly generated class instance containing the required data.
"""
obj = cls.__new__(cls)  # create class instance, no __init__() call

obj._set_Ls(Ls)

obj.test_sanity()
return obj

@property
def dim(self):
"""The dimension of the lattice."""
return len(self.Ls)

@property
def order(self):
"""Defines an ordering of the lattice sites, thus mapping the lattice to a 1D chain.

Each row of the array contains the lattice indices for one site,
the order of the rows thus specifies a path through the lattice,
along which an MPS will wind through through the lattice.

You can visualize the order with :meth:`plot_order`.
"""
return self._order

@order.setter
def order(self, order_):
# update the value itself
self._order = order_
# and the other stuff which is cached
self._perm = np.lexsort(order_.T)
self._mps2lat_vals_idx = np.empty(self.shape, np.intp)
self._mps2lat_vals_idx[tuple(order_.T)] = np.arange(self.N_sites)
# versions for fixed u
self._mps_fix_u = []
self._mps2lat_vals_idx_fix_u = []
for u in range(len(self.unit_cell)):
mps_fix_u = np.nonzero(order_[:, -1] == u)
self._mps_fix_u.append(mps_fix_u)
mps2lat_vals_idx = np.empty(self.Ls, np.intp)
mps2lat_vals_idx[tuple(order_[mps_fix_u, :-1].T)] = np.arange(self.N_cells)
self._mps2lat_vals_idx_fix_u.append(mps2lat_vals_idx)
self._mps_fix_u = tuple(self._mps_fix_u)

def ordering(self, order):
"""Provide possible orderings of the `N` lattice sites.

This function can be overwritten by derived lattices to define additional orderings.
The following orders are defined in this method:

================== =========================== =============================
`order`            equivalent `priority`       equivalent ``snake_winding``
================== =========================== =============================
``'Cstyle'``       (0, 1, ..., dim-1, dim)     (False, ..., False, False)
``'default'``
``'snake'``        (0, 1, ..., dim-1, dim)     (True, ..., True, True)
``'snakeCstyle'``
``'Fstyle'``       (dim-1, ..., 1, 0, dim)     (False, ..., False, False)
``'snakeFstyle'``  (dim-1, ..., 1, 0, dim)     (False, ..., False, False)
================== =========================== =============================

Parameters
----------
order : str | ``('standard', snake_winding, priority)`` | ``('grouped', groups)``
Specifies the desired ordering using one of the strings of the above tables.
Alternatively, an ordering is specified by a tuple with first entry specifying a
function, ``'standard'`` for :func:`get_order` and ``'grouped'`` for
:func:`get_order_grouped`, and other arguments in the tuple as specified in the
documentation of these functions.

Returns
-------
order : array, shape (N, D+1), dtype np.intp
the order to be used for :attr:`order`.

--------
get_order : generates the `order` from equivalent `priority` and `snake_winding`.
get_order_grouped : variant of `get_order`.
plot_order : visualizes the resulting `order`.
"""
if isinstance(order, str):
if order in ["default", "Cstyle"]:
priority = None
snake_winding = (False, ) * (self.dim + 1)
elif order == "Fstyle":
priority = range(self.dim, -1, -1)
snake_winding = (False, ) * (self.dim + 1)
elif order in ["snake", "snakeCstyle"]:
priority = None
snake_winding = (True, ) * (self.dim + 1)
elif order == "snakeFstyle":
priority = range(self.dim, -1, -1)
snake_winding = (True, ) * (self.dim + 1)
else:
# in a derived lattice use ``return super().ordering(order)`` as last option
# such that the derived lattice also has the orderings defined in this function.
raise ValueError("unknown ordering " + repr(order))
else:
descr = order
if descr == 'standard':
snake_winding, priority = order, order
elif descr == 'grouped':
return get_order_grouped(self.shape, order)
else:
raise ValueError("unknown ordering " + repr(order))
return get_order(self.shape, snake_winding, priority)

@property
def boundary_conditions(self):
"""Human-readable list of boundary conditions from :attr:`bc` and :attr:`bc_shift`.

Returns
-------
boundary_conditions : list of str
List of ``"open"`` or ``"periodic"``, one entry for each direction of the lattice.
"""
global bc_choices
bc_choices_reverse = dict([(v, k) for (k, v) in bc_choices.items()])
bc = [bc_choices_reverse[bc] for bc in self.bc]
if self.bc_shift is not None:
for i, shift in enumerate(self.bc_shift):
assert bc[i + 1] == "periodic"
bc[i + 1] = shift
return bc

@boundary_conditions.setter
def boundary_conditions(self, bc):
global bc_choices
if bc in list(bc_choices.keys()):
bc = [bc_choices[bc]] * self.dim
self.bc_shift = None
else:
bc = list(bc)  # we modify entries...
self.bc_shift = np.zeros(self.dim - 1, np.int_)
for i, bc_i in enumerate(bc):
if isinstance(bc_i, int):
if i == 0:
raise ValueError("Invalid bc: first entry can't be a shift")
self.bc_shift[i - 1] = bc_i
bc[i] = bc_choices['periodic']
else:
bc[i] = bc_choices[bc_i]
if not np.any(self.bc_shift != 0):
self.bc_shift = None
self.bc = np.array(bc)

def enlarge_mps_unit_cell(self, factor=2):
"""Repeat the unit cell for infinite MPS boundary conditions; in place.

Parameters
----------
factor : int
The new number of sites in the MPS unit cell will be increased from `N_sites` to
``factor*N_sites_per_ring``. Since MPS unit cells are repeated in the `x`-direction
in our convetion, the lattice shape goes from
``(Lx, Ly, ..., Lu)`` to ``(Lx*factor, Ly, ..., Lu)``.
"""
if self.bc_MPS != "infinite":
raise ValueError("can only enlarge the MPS unit cell for infinite MPS.")
new_Ls = list(self.Ls)
old_Lx = new_Ls
new_Ls = old_Lx * factor
old_order = self.order
new_order = []
for i in range(factor):
order = old_order.copy()
shift_x = i * old_Lx
order[:, 0] += shift_x
new_order.append(order)
new_order = np.vstack(new_order)
# now update the contents of `self`
self._set_Ls(new_Ls)
self.order = new_order  # property setter
self.test_sanity()

def position(self, lat_idx):
"""return 'space' position of one or multiple sites.

Parameters
----------
lat_idx : ndarray, ``(... , dim+1)``
Lattice indices.

Returns
-------
pos : ndarray, ``(..., dim)``
The position of the lattice sites specified by `lat_idx` in real-space.
"""
idx = self._asvalid_latidx(lat_idx)
res = np.take(self.unit_cell_positions, idx[..., -1], axis=0)
for i in range(self.dim):
res += idx[..., i, np.newaxis] * self.basis[i]
return res

def site(self, i):
"""return :class:`~tenpy.networks.site.Site` instance corresponding to an MPS index `i`"""
return self.unit_cell[self.order[i, -1]]

def mps_sites(self):
"""Return a list of sites for all MPS indices.

Equivalent to ``[self.site(i) for i in range(self.N_sites)]``.

This should be used for `sites` of 1D tensor networks (MPS, MPO,...).
"""
return [self.unit_cell[u] for u in self.order[:, -1]]

def mps2lat_idx(self, i):
"""Translate MPS index `i` to lattice indices ``(x_0, ..., x_{dim-1}, u)``.

Parameters
----------
i : int | array_like of int
MPS index/indices.

Returns
-------
lat_idx : array
First dimensions like `i`, last dimension has len `dim`+1 and contains the lattice
indices ``(x_0, ..., x_{dim-1}, u)`` corresponding to `i`.
For `i` accross the MPS unit cell and "infinite" `bc_MPS`, we shift `x_0` accordingly.
"""
if self.bc_MPS == 'infinite':
# allow `i` outsit of MPS unit cell for bc_MPS infinite
i0 = i
i = np.mod(i, self.N_sites)
if np.any(i0 != i):
lat = self.order[i].copy()
lat[..., 0] += (i0 - i) * self.N_rings // self.N_sites
# N_sites_per_ring might not be set for IrregularLattice
return lat
return self.order[i].copy()

def lat2mps_idx(self, lat_idx):
"""Translate lattice indices ``(x_0, ..., x_{D-1}, u)`` to MPS index `i`.

Parameters
----------
lat_idx : array_like [..., dim+1]
The last dimension corresponds to lattice indices ``(x_0, ..., x_{D-1}, u)``.
All lattice indices should be positive and smaller than the corresponding entry in
``self.shape``. Exception: for "infinite" `bc_MPS`, an `x_0` outside indicates shifts
accross the boundary.

Returns
-------
i : array_like
MPS index/indices corresponding to `lat_idx`.
Has the same shape as `lat_idx` without the last dimension.
"""
idx = self._asvalid_latidx(lat_idx)
if self.bc_MPS == 'infinite':
i_shift = idx[..., 0] - np.mod(idx[..., 0], self.N_rings)
idx[..., 0] -= i_shift
i = np.sum(np.mod(idx, self.shape) * self._strides, axis=-1)  # before permutation
i = np.take(self._perm, i)  # after permutation
if self.bc_MPS == 'infinite':
i += i_shift * self.N_sites // self.N_rings
# N_sites_per_ring might not be set for IrregularLattice
return i

def mps_idx_fix_u(self, u=None):
"""return an index array of MPS indices for which the site within the unit cell is `u`.

If you have multiple sites in your unit-cell, an onsite operator is in general not defined
for all sites. This functions returns an index array of the mps indices which belong to
sites given by ``self.unit_cell[u]``.

Parameters
----------
u : None | int
Selects a site of the unit cell. ``None`` (default) means all sites.

Returns
-------
mps_idx : array
MPS indices for which ``self.site(i) is self.unit_cell[u]``. Ordered ascending.
"""
if u is not None:
return self._mps_fix_u[u]
return self._perm

def mps_lat_idx_fix_u(self, u=None):
"""Similar as :meth:`mps_idx_fix_u`, but return also the corresponding lattice indices.

Parameters
----------
u : None | int
Selects a site of the unit cell. ``None`` (default) means all sites.

Returns
-------
mps_idx : array
MPS indices `i` for which ``self.site(i) is self.unit_cell[u]``.
lat_idx : 2D array
The row `j` contains the lattice index (without `u`) corresponding to ``mps_idx[j]``.
"""
mps_idx = self.mps_idx_fix_u(u)
return mps_idx, self.order[mps_idx, :-1]

def mps2lat_values(self, A, axes=0, u=None):
"""Reshape/reorder `A` to replace an MPS index by lattice indices.

Parameters
----------
A : ndarray
Some values. Must have ``A.shape[axes] = self.N_sites`` if `u` is ``None``, or
``A.shape[axes] = self.N_cells`` if `u` is an int.
axes : (iterable of) int
chooses the axis which should be replaced.
u : ``None`` | int
Optionally choose a subset of MPS indices present in the axes of `A`, namely the
indices corresponding to ``self.unit_cell[u]``, as returned by :meth:`mps_idx_fix_u`.
The resulting array will not have the additional dimension(s) of `u`.

Returns
-------
res_A : ndarray
Reshaped and reordered verions of A. Such that MPS indices along the specified axes
are replaced by lattice indices, i.e., if MPS index `j` maps to lattice site
`(x0, x1, x2)`, then ``res_A[..., x0, x1, x2, ...] = A[..., j, ...]``.

Examples
--------
Say you measure expection values of an onsite term for an MPS, which gives you an 1D array
`A`, where `A[i]` is the expectation value of the site given by ``self.mps2lat_idx(i)``.
Then this function gives you the expectation values ordered by the lattice:

>>> print(lat.shape, A.shape)
(10, 3, 2) (60,)
>>> A_res = lat.mps2lat_values(A)
>>> A_res.shape
(10, 3, 2)
>>> A_res[lat.mps2lat_idx(5)] == A
True

If you have a correlation function ``C[i, j]``, it gets just slightly more complicated:

>>> print(lat.shape, C.shape)
(10, 3, 2) (60, 60)
>>> lat.mps2lat_values(C, axes=[0, 1]).shape
(10, 3, 2, 10, 3, 2)

If the unit cell consists of different physical sites, an onsite operator might be defined
only on one of the sites in the unit cell. Then you can use :meth:`mps_idx_fix_u` to get
the indices of sites it is defined on, measure the operator on these sites, and use
the argument `u` of this function.

>>> u = 0
>>> idx_subset = lat.mps_idx_fix_u(u)
>>> A_u = A[idx_subset]
>>> A_u_res = lat.mps2lat_values(A_u, u=u)
>>> A_u_res.shape
(10, 3)
>>> np.all(A_res[:, :, u] == A_u_res[:, :])
True

.. todo ::
make sure this function is used for expectation values...
"""
axes = to_iterable(axes)
if len(axes) > 1:
axes = [(ax + A.ndim if ax < 0 else ax) for ax in axes]
A = self.mps2lat_values(A, ax, u)  # recursion with single axis
return A
# choose the appropriate index arrays
if u is None:
idx = self._mps2lat_vals_idx
else:
idx = self._mps2lat_vals_idx_fix_u[u]
return np.take(A, idx, axis=axes)

def mps2lat_values_masked(self, A, axes=-1, mps_inds=None, include_u=None):
"""Reshape/reorder an array `A` to replace an MPS index by lattice indices.

This is a generalization of :meth:`mps2lat_values` allowing for the case of an arbitrary
set of MPS indices present in each axis of `A`.

Parameters
----------
A : ndarray
Some values.
axes : (iterable of) int
Chooses the axis of `A` which should be replaced.
If multiple axes are given, you also need to give multiple index arrays as `mps_inds`.
mps_inds : (list of) 1D ndarray
Specifies for each `axis` in `axes`, for which MPS indices we have values in the
corresponding `axis` of `A`.
Defaults to ``[np.arange(A.shape[ax]) for ax in axes]``.
For indices accross the MPS unit cell and "infinite" `bc_MPS`,
we shift `x_0` accordingly.
include_u : (list of) bool
Specifies for each `axis` in `axes`, whether the `u` index of the lattice should be
included into the output array `res_A`. Defaults to ``len(self.unit_cell) > 1``.

Returns
-------
Reshaped and reordered copy of A. Such that MPS indices along the specified axes
are replaced by lattice indices, i.e., if MPS index `j` maps to lattice site
`(x0, x1, x2)`, then ``res_A[..., x0, x1, x2, ...] = A[..., mps_inds[j], ...]``.
"""
try:
iter(axes)
except TypeError:  # axes is single int
axes = [axes]
mps_inds = [mps_inds]
include_u = [include_u]
else:  # iterable axes
if mps_inds is None:
mps_inds = [None] * len(axes)
if include_u is None:
include_u = [None] * len(axes)
if len(axes) != len(mps_inds) or len(axes) != len(include_u):
raise ValueError("Lenght of `axes`, `mps_inds` and `include_u` different")
# sort axes ascending
axes = [(ax + A.ndim if ax < 0 else ax) for ax in axes]

# goal: use numpy advanced indexing for the data copy
lat_inds = []  # lattice indices to be used
new_shapes = []  # shape to be
for ax, mps_inds_ax, include_u_ax in zip(axes, mps_inds, include_u):
if mps_inds_ax is None:  # default
mps_inds_ax = np.arange(A.shape[ax])
if include_u_ax is None:  # default
include_u_ax = (len(self.unit_cell) > 1)
if mps_inds_ax.ndim != 1:
raise ValueError("got non-1D array in `mps_inds` " + str(mps_inds_ax.shape))
lat_inds_ax = self.mps2lat_idx(mps_inds_ax)
shape = list(self.shape)
max_i = np.max(mps_inds_ax)
if max_i >= self.N_sites:
shape += (max_i - self.N_sites) * self.N_rings // self.N_sites + 1
min_i = np.min(mps_inds_ax)
if min_i < 0:
# we use numpy indexing to simply wrap around negative indices
shape += (abs(min_i) - 1) * self.N_rings // self.N_sites + 1
if not include_u_ax:
shape = shape[:-1]
lat_inds_ax = lat_inds_ax[:, :-1]
new_shapes.append(shape)
lat_inds.append(lat_inds_ax)

res_A_ndim = A.ndim - len(axes) + sum([len(s) for s in new_shapes])
res_A_shape = []
res_A_inds = []
dim_before = 0
dim_after = A.ndim - 1
for ax in range(A.ndim):
if ax in axes:
i = axes.index(ax)
inds_ax = lat_inds[i].T
res_A_shape.extend(new_shapes[i])
for inds in lat_inds[i].T:
inds = inds.reshape( * dim_before + [len(inds)] +  * dim_after)
res_A_inds.append(inds)
else:
d = A.shape[ax]
res_A_shape.append(d)
inds = np.arange(d).reshape( * dim_before + [d] +  * dim_after)
res_A_inds.append(inds)
dim_before += 1
dim_after -= 1

# and finally we are in the position to create the masked Array
fill_value = np.ma.array([0, 1], dtype=A.dtype).get_fill_value()
res_A_data = np.full(res_A_shape, fill_value, dtype=A.dtype)

res_A[tuple(res_A_inds)] = A  # copy data, automatically unmasks entries
return res_A

def count_neighbors(self, u=0, key='nearest_neighbors'):
"""Count e.g. the number of nearest neighbors for a site in the bulk.

Parameters
----------
u : int
Specifies the site in the unit cell, for which we should count the number
of neighbors (or whatever `key` specifies).
key : str
Key of :attr:`pairs` to select what to count.

Returns
-------
number : int
Number of nearest neighbors (or whatever `key` specified) for the `u`-th site in the
unit cell, somewhere in the bulk of the lattice. Note that it might not be the correct
value at the edges of a lattice with open boundary conditions.
"""
pairs = self.pairs[key]
count = 0
for u1, u2, dx in pairs:
if u1 == u:
count += 1
if u2 == u:
count += 1
return count

def number_nearest_neighbors(self, u=0):
"""Deprecated.

.. deprecated :: 0.5.0
"""
msg = "Use ``count_neighbors(u, 'nearest_neighbors')`` instead."
warnings.warn(msg, FutureWarning)
return self.count_neighbors(u, 'nearest_neighbors')

def number_next_nearest_neighbors(self, u=0):
"""Deprecated.

.. deprecated :: 0.5.0
"""
msg = "Use ``count_neighbors(u, 'next_nearest_neighbors')`` instead."
warnings.warn(msg, FutureWarning)
return self.count_neighbors(u, 'next_nearest_neighbors')

def coupling_shape(self, dx):
"""Calculate correct shape of the `strengths` for a coupling.

Parameters
----------
dx : tuple of int
Translation vector in the lattice for a coupling of two operators.
Corresponds to `dx` argument of

Returns
-------
coupling_shape : tuple of int
Len :attr:`dim`. The correct shape for an array specifying the coupling strength.
`lat_indices` has only rows within this shape.
shift_lat_indices : array
Translation vector from origin to the lower left corner of box spanned by `dx`.
"""
shape = [La - abs(dxa) * int(bca) for La, dxa, bca in zip(self.Ls, dx, self.bc)]
shift_strength = [min(0, dxa) for dxa in dx]
return tuple(shape), np.array(shift_strength)

def possible_couplings(self, u1, u2, dx):
"""Find possible MPS indices for two-site couplings.

For periodic boundary conditions (``bc[a] == False``)
the index ``x_a`` is taken modulo ``Ls[a]`` and runs through ``range(Ls[a])``.
For open boundary conditions, ``x_a`` is limited to ``0 <= x_a < Ls[a]`` and
``0 <= x_a+dx[a] < lat.Ls[a]``.

Parameters
----------
u1, u2 : int
Indices within the unit cell; the `u1` and `u2` of
dx : array
Length :attr:`dim`. The translation in terms of basis vectors for the coupling.

Returns
-------
mps1, mps2 : array
For each possible two-site coupling the MPS indices for the `u1` and `u2`.
lat_indices : 2D int array
Rows of `lat_indices` correspond to rows of `mps_ijkl` and contain the lattice indices
of the "lower left corner" of the box containing the coupling.
coupling_shape : tuple of int
Len :attr:`dim`. The correct shape for an array specifying the coupling strength.
`lat_indices` has only rows within this shape.
"""
coupling_shape, shift_lat_indices = self.coupling_shape(dx)
if any([s == 0 for s in coupling_shape]):
return [], [], np.zeros([0, self.dim]), coupling_shape
Ls = np.array(self.Ls)
mps_i, lat_i = self.mps_lat_idx_fix_u(u1)
lat_j_shifted = lat_i + dx
lat_j = np.mod(lat_j_shifted, Ls)  # assuming PBC
if self.bc_shift is not None:
shift = np.sum(((lat_j_shifted - lat_j) // Ls)[:, 1:] * self.bc_shift, axis=1)
lat_j_shifted[:, 0] -= shift
lat_j[:, 0] = np.mod(lat_j_shifted[:, 0], Ls)
keep = self._keep_possible_couplings(lat_j, lat_j_shifted, u2)
mps_i = mps_i[keep]
lat_indices = lat_i[keep] + shift_lat_indices[np.newaxis, :]
lat_indices = np.mod(lat_indices, coupling_shape)
lat_j = lat_j[keep]
lat_j_shifted = lat_j_shifted[keep]
mps_j = self.lat2mps_idx(np.concatenate([lat_j, [[u2]] * len(lat_j)], axis=1))
if self.bc_MPS == 'infinite':
# shift j by whole MPS unit cells for couplings along the infinite direction
# N_sites_per_ring might not be set for IrregularLattice
mps_j_shift = (lat_j_shifted[:, 0] - lat_j[:, 0]) * self.N_sites // self.N_rings
mps_j += mps_j_shift
# finally, ensure 0 <= min(i, j) < N_sites.
mps_ij_shift = np.where(mps_j_shift < 0, -mps_j_shift, 0)
mps_i += mps_ij_shift
mps_j += mps_ij_shift
return mps_i, mps_j, lat_indices, coupling_shape

def _keep_possible_couplings(self, lat_j, lat_j_shifted, u2):
"""filter possible j sites of a coupling from :meth:`possible_couplings`"""
return np.all(
np.logical_or(
lat_j_shifted == lat_j,  # not accross the boundary
np.logical_not(self.bc)),  # direction has PBC
axis=1)

def multi_coupling_shape(self, dx):
"""Calculate correct shape of the `strengths` for a multi_coupling.

Parameters
----------
dx : 2D array, shape (N_ops, :attr:`dim`)
``dx[i, :]`` is the translation vector in the lattice for the `i`-th operator.
Corresponds to the `dx` of each operator given in the argument `ops` of

Returns
-------
coupling_shape : tuple of int
Len :attr:`dim`. The correct shape for an array specifying the coupling strength.
`lat_indices` has only rows within this shape.
shift_lat_indices : array
Translation vector from origin to the lower left corner of box spanned by `dx`.
(Unlike for :meth:`coupling_shape` it can also contain entries > 0)
"""
# coupling_shape(dx) is equivalent to
# multi_coupling_shape(np.array([*self.dim, dx]))
Ls = self.Ls
shape = [None] * len(Ls)
shift_strength = [None] * len(Ls)
for a in range(len(Ls)):
max_dx, min_dx = np.max(dx[:, a]), np.min(dx[:, a])
box_dx = max_dx - min_dx
shape[a] = Ls[a] - box_dx * int(self.bc[a])
shift_strength[a] = min_dx  # note: can be positive!
return tuple(shape), np.array(shift_strength)

def possible_multi_couplings(self, ops):
"""Generalization of :meth:`possible_couplings` to couplings with more than 2 sites.

Parameters
----------
ops : list of ``(opname, dx, u)``
Same as the argument `ops` of

Returns
-------
mps_ijkl : 2D int array
Each row contains MPS indices `i,j,k,l,...`` for each of the operators positions.
The positions are defined by `dx` (j,k,l,... relative to `i`) and boundary coundary
conditions of `self` (how much the `box` for given `dx` can be shifted around without
hitting a boundary - these are the different rows).
lat_indices : 2D int array
Rows of `lat_indices` correspond to rows of `mps_ijkl` and contain the lattice indices
of the "lower left corner" of the box containing the coupling.
coupling_shape : tuple of int
Len :attr:`dim`. The correct shape for an array specifying the coupling strength.
`lat_indices` has only rows within this shape.
"""
D = self.dim
Nops = len(ops)
Ls = np.array(self.Ls)
# make 3D arrays ["iteration over lattice", "operator", "spatial direction"]
# recall numpy broadcasing: 1D equivalent to [np.newaxis, np.newaxis, :]
dx = np.array([op_dx for _, op_dx, op_u in ops], dtype=np.int_).reshape([1, Nops, D])
u = np.array([op_u for _, op_dx, op_u in ops], dtype=np.int_).reshape([1, Nops, 1])
coupling_shape, shift_lat_indices = self.multi_coupling_shape(dx[0, :, :])
if any([s == 0 for s in coupling_shape]):
return [], [], coupling_shape
lat_indices = np.indices(coupling_shape).reshape([1, self.dim, -1]).transpose([2, 0, 1])
lat_ijkl_shifted = lat_indices + (dx - shift_lat_indices)
lat_ijkl = np.mod(lat_ijkl_shifted, Ls)
if self.bc_shift is not None:
shift = np.sum(((lat_ijkl_shifted - lat_ijkl) // Ls)[:, :, 1:] * self.bc_shift, axis=2)
lat_ijkl_shifted[:, :, 0] -= shift
lat_ijkl[:, :, 0] = np.mod(lat_ijkl_shifted[:, :, 0], Ls)
keep = self._keep_possible_multi_couplings(lat_ijkl, lat_ijkl_shifted, u)
lat_indices = lat_indices[keep, 0, :]  # make 2D as to be returned
lat_ijkl = lat_ijkl[keep, :, :]
u = np.broadcast_to(u, lat_ijkl.shape[:2] + (1, ))
mps_ijkl = self.lat2mps_idx(np.concatenate([lat_ijkl, u], axis=2))
if self.bc_MPS == 'infinite':
# N_sites_per_ring might not be set for IrregularLattice
mps_ijkl += ((lat_ijkl_shifted[keep, :, 0] - lat_ijkl[:, :, 0]) * self.N_sites //
self.N_rings)
return mps_ijkl, lat_indices, coupling_shape

def _keep_possible_multi_couplings(self, lat_ijkl, lat_ijkl_shifted, u_ijkl):
"""Filter possible couplings from :meth:`possible_multi_couplings`"""
return np.all(
np.logical_or(
lat_ijkl_shifted == lat_ijkl,  # not accross the boundary
np.logical_not(self.bc)),  # direction has PBC
axis=(1, 2))

def plot_sites(self, ax, markers=['o', '^', 's', 'p', 'h', 'D'], **kwargs):
"""Plot the sites of the lattice with markers.

Parameters
----------
ax : :class:`matplotlib.axes.Axes`
The axes on which we should plot.
markers : list
List of values for the keywork `marker` of ``ax.plot()`` to distinguish the different
sites in the unit cell, a site `u` in the unit cell is plotted with a marker
``markers[u % len(markers)]``.
**kwargs :
Further keyword arguments given to ``ax.plot()``.
"""
kwargs.setdefault("linestyle", 'None')
use_marker = ('marker' not in kwargs)
for u in range(len(self.unit_cell)):
pos = self.position(self.order[self.mps_idx_fix_u(u), :])
if pos.shape == 1:
pos = pos * np.array([[1., 0]])  # use broadcasting to add a column with zeros
if pos.shape != 2:
raise ValueError("can only plot in 2 dimensions.")
if use_marker:
kwargs['marker'] = markers[u % len(markers)]
ax.plot(pos[:, 0], pos[:, 1], **kwargs)

def plot_order(self, ax, order=None, textkwargs={}, **kwargs):
"""Plot a line connecting sites in the specified "order" and text labels enumerating them.

Parameters
----------
ax : :class:`matplotlib.axes.Axes`
The axes on which we should plot.
order : None | 2D array (self.N_sites, self.dim+1)
The order as returned by :meth:`ordering`; by default (``None``) use :attr:`order`.
textkwargs: ``None`` | dict
If not ``None``, we add text labels enumerating the sites in the plot. The dictionary
can contain keyword arguments for ``ax.text()``.
**kwargs :
Further keyword arguments given to ``ax.plot()``.
"""
if order is None:
order = self.order
pos = self.position(order)
kwargs.setdefault('color', 'r')
if pos.shape == 1:
pos = pos * np.array([[1., 0]])  # use broadcasting to add a column with zeros
if pos.shape != 2:
raise ValueError("can only plot in 2 dimensions.")
ax.plot(pos[:, 0], pos[:, 1], **kwargs)
if textkwargs is not None:
textkwargs.setdefault('color', kwargs['color'])
for i, p in enumerate(pos):
ax.text(p, p, str(i), **textkwargs)

def plot_coupling(self, ax, coupling=None, wrap=False, **kwargs):
"""Plot lines connecting nearest neighbors of the lattice.

Parameters
----------
ax : :class:`matplotlib.axes.Axes`
The axes on which we should plot.
coupling : list of (u1, u2, dx)
By default (``None``), use ``self.pairs['nearest_neighbors']``.
Specifies the connections to be plotted; iteating over lattice indices `(i0, i1, ...)`,
we plot a connection from the site ``(i0, i1, ..., u1)`` to the site
``(i0+dx, i1+dx, ..., u2)``, taking into account the boundary conditions.
wrap : bool
If True, wrap
**kwargs :
Further keyword arguments given to ``ax.plot()``.
"""
if coupling is None:
coupling = self.pairs['nearest_neighbors']
kwargs.setdefault('color', 'k')
Ls = np.array(self.Ls)
for u1, u2, dx in coupling:
if wrap:
mps_i, mps_j, _, _ = self.possible_couplings(u1, u2, dx)
pos1 = self.position(self.mps2lat_idx(mps_i))
pos2 = self.position(self.mps2lat_idx(mps_j))
else:
dx = np.r_[np.array(dx), u2 - u1]  # append the difference in u to dx
lat_idx_1 = self.order[self._mps_fix_u[u1], :]
lat_idx_2 = lat_idx_1 + dx[np.newaxis, :]
lat_idx_2_mod = np.mod(lat_idx_2[:, :-1], Ls)
keep = self._keep_possible_couplings(lat_idx_2_mod, lat_idx_2[:, :-1], u2)
# get positions
pos1 = self.position(lat_idx_1[keep, :])
pos2 = self.position(lat_idx_2[keep, :])
pos = np.stack((pos1, pos2), axis=0)
# ax.plot connects columns of 2D array by lines
if pos.shape == 1:
pos = pos * np.array([[[1., 0]]])  # use broadcasting to add a column with zeros
if pos.shape != 2:
raise ValueError("can only plot in 2 dimensions.")
ax.plot(pos[:, :, 0], pos[:, :, 1], **kwargs)

def plot_basis(self, ax, origin=(0., 0.), shade=None, **kwargs):
"""Plot arrows indicating the basis vectors of the lattice.

Parameters
----------
ax : :class:`matplotlib.axes.Axes`
The axes on which we should plot.
**kwargs :
Keyword arguments for ``ax.arrow``.
"""
kwargs.setdefault("width", 0.03)
kwargs.setdefault("color", 'g')
origin = np.array(origin)
basis = np.array([self.basis[i] for i in range(self.dim)])
if basis.shape == 1:
basis = basis * np.array([[1., 0]])
if basis.shape != 2:
raise ValueError("can only plot in 2 dimensions.")
shade = True if self.dim == 2 else False
from matplotlib.patches import Polygon
xy = [origin, origin + basis, origin + basis + basis, origin + basis]
for i in range(self.dim):
vec = basis[i]
ax.arrow(origin, origin, vec, vec, **kwargs)

def plot_bc_identified(self, ax, direction=-1, shift=None, cylinder_axis=False, **kwargs):
"""Mark two sites indified by periodic boundary conditions.

Works only for lattice with a 2-dimensional basis.

Parameters
----------
ax : :class:`matplotlib.axes.Axes`
The axes on which we should plot.
direction : int
The direction of the lattice along which we should mark the idenitified sites.
If ``None``, mark it along all directions with periodic boundary conditions.
cylinder_axis : bool
Whether to plot the cylinder axis as well.
shift : None | np.ndarray
The origin starting from where we mark the identified sites.
Defaults to the first entry of :attr:`unit_cell_positions`.
**kwargs :
Keyword arguments for the used ``ax.plot``.
"""
if direction is None:
dirs = [i for i in range(self.dim) if not self.bc[i]]
else:
if direction < 0:
direction += self.dim
dirs = [direction]
shift = self.unit_cell_positions
kwargs.setdefault("marker", "o")
kwargs.setdefault("markersize", 10)
kwargs.setdefault("color", "orange")
x_y = []
for i in dirs:
if self.bc[i]:
raise ValueError("Boundary conditons are not periodic for given direction")
x_y.append(shift)
x_y.append(shift + self.Ls[i] * self.basis[i])
if self.bc_shift is not None and i > 0:
x_y[-1] = x_y[-1] + self.bc_shift[i - 1] * self.basis
x_y = np.array(x_y)
if x_y.shape != 2:
raise ValueError("can only plot in 2D")
ax.plot(x_y[:, 0], x_y[:, 1], **kwargs)
if cylinder_axis:
if len(x_y) != 2 or self.dim != 2:
raise ValueError("can't plot cylinder axis for multiple directions")
center = np.mean(x_y, axis=0)
diff = x_y[1, :] - x_y
perp = np.array([diff, -diff])
x_y_cyl = np.array([center - perp, center + perp])
kwargs.setdefault('linestyle', '--')
kwargs['marker'] = None
ax.plot(x_y_cyl[:, 0], x_y_cyl[:, 1], **kwargs)

def _asvalid_latidx(self, lat_idx):
"""convert lat_idx to an ndarray with correct last dimension."""
lat_idx = np.asarray(lat_idx, dtype=np.intp)
if lat_idx.shape[-1] != len(self.shape):
raise ValueError("wrong len of last dimension of lat_idx: " + str(lat_idx.shape))
return lat_idx

def _set_Ls(self, Ls):
self.Ls = tuple([int(L) for L in Ls])
self.N_cells = int(np.prod(self.Ls))
self.shape = self.Ls + (len(self.unit_cell), )
self.N_sites = int(np.prod(self.shape))
self.N_rings = self.Ls
self.N_sites_per_ring = int(self.N_sites // self.N_rings)
strides = 
for L in self.Ls:
strides.append(strides[-1] * L)
self._strides = np.array(strides, np.intp)

@property
def nearest_neighbors(self):
msg = ("Deprecated access with ``lattice.nearest_neighbors``.\n"
warnings.warn(msg, FutureWarning)
return self.pairs['nearest_neighbors']

@property
def next_nearest_neighbors(self):
msg = ("Deprecated access with ``lattice.next_nearest_neighbors``.\n"
warnings.warn(msg, FutureWarning)
return self.pairs['next_nearest_neighbors']

@property
def next_next_nearest_neighbors(self):
msg = ("Deprecated access with ``lattice.next_next_nearest_neighbors``.\n"
warnings.warn(msg, FutureWarning)
return self.pairs['next_next_nearest_neighbors']

class TrivialLattice(Lattice):
"""Trivial lattice consisting of a single (possibly large) unit cell in 1D.

This is usefull if you need a valid :class:`Lattice` with given :meth:`mps_sites`
and don't care about the actual geometry, e.g, because you don't intend to use the
:class:`~tenpy.models.model.CouplingModel`.

Parameters
----------
mps_sites : list of :class:`~tenpy.networks.site.Site`
The sites making up a unit cell of the lattice.
**kwargs :
Further keyword arguments given to :class:`Lattice`.
"""
def __init__(self, mps_sites, **kwargs):
Lattice.__init__(self, , mps_sites, **kwargs)

class IrregularLattice(Lattice):
"""A variant of a regular lattice, where we might have extra sites or sites missing.

.. note ::
The lattice defines only the geometry of the sites, not the couplings;
you can have position-dependent couplings/onsite terms despite having a regular lattice.

By adjusting the :attr:`order` and a few private attributes and methods, we can
make functions like :meth:`possible_couplings` work with a more "irregular" lattice structure,
where some of the sites are missing and other sites added instead.

Parameters
----------
regular_lattice : :class:`Lattice`
The lattice this is based on.
remove : 2D array | None
Each row is a lattice index ``(x_0, ..., x_{dim-1}, u)`` of a site to be removed.
If ``None``, don't remove something.
add : Tuple[2D array, list] | None
Each row of the 2D array is a lattice index ``(x_0, ..., x_{dim-1}, u)`` specifiying
where a site is to be added; `u` is the index of the site within the final
:attr:`unit_cell` of the irregular lattice.
For each row of the 2D array, there is one entry in the list specifying where the site
is inserted in the MPS; the values are compared to the MPS indices of the *regular* lattice
and sorted into it, so "2.5" goes between what was site 2 and 3 in the regular lattice.
An entry `None` indicates that the site should be inserted after the lattice site
``(x_0, ..., x_{dim-1}, -1)`` of the `regular_lattice`.
Extra sites to be added to the unit cell.
add_positions : iterable of 1D arrays
For each extra site in `add_unit_cell` the position within the unit cell.

Attributes
----------
regular_lattice : :class:`Lattice`
The lattice this is based on.
remove, add : 2D array | None
See above.  Used in :meth:`ordering` only.

Examples
--------
Let's imagine that we have two different sites; for concreteness we can thing of a
fermion site, which we represent with ``'F'``, and a spin site ``'S'``.

You could now imagine that to have fermion chain with spins on the "bonds".
In the periodic/infinite case, you would simply define

>>> lat = Lattice(, unit_cell=['F', 'S'], bc='periodic', bc_MPS='infinite')
>>> lat.mps_sites()
['F', 'S', 'F', 'S']

For a finite system, you don't want to terminate with a spin on the right, so you need to
remove the very last site by specifying the lattice index ``[L-1, 1]`` of that site:

>>> L = 4
>>> reg_lat = Lattice([L], unit_cell=['F', 'S'], bc='open', bc_MPS='finite')
>>> irr_lat = IrregularLattice(reg_lat, remove=[[L-1, 1]])
>>> irr_lat.mps_sites()
['F', 'S', 'F', 'S', 'F', 'S', 'F']

Another simple example would be to add a spin in the center of a fermion chain.
In that case, we add another site to the unit cell and specify the lattice index as
``[(L-1)//2, 1]`` (where the 1 is the index of ``'S'`` in the unit cell ``['F', 'S']`` of the
irregular lattice).
The `None` for the MPS index is equivalent to ``(L-1)/2`` in this case.

>>> reg_lat = Lattice([L], unit_cell=['F'])
>>> irr_lat = IrregularLattice(reg_lat, add=([[(L-1)//2, 1]], [None]),
['F', 'F', 'S', 'F', 'F']

"""
_REMOVED = -123456  # value in self._perm indicating removed sites.

def __init__(self,
regular_lattice,
remove=None,
self.regular_lattice = regular_lattice
if remove is not None:
remove = np.array(remove, np.intp)
self.remove = remove
Lattice.__init__(
self,
regular_lattice.Ls,
unit_cell,
bc=regular_lattice.boundary_conditions,
bc_MPS=regular_lattice.bc_MPS,
basis=regular_lattice.basis,
positions=positions,
pairs=self.regular_lattice.pairs,
)
# done

def save_hdf5(self, hdf5_saver, h5gr, subpath):
super().save_hdf5(hdf5_saver, h5gr, subpath)
hdf5_saver.save(self.regular_lattice, subpath + "regular_lattice")
hdf5_saver.save(self.remove, subpath + "remove")

@classmethod
return obj

def ordering(self, order):
"""Provide possible orderings of the lattice sites.

Parameters
----------
order :
Argument for the :meth:`Lattice.ordering` of the :attr:`regular_lattice`, or
2D ndarray providing the order of the *regular lattice*.

Returns
-------
order : array, shape (N, D+1)
The order to be used for :attr:`order`, i.e. `order` with added/removed sites
as specified by :attr:`remove` and :attr:`add`.
"""
order_ = self.regular_lattice.ordering(order)
mps_reg = np.arange(len(order_))
if self.remove is not None:
self._perm = np.lexsort(order_.T)  # allow to temporarily use lat2mps_idx for lattice
# indices with u from regular lattice
keep = np.ones([len(order_)], np.bool_)
keep[self.lat2mps_idx(self.remove)] = False
order_ = order_[keep]
mps_reg = mps_reg[keep]
# sort such that MPS indices are ascending
close_to = np.array(lat_idx[i])
close_to[-1] = len(self.regular_lattice.unit_cell) - 1
order_ = np.concatenate((order_, np.array(lat_idx)), axis=0)
order_ = order_[sort, :]
return order_

@Lattice.order.setter
def order(self, order_):
self._order = np.array(order_, dtype=np.intp)

# this defines `self._perm`
perm = np.full([np.prod(self.shape)], self._REMOVED)
perm[np.sum(self._order * self._strides, axis=1)] = np.arange(len(order_))
self._perm = perm

# and the other stuff which is cached
# versions for fixed u
self._mps_fix_u = []
for u in range(len(self.unit_cell)):
mps_fix_u = np.nonzero(order_[:, -1] == u)
self._mps_fix_u.append(mps_fix_u)
self._mps_fix_u = tuple(self._mps_fix_u)
self.N_sites = len(order_)
_, counts = np.unique(order_[:, 0], return_counts=True)
if np.all(counts == counts):
self.N_sites_per_ring = counts
else:
self.N_sites_per_ring = None

# mps2lat_idx and lat2mps_idx work thanks to they way _perm is defined,
# mps2lat_values, mps2lat_values_masked work as well

def mps_idx_fix_u(self, u=None):
if u is not None:
return self._mps_fix_u[u]
return self._perm[self._perm != self._REMOVED]

# make possible_couplings and possible_multi_couplings work

def _keep_possible_couplings(self, lat_j, lat_j_shifted, u2):
"""filter possible j sites of a coupling from :meth:`possible_couplings`"""
keep = super()._keep_possible_couplings(lat_j, lat_j_shifted, u2)
lat_j_u = np.concatenate([lat_j, np.full([len(lat_j), 1], u2)], axis=1)
i = np.sum(lat_j_u * self._strides, axis=-1)
i = np.take(self._perm, i, 0)
return np.logical_and(keep, i != self._REMOVED)

def _keep_possible_multi_couplings(self, lat_ijkl, lat_ijkl_shifted, u_ijkl):
"""filter possible j sites of a coupling from :meth:`possible_couplings`"""
keep = super()._keep_possible_multi_couplings(lat_ijkl, lat_ijkl_shifted, u_ijkl)
u_ijkl = np.broadcast_to(u_ijkl, lat_ijkl.shape[:2] + (1, ))
i = np.sum(np.concatenate([lat_ijkl, u_ijkl], axis=2) * self._strides, axis=-1)
i = np.take(self._perm, i, 0)
return np.logical_and(keep, np.all(i != self._REMOVED, axis=1))

def _set_Ls(self, Ls):
super()._set_Ls(Ls)
# N_sites, N_sites_per_ring set by order setter
# self.N_sites = None
self.N_sites_per_ring = None

class SimpleLattice(Lattice):
"""A lattice with a unit cell consiting of just a single site.

In many cases, the unit cell consists just of a single site, such that the the last entry of
`u` of an 'lattice index' can only be ``0``.
From the point of internal algorithms, we handle this class like a :class:`Lattice` --
in that way we don't need to distinguish special cases in the algorithms.

Yet, from the point of a tenpy user, for example if you measure an expectation value
on each site in a `SimpleLattice`, you expect to get an ndarray of dimensions ``self.Ls``,
not ``self.shape``. To avoid that problem, `SimpleLattice` overwrites just the meaning of
``u=None`` in :meth:`mps2lat_values` to be the same as ``u=0``.

Parameters
----------
Ls : list of int
the length in each direction
site : :class:`~tenpy.networks.site.Site`
the lattice site. The `unit_cell` of the :class:`Lattice` is just ``[site]``.
**kwargs :
Additional keyword arguments given to the :class:`Lattice`.
If `order` is specified in the form ``('standard', snake_windingi, priority)``,
the `snake_winding` and `priority` should only be specified for the spatial directions.
Similarly, `positions` can be specified as a single vector.
"""
def __init__(self, Ls, site, **kwargs):
if 'positions' in kwargs:
Dim = len(kwargs['basis']) if 'basis' in kwargs else len(Ls)
kwargs['positions'] = np.reshape(kwargs['positions'], (1, Dim))
if 'order' in kwargs and not isinstance(kwargs['order'], str):
descr, snake_winding, priority = kwargs['order']
assert descr == 'standard'
snake_winding = tuple(snake_winding) + (False, )
priority = tuple(priority) + (max(priority) + 1., )
kwargs['order'] = descr, snake_winding, priority
Lattice.__init__(self, Ls, [site], **kwargs)

def mps2lat_values(self, A, axes=0, u=None):
"""same as :meth:`Lattice.mps2lat_values`, but ignore ``u``, setting it to ``0``."""
return super().mps2lat_values(A, axes, 0)

class Chain(SimpleLattice):
"""A chain of L equal sites.

.. plot ::

import matplotlib.pyplot as plt
from tenpy.models import lattice
plt.figure(figsize=(5, 1.4))
ax = plt.gca()
lat = lattice.Chain(4, None, bc='periodic')
lat.plot_coupling(ax, linewidth=3.)
lat.plot_order(ax, linestyle=':')
lat.plot_sites(ax)
ax.set_xlim(-1.)
ax.set_ylim(-0.5, 0.5)
ax.set_aspect('equal')
plt.show()

Parameters
----------
L : int
The lenght of the chain.
site : :class:`~tenpy.networks.site.Site`
The local lattice site. The `unit_cell` of the :class:`Lattice` is just ``[site]``.
**kwargs :
Additional keyword arguments given to the :class:`Lattice`.
`pairs` are initialize with ``[next_]next_]nearest_neighbors``.
`positions` can be specified as a single vector.
"""
dim = 1

def __init__(self, L, site, **kwargs):
kwargs.setdefault('pairs', {})
kwargs['pairs'].setdefault('nearest_neighbors', [(0, 0, np.array())])
kwargs['pairs'].setdefault('next_nearest_neighbors', [(0, 0, np.array())])
kwargs['pairs'].setdefault('next_next_nearest_neighbors', [(0, 0, np.array())])
# and otherwise default values.
SimpleLattice.__init__(self, [L], site, **kwargs)

def ordering(self, order):
"""Provide possible orderings of the `N` lattice sites.

The following orders are defined in this method compared to :meth:`Lattice.ordering`:

================== ============================================================
`order`            Resulting order
================== ============================================================
``'default'``      ``0, 1, 2, 3, 4, ... ,L-1``
------------------ ------------------------------------------------------------
``'folded'``       ``0, L-1, 1, L-2, ... , L//2``.
This order might be usefull if you want to consider a
ring with periodic boundary conditions with a finite MPS:
It avoids the ultra-long range of the coupling from site
0 to L present in the default order.
================== ============================================================
"""
if isinstance(order, str) and order == 'default' or order == 'folded':
(L, u) = self.shape
assert u == 1
ordering = np.zeros([L, 2], dtype=np.intp)
if order == 'default':
ordering[:, 0] = np.arange(L, dtype=np.intp)
elif order == 'folded':
order = []
for i in range(L // 2):
order.append(i)
order.append(L - i - 1)
if L % 2 == 1:
order.append(L // 2)
assert len(order) == L
ordering[:, 0] = np.array(order, dtype=np.intp)
else:
assert (False)  # should not be possible
return ordering
return super().ordering(order)

.. plot ::

import matplotlib.pyplot as plt
from tenpy.models import lattice
plt.figure(figsize=(5, 1.4))
ax = plt.gca()
lat.plot_coupling(ax, linewidth=3.)
lat.plot_order(ax, linestyle=':')
lat.plot_sites(ax)
ax.set_aspect('equal')
ax.set_xlim(-1.)
ax.set_ylim(-1.)
plt.show()

Parameters
----------
L : int
The length of each chain, we have 2*L sites in total.
sites : (list of) :class:`~tenpy.networks.site.Site`
The two local lattice sites making the `unit_cell` of the :class:`Lattice`.
If only a single :class:`~tenpy.networks.site.Site` is given, it is used for both chains.
**kwargs :
Additional keyword arguments given to the :class:`Lattice`.
`basis`, `pos` and `pairs` are set accordingly.
"""
dim = 1

def __init__(self, L, sites, **kwargs):
sites = _parse_sites(sites, 2)
basis = np.array([[1., 0.]])
pos = np.array([[0., 0.], [0., 1.]])
kwargs.setdefault('basis', basis)
kwargs.setdefault('positions', pos)
NN = [(0, 0, np.array()), (1, 1, np.array()), (0, 1, np.array())]
nNN = [(0, 1, np.array()), (1, 0, np.array())]
nnNN = [(0, 0, np.array()), (1, 1, np.array())]
kwargs.setdefault('pairs', {})
kwargs['pairs'].setdefault('nearest_neighbors', NN)
kwargs['pairs'].setdefault('next_nearest_neighbors', nNN)
kwargs['pairs'].setdefault('next_next_nearest_neighbors', nnNN)
Lattice.__init__(self, [L], sites, **kwargs)

class Square(SimpleLattice):
"""A square lattice.

.. plot ::

import matplotlib.pyplot as plt
from tenpy.models import lattice
plt.figure(figsize=(5, 5))
ax = plt.gca()
lat = lattice.Square(4, 4, None, bc='periodic')
lat.plot_coupling(ax, linewidth=3.)
lat.plot_order(ax, linestyle=':')
lat.plot_sites(ax)
lat.plot_basis(ax, origin=-0.5*(lat.basis + lat.basis))
ax.set_aspect('equal')
ax.set_xlim(-1)
ax.set_ylim(-1)
plt.show()

Parameters
----------
Lx, Ly : int
The length in each direction.
site : :class:`~tenpy.networks.site.Site`
The local lattice site. The `unit_cell` of the :class:`Lattice` is just ``[site]``.
**kwargs :
Additional keyword arguments given to the :class:`Lattice`.
`pairs` are set accordingly.
If `order` is specified in the form ``('standard', snake_winding, priority)``,
the `snake_winding` and `priority` should only be specified for the spatial directions.
Similarly, `positions` can be specified as a single vector.
"""
dim = 2

def __init__(self, Lx, Ly, site, **kwargs):
NN = [(0, 0, np.array([1, 0])), (0, 0, np.array([0, 1]))]
nNN = [(0, 0, np.array([1, 1])), (0, 0, np.array([1, -1]))]
nnNN = [(0, 0, np.array([2, 0])), (0, 0, np.array([0, 2]))]
kwargs.setdefault('pairs', {})
kwargs['pairs'].setdefault('nearest_neighbors', NN)
kwargs['pairs'].setdefault('next_nearest_neighbors', nNN)
kwargs['pairs'].setdefault('next_next_nearest_neighbors', nnNN)
SimpleLattice.__init__(self, [Lx, Ly], site, **kwargs)

class Triangular(SimpleLattice):
"""A triangular lattice.

.. plot ::

import matplotlib.pyplot as plt
from tenpy.models import lattice
plt.figure(figsize=(4, 5))
ax = plt.gca()
lat = lattice.Triangular(4, 4, None, bc='periodic')
lat.plot_coupling(ax, linewidth=3.)
lat.plot_order(ax, linestyle=':')
lat.plot_sites(ax)
lat.plot_basis(ax, origin=-0.5*(lat.basis + lat.basis))
ax.set_aspect('equal')
plt.show()

Parameters
----------
Lx, Ly : int
The length in each direction.
site : :class:`~tenpy.networks.site.Site`
The local lattice site. The `unit_cell` of the :class:`Lattice` is just ``[site]``.
**kwargs :
Additional keyword arguments given to the :class:`Lattice`.
`pairs` are set accordingly.
If `order` is specified in the form ``('standard', snake_windingi, priority)``,
the `snake_winding` and `priority` should only be specified for the spatial directions.
Similarly, `positions` can be specified as a single vector.
"""
dim = 2

def __init__(self, Lx, Ly, site, **kwargs):
sqrt3_half = 0.5 * np.sqrt(3)  # = cos(pi/6)
basis = np.array([[sqrt3_half, 0.5], [0., 1.]])
NN = [(0, 0, np.array([1, 0])), (0, 0, np.array([-1, 1])), (0, 0, np.array([0, -1]))]
nNN = [(0, 0, np.array([2, -1])), (0, 0, np.array([1, 1])), (0, 0, np.array([-1, 2]))]
nnNN = [(0, 0, np.array([2, 0])), (0, 0, np.array([0, 2])), (0, 0, np.array([-2, 2]))]
kwargs.setdefault('basis', basis)
kwargs.setdefault('pairs', {})
kwargs['pairs'].setdefault('nearest_neighbors', NN)
kwargs['pairs'].setdefault('next_nearest_neighbors', nNN)
kwargs['pairs'].setdefault('next_next_nearest_neighbors', nnNN)
SimpleLattice.__init__(self, [Lx, Ly], site, **kwargs)

class Honeycomb(Lattice):
"""A honeycomb lattice.

.. plot ::

import matplotlib.pyplot as plt
from tenpy.models import lattice
plt.figure(figsize=(5, 6))
ax = plt.gca()
lat = lattice.Honeycomb(4, 4, None, bc='periodic')
lat.plot_coupling(ax, linewidth=3.)
lat.plot_order(ax, linestyle=':')
lat.plot_sites(ax)
lat.plot_basis(ax, origin=-0.5*(lat.basis + lat.basis))
ax.set_aspect('equal')
ax.set_xlim(-1)
ax.set_ylim(-1)
plt.show()

Parameters
----------
Lx, Ly : int
The length in each direction.
sites : (list of) :class:`~tenpy.networks.site.Site`
The two local lattice sites making the `unit_cell` of the :class:`Lattice`.
If only a single :class:`~tenpy.networks.site.Site` is given, it is used for both sites.
**kwargs :
Additional keyword arguments given to the :class:`Lattice`.
`basis`, `pos` and `pairs` are set accordingly.
For the Honeycomb lattice ``'fourth_nearest_neighbors', 'fifth_nearest_neighbors'``
are set in :attr:`pairs`.
"""
dim = 2

def __init__(self, Lx, Ly, sites, **kwargs):
sites = _parse_sites(sites, 2)
basis = np.array(([0.5 * np.sqrt(3), 0.5], [0., 1]))
delta = np.array([1 / (2. * np.sqrt(3.)), 0.5])
pos = (-delta / 2., delta / 2)
kwargs.setdefault('basis', basis)
kwargs.setdefault('positions', pos)
NN = [(0, 1, np.array([0, 0])), (1, 0, np.array([1, 0])), (1, 0, np.array([0, 1]))]
nNN = [(0, 0, np.array([1, 0])), (0, 0, np.array([0, 1])), (0, 0, np.array([1, -1])),
(1, 1, np.array([1, 0])), (1, 1, np.array([0, 1])), (1, 1, np.array([1, -1]))]
nnNN = [(1, 0, np.array([1, 1])), (0, 1, np.array([-1, 1])), (0, 1, np.array([1, -1]))]
NN4 = [(0, 1, np.array([0, 1])), (0, 1, np.array([1, 0])), (0, 1, np.array([1, -2])),
(0, 1, np.array([0, -2])), (0, 1, np.array([-2, 0])), (0, 1, np.array([-2, 1]))]
NN5 = [(0, 0, np.array([1, 1])), (0, 0, np.array([2, -1])), (0, 0, np.array([-1, 2])),
(1, 1, np.array([1, 1])), (1, 1, np.array([2, -1])), (1, 1, np.array([-1, 2]))]
kwargs.setdefault('pairs', {})
kwargs['pairs'].setdefault('nearest_neighbors', NN)
kwargs['pairs'].setdefault('next_nearest_neighbors', nNN)
kwargs['pairs'].setdefault('next_next_nearest_neighbors', nnNN)
kwargs['pairs'].setdefault('fourth_nearest_neighbors', NN4)
kwargs['pairs'].setdefault('fifth_nearest_neighbors', NN5)
Lattice.__init__(self, [Lx, Ly], sites, **kwargs)

def ordering(self, order):
"""Provide possible orderings of the `N` lattice sites.

The following orders are defined in this method compared to :meth:`Lattice.ordering`:

================== =========================== =============================
`order`            equivalent `priority`       equivalent ``snake_winding``
================== =========================== =============================
``'default'``      (0, 2, 1)                   (False, False, False)
``'snake'``        (0, 2, 1)                   (False, True, False)
================== =========================== =============================
"""
if isinstance(order, str):
if order == "default":
priority = (0, 2, 1)
snake_winding = (False, False, False)
return get_order(self.shape, snake_winding, priority)
elif order == "snake":
priority = (0, 2, 1)
snake_winding = (False, False, True)
return get_order(self.shape, snake_winding, priority)
return super().ordering(order)

@property
def fourth_nearest_neighbors(self):
msg = ("Deprecated access with ``lattice.fourth_nearest_neighbors``.\n"
warnings.warn(msg, FutureWarning)
return self.pairs['fourth_nearest_neighbors']

@property
def fifth_nearest_neighbors(self):
msg = ("Deprecated access with ``lattice.fifth_nearest_neighbors``.\n"
warnings.warn(msg, FutureWarning)
return self.pairs['fifth_nearest_neighbors']

class Kagome(Lattice):
"""A Kagome lattice.

.. plot ::

import matplotlib.pyplot as plt
from tenpy.models import lattice
plt.figure(figsize=(5, 4))
ax = plt.gca()
lat = lattice.Kagome(4, 4, None, bc='periodic')
lat.plot_coupling(ax, linewidth=3.)
lat.plot_order(ax, linestyle=':')
lat.plot_sites(ax)
lat.plot_basis(ax, origin=-0.25*(lat.basis + lat.basis))
ax.set_aspect('equal')
ax.set_xlim(-1)
ax.set_ylim(-1)
plt.show()

Parameters
----------
Lx, Ly : int
The length in each direction.
sites : (list of) :class:`~tenpy.networks.site.Site`
The two local lattice sites making the `unit_cell` of the :class:`Lattice`.
If only a single :class:`~tenpy.networks.site.Site` is given, it is used for both sites.
**kwargs :
Additional keyword arguments given to the :class:`Lattice`.
`basis`, `pos` and `pairs` are set accordingly.
"""
dim = 2

def __init__(self, Lx, Ly, sites, **kwargs):
sites = _parse_sites(sites, 3)
#     2
#    / \
#   /   \
#  0-----1
pos = np.array([[0, 0], [1, 0], [0.5, 0.5 * 3**0.5]])
basis = [2 * pos, 2 * pos]
kwargs.setdefault('basis', basis)
kwargs.setdefault('positions', pos)
NN = [(0, 1, np.array([0, 0])), (0, 2, np.array([0, 0])), (1, 2, np.array([0, 0])),
(1, 0, np.array([1, 0])), (2, 0, np.array([0, 1])), (2, 1, np.array([-1, 1]))]
nNN = [(0, 1, np.array([0, -1])), (0, 2, np.array([1, -1])), (1, 0, np.array([1, -1])),
(1, 2, np.array([1, 0])), (2, 0, np.array([1, 0])), (2, 1, np.array([0, 1]))]
nnNN = [(0, 0, np.array([1, -1])), (1, 1, np.array([0, 1])), (2, 2, np.array([1, 0]))]
kwargs.setdefault('pairs', {})
kwargs['pairs'].setdefault('nearest_neighbors', NN)
kwargs['pairs'].setdefault('next_nearest_neighbors', nNN)
kwargs['pairs'].setdefault('next_next_nearest_neighbors', nnNN)
Lattice.__init__(self, [Lx, Ly], sites, **kwargs)

def get_lattice(lattice_name):
"""Given the name of a :class:`Lattice` class, get the lattice class itself.

Parameters
----------
lattice_name : str
Name of a :class:`Lattice` class defined in the module :mod:`~tenpy.models.lattice`,
for example ``"Chain", "Square", "Honeycomb", ...``.

Returns
-------
LatticeClass : :class:`Lattice`
The lattice class (type, not instance) specified by `lattice_name`.
"""
LatticeClass = globals()[lattice_name]
assert issubclass(LatticeClass, Lattice)
return LatticeClass

def get_order(shape, snake_winding, priority=None):
"""Built the :attr:`Lattice.order` in (Snake-) C-Style for a given lattice shape.

In this function, the word 'direction' referst to a physical direction of the lattice or the
index `u` of the unit cell as an "artificial direction".

Parameters
----------
shape : tuple of int
The shape of the lattice, i.e., the length in each direction.
snake_winding : tuple of bool
For each direction one bool, whether we should wind as a "snake" (True) in that direction
(i.e., going forth and back) or simply repeat ascending (False)
priority : ``None`` | tuple of float
If ``None`` (default), use C-Style ordering.
Otherwise, this defines the priority along which direction to wind first;
the direction with the highest priority increases fastest.
For example, "C-Style" order is enforced by ``priority=(0, 1, 2, ...)``,
and Fortrans F-style order is enforced by ``priority=(dim, dim-1, ..., 1, 0)``
group : ``None`` | tuple of tuple
If ``None`` (default), ignore it.
Otherwise, it specifies that we group the fastests changing dimension

Returns
-------
order : ndarray (np.prod(shape), len(shape))
An order of the sites for :attr:`Lattice.order` in the specified `ordering`.

--------
Lattice.ordering : method in :class:`Lattice` to obtain the order from parameters.
Lattice.plot_order : visualizes the resulting order in a :class:`Lattice`.
get_order_grouped : a variant grouping sites of the unit cell.
"""
if priority is not None:
# reduce this case to C-style order and a few permutations
perm = np.argsort(priority)
inv_perm = inverse_permutation(perm)
transp_shape = np.array(shape)[perm]
transp_snake = np.array(snake_winding)[perm]
order = get_order(transp_shape, transp_snake, None)  # in plain C-style
order = order[:, inv_perm]
return order
# simpler case: generate C-style order
shape = tuple(shape)
if not any(snake_winding):
# optimize: can use np.mgrid
res = np.mgrid[tuple([slice(0, L) for L in shape])]
return res.reshape((len(shape), np.prod(shape))).T
# some snake: generate direction by direction, each time adding a new column to `order`
snake_winding = tuple(snake_winding) + (False, )
dim = len(shape)
order = np.empty((1, 0), dtype=np.intp)
for i in range(dim):
L = shape[dim - 1 - i]
snake = snake_winding[dim - i]  # previous direction snake?
L0, D = order.shape
# insert a new first column into order
new_order = np.empty((L * L0, D + 1), dtype=np.intp)
new_order[:, 0] = np.repeat(np.arange(L), L0)
if not snake:
new_order[:, 1:] = np.tile(order, (L, 1))
else:  # snake
new_order[:L0, 1:] = order
L0_2 = 2 * L0
if L > 1:
# reverse order to go back for second index
new_order[L0:L0_2, 1:] = order[::-1, :]
if L > 2:
# repeat (ascending, descending) up to length L
rep = L // 2 - 1
if rep > 0:
new_order[L0_2:(rep + 1) * L0_2, 1:] = np.tile(new_order[:L0_2, 1:], (rep, 1))
if L % 2 == 1:
new_order[-L0:, 1:] = order
order = new_order
return order

def get_order_grouped(shape, groups):
"""Variant of :func:`get_order`, grouping some sites of the unit cell.

In this function, the word 'direction' referst to a physical direction of the lattice or the
index `u` of the unit cell as an "artificial direction".
This function is usefull for lattices with a unit cell of more than 2 sites (e.g. Kagome).
The argument `group` is a
To explain the order, assume we have a 3-site unit cell in a 2D lattice with shape
(Lx, Ly, Lu).
Calling this function with groups=((1,), (2, 0)) returns an order of the following form::

# columns: [x, y, u]
[0, 0, 1]  # first for u = 1 along y
[0, 1, 1]
:
[0, Ly-1, 1]
[0, 0, 2]  # then for u = 2 and 0
[0, 0, 0]
[0, 1, 2]
[0, 1, 0]
:
[0, Ly-1, 2]
[0, Ly-1, 0]
# and then repeat the above for increasing `x`.

Parameters
----------
shape : tuple of int
The shape of the lattice, i.e., the length in each direction.
groups : tuple of tuple of int
A partition and reordering of range(shape[-1]) into smaller groups.
The ordering goes first within a group, then along the last spatial dimensions, then
changing between different groups and finally in Cstyle order along the remaining spatial
dimensions.

Returns
-------
order : ndarray (np.prod(shape), len(shape))
An order of the sites for :attr:`Lattice.order` in the specified `ordering`.

--------
:meth:`Lattice.ordering` : method in :class:`Lattice` to obtain the order from parameters.
:meth:`Lattice.plot_order` : visualizes the resulting order in a :class:`Lattice`.
"""
Ly = shape[-2]
Lu = shape[-1]
N_sites = np.prod(shape)
# sanity check for argument group
groups = list(groups)
all = [g for gr in groups for g in gr]
all_set = set(all)
assert all_set == set(range(Lu))  # does every number appear?
assert len(all) == len(all_set) == Lu  # exactly once?
assert len(shape) > 1
rLy = np.arange(Ly)
pre_order = np.empty((Ly * Lu, 2), dtype=np.intp)
start = 0
for gr in groups:
gr = np.array(gr)
Lgr = len(gr)
end = start + Lgr * Ly
pre_order[start:end, 0] = np.repeat(rLy, Lgr)
pre_order[start:end, 1] = np.tile(gr, Ly)
start = end
other_order = get_order(shape[:-2], [False])
order = np.empty((N_sites, len(shape)), dtype=np.intp)
order[:, :-2] = np.repeat(other_order, Ly * Lu, axis=0)
order[:, -2:] = np.tile(pre_order, (N_sites // (Ly * Lu), 1))
return order

def _parse_sites(sites, expected_number):
try:  # allow to specify a single site
iter(sites)
except TypeError:
return [sites] * expected_number
if len(sites) != expected_number:
raise ValueError("need to specify a single site or exactly {0:d}, got {1:d}".format(
expected_number, len(sites)))
return sites
```