# Copyright 2019-2020 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Unit tests for strawberryfields.apps.similarity
"""
# pylint: disable=no-self-use,unused-argument,too-many-arguments
import itertools
from collections import Counter
from unittest import mock

import networkx as nx
import numpy as np
import pytest

import strawberryfields as sf
from strawberryfields.apps import similarity

pytestmark = pytest.mark.apps

# dict of orbits. accessed using total photon number K
all_orbits = {
    3: [[1, 1, 1], [2, 1], [3]],
    4: [[1, 1, 1, 1], [2, 1, 1], [3, 1], [2, 2], [4]],
    5: [[1, 1, 1, 1, 1], [2, 1, 1, 1], [3, 1, 1], [2, 2, 1], [4, 1], [3, 2], [5]],
}

all_orbits_cumulative = [o for orbs in all_orbits.values() for o in orbs]

all_events = {
    (3, 1): [3, None, None],
    (4, 1): [4, None, None, None, None],
    (5, 1): [5, None, None, None, None, None, None],
    (3, 2): [3, 3, None],
    (4, 2): [4, 4, None, 4, None],
    (5, 2): [5, 5, None, 5, None, None, None],
}


@pytest.mark.parametrize("dim", [3, 4, 5])
def test_sample_to_orbit(dim):
    """Test if function ``similarity.sample_to_orbit`` correctly returns the original orbit after
    taking all permutations over the orbit. The starting orbits are all orbits for a fixed photon
    number ``dim``."""
    orb = all_orbits[dim]
    checks = []
    for o in orb:
        sorted_sample = o.copy()
        sorted_sample_len = len(sorted_sample)
        if sorted_sample_len != dim:
            sorted_sample += [0] * sorted_sample_len
        permutations = itertools.permutations(sorted_sample)
        checks.append(all([similarity.sample_to_orbit(p) == o for p in permutations]))
    assert all(checks)


@pytest.mark.parametrize("dim", [3, 4, 5])
class TestOrbits:
    """Tests for the function ``strawberryfields.apps.similarity.orbits``"""

    def test_orbit_sum(self, dim):
        """Test if function generates orbits that are lists that sum to ``dim``."""
        assert all([sum(o) == dim for o in similarity.orbits(dim)])

    def test_orbit_sorted(self, dim):
        """Test if function generates orbits that are lists sorted in descending order."""
        assert all([o == sorted(o, reverse=True) for o in similarity.orbits(dim)])

    def test_orbits(self, dim):
        """Test if function returns all the integer partitions of 5. This test does not
        require ``similarity.orbits`` to return the orbits in any specified order."""
        partition = all_orbits[dim]
        orb = similarity.orbits(dim)

        assert sorted(partition) == sorted(orb)


@pytest.mark.parametrize("dim", [3, 4, 5])
@pytest.mark.parametrize("max_count_per_mode", [1, 2])
def test_sample_to_event(dim, max_count_per_mode):
    """Test if function ``similarity.sample_to_event`` gives the correct set of events when
    applied to all orbits with a fixed number of photons ``dim``. This test ensures that orbits
    exceeding the ``max_count_per_mode`` value are attributed the ``None`` event and that orbits
    not exceeding the ``max_count_per_mode`` are attributed the event ``dim``."""
    orb = all_orbits[dim]
    target_events = all_events[(dim, max_count_per_mode)]
    ev = [similarity.sample_to_event(o, max_count_per_mode) for o in orb]

    assert ev == target_events


class TestOrbitToSample:
    """Tests for the function ``strawberryfields.apps.similarity.orbit_to_sample``"""

    def test_low_modes(self):
        """Test if function raises a ``ValueError`` if fed an argument for ``modes`` that does
        not exceed the length of the input orbit."""
        with pytest.raises(ValueError, match="Number of modes cannot"):
            similarity.orbit_to_sample([1, 2, 3], 2)

    @pytest.mark.parametrize("orb_dim", [3, 4, 5])
    @pytest.mark.parametrize("modes_dim", [6, 7])
    def test_sample_length(self, orb_dim, modes_dim):
        """Test if function returns a sample that is of correct length ``modes_dim`` when fed a
        collision-free event of ``orb_dim`` photons."""
        samp = similarity.orbit_to_sample(all_orbits[orb_dim][0], modes_dim)
        assert len(samp) == modes_dim

    def test_sample_composition(self):
        """Test if function returns a sample that corresponds to the input orbit. Input orbits
        are orbits from ``all_orbits_cumulative``, i.e., all orbits from 3-5 photons. This test
        checks if a sample corresponds to an orbit by counting the occurrence of elements in the
        sample and comparing to a count of elements in the orbit."""
        modes = 5

        all_orbits_zeros = [
            [1, 1, 1, 0, 0],
            [2, 1, 0, 0, 0],
            [3, 0, 0, 0, 0],
            [1, 1, 1, 1, 0],
            [2, 1, 1, 0, 0],
            [3, 1, 0, 0, 0],
            [2, 2, 0, 0, 0],
            [4, 0, 0, 0, 0],
            [1, 1, 1, 1, 1],
            [2, 1, 1, 1, 0],
            [3, 1, 1, 0, 0],
            [2, 2, 1, 0, 0],
            [4, 1, 0, 0, 0],
            [3, 2, 0, 0, 0],
            [5, 0, 0, 0, 0],
        ]  # padding orbits with zeros at the end for comparison to samples

        counts = [Counter(similarity.orbit_to_sample(o, modes)) for o in all_orbits_cumulative]
        ideal_counts = [Counter(o) for o in all_orbits_zeros]

        assert counts == ideal_counts


class TestEventToSample:
    """Tests for the function ``strawberryfields.apps.similarity.event_to_sample``"""

    def test_low_count(self):
        """Test if function raises a ``ValueError`` if ``max_count_per_mode`` is negative."""
        with pytest.raises(ValueError, match="Maximum number of photons"):
            similarity.event_to_sample(2, -1, 5)

    def test_high_photon(self):
        """Test if function raises a ``ValueError`` if ``photon_number`` is so high that it
        cannot correspond to a sample given the constraints of ``max_count_per_mode`` and
        ``modes``"""
        with pytest.raises(ValueError, match="No valid samples can be generated."):
            similarity.event_to_sample(5, 1, 4)

    @pytest.mark.parametrize("photon_num", [5, 6])
    @pytest.mark.parametrize("modes_dim", [10, 11])
    @pytest.mark.parametrize("count", [3, 4])
    def test_sample_length(self, photon_num, modes_dim, count):
        """Test if function returns a sample that is of correct length ``modes_dim``."""
        samp = similarity.event_to_sample(photon_num, count, modes_dim)
        assert len(samp) == modes_dim

    @pytest.mark.parametrize("photon_num", [5, 6])
    @pytest.mark.parametrize("modes_dim", [10, 11])
    @pytest.mark.parametrize("count", [3, 4])
    def test_sample_sum(self, photon_num, modes_dim, count):
        """Test if function returns a sample that has the correct number of photons."""
        samp = similarity.event_to_sample(photon_num, count, modes_dim)
        assert sum(samp) == photon_num

    @pytest.mark.parametrize("photon_num", [5, 6])
    @pytest.mark.parametrize("modes_dim", [10, 11])
    @pytest.mark.parametrize("count", [3, 4])
    def test_sample_max_count(self, photon_num, modes_dim, count):
        """Test if function returns a sample that has maximum element not exceeding ``count``."""
        samp = similarity.event_to_sample(photon_num, count, modes_dim)
        assert max(samp) <= count


orbits = [
    [(1, 1, 2), 4, 12],
    [(1, 1), 4, 6],
    [(1, 2, 3), 4, 24],
    [(1, 1, 1, 1), 5, 5],
    [(1, 1, 2), 5, 30],
    [(1, 2, 3), 5, 60],
]


@pytest.mark.parametrize("orbit, max_photon, expected", orbits)
def test_orbit_cardinality(orbit, max_photon, expected):
    """Test if function ``strawberryfields.apps.similarity.orbit_cardinality`` returns the
    correct number of samples for some hard-coded examples."""

    assert similarity.orbit_cardinality(list(orbit), max_photon) == expected


events = [
    [5, 3, 6, 216],
    [6, 3, 6, 336],
    [5, 2, 6, 126],
    [5, 3, 7, 413],
    [6, 3, 7, 728],
    [5, 2, 7, 266],
]


@pytest.mark.parametrize("photons, max_count, modes, expected", events)
def test_event_cardinality(photons, max_count, modes, expected):
    """Test if function ``strawberryfields.apps.similarity.event_cardinality`` returns the
    correct number of samples for some hard-coded examples."""

    assert similarity.event_cardinality(photons, max_count, modes) == expected


class TestGetState:
    """Tests for the function ``strawberryfields.apps.similarity._get_state``"""

    def test_loss(self, monkeypatch):
        """Test if function correctly creates the SF program for lossy GBS."""
        graph = nx.complete_graph(5)
        mock_eng_run = mock.MagicMock()

        with monkeypatch.context() as m:
            m.setattr(sf.LocalEngine, "run", mock_eng_run)
            similarity._get_state(graph, loss=0.5)
            p_func = mock_eng_run.call_args[0][0]

        assert isinstance(p_func.circuit[1].op, sf.ops.LossChannel)

    def test_no_loss(self, monkeypatch):
        """Test if function correctly creates the SF program for GBS without loss."""
        graph = nx.complete_graph(5)
        mock_eng_run = mock.MagicMock()

        with monkeypatch.context() as m:
            m.setattr(sf.LocalEngine, "run", mock_eng_run)
            similarity._get_state(graph)
            p_func = mock_eng_run.call_args[0][0]

        assert not all([isinstance(op, sf.ops.LossChannel) for op in p_func.circuit])

    def test_max_loss(self):
        """Test if function samples from the vacuum when maximum loss is applied."""
        dim = 5
        graph = nx.complete_graph(dim)
        state = similarity._get_state(graph, 0, 1)
        cov = state.cov()
        disp = state.displacement()

        assert np.allclose(cov, 0.5 * state.hbar * np.eye(2 * dim))
        assert np.allclose(disp, np.zeros(dim))


class TestProbOrbitExact:
    """Tests for the function ``strawberryfields.apps.similarity.prob_orbit_exact``"""

    def test_invalid_n_mean(self):
        """Test if function raises a ``ValueError`` when the mean photon number is specified to
        be negative."""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Mean photon number must be non-negative"):
            similarity.prob_orbit_exact(g, [1, 1, 1, 1], n_mean=-1)

    def test_invalid_loss(self):
        """Test if function raises a ``ValueError`` when the loss parameter is specified outside
        of range."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="Loss parameter must take a value between zero and one"
        ):
            similarity.prob_orbit_exact(g, [1, 1, 1, 1], loss=2)

    def test_prob_vacuum_orbit(self):
        """Tests if the function gives the right probability for the empty orbit when the GBS
        device has been configured to have zero mean photon number."""
        graph = nx.complete_graph(5)
        assert similarity.prob_orbit_exact(graph, [], 0) == 1.0

    @pytest.mark.parametrize("k", [1, 3, 5, 7, 9])
    def test_odd_photon_numbers(self, k):
        """Test if function returns zero probability for odd number of total photons."""
        graph = nx.complete_graph(10)
        assert similarity.prob_orbit_exact(graph, [k]) == 0.0

    def test_correct_result_returned(self, monkeypatch):
        """Tests if the call to _get_state function is performed correctly and probabilities over all
        permutations of the given orbit are summed correctly. The test monkeypatches the fock_prob
        function so that the probability is the same for each sample permutation and
        is equal to 1/8. For a 4-mode graph, [1, 1] has 6 possible permutations. """
        graph = nx.complete_graph(4)
        with monkeypatch.context() as m:
            m.setattr(
                "strawberryfields.backends.BaseGaussianState.fock_prob",
                lambda *args, **kwargs: 1 / 8,
            )
            assert similarity.prob_orbit_exact(graph, [1, 1]) == 6 / 8

    def test_known_result(self):
        """Tests if the probability for known cases is correctly
        reproduced."""
        g = nx.complete_graph(3)

        p = similarity.prob_orbit_exact(g, [1, 1], 2)
        assert np.allclose(p, 0.24221526825385403)

        s = similarity._get_state(g, 2)
        temp = s.fock_prob([1, 1, 0]) + s.fock_prob([1, 0, 1]) + s.fock_prob([0, 1, 1])
        assert np.allclose(p, temp)

        assert np.allclose(similarity.prob_orbit_exact(g, [4], 1), 0)


class TestProbEventExact:
    """Tests for the function ``strawberryfields.apps.similarity.prob_event_exact``"""

    def test_invalid_n_mean(self):
        """Test if function raises a ``ValueError`` when the mean photon number is specified to
        be negative."""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Mean photon number must be non-negative"):
            similarity.prob_event_exact(g, 2, 2, n_mean=-1)

    def test_invalid_loss(self):
        """Test if function raises a ``ValueError`` when the loss parameter is specified outside
        of range."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="Loss parameter must take a value between zero and one"
        ):
            similarity.prob_event_exact(g, 2, 2, loss=2)

    def test_invalid_photon_number(self):
        """Test if function raises a ``ValueError`` when a photon number below zero is specified"""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Photon number must not be below zero"):
            similarity.prob_event_exact(g, -1, 2)

    def test_low_count(self):
        """Test if function raises a ``ValueError`` if ``max_count_per_mode`` is negative."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="Maximum number of photons per mode must be non-negative"
        ):
            similarity.prob_event_exact(g, 2, max_count_per_mode=-1)

    def test_prob_vacuum_event(self):
        """Tests if the function gives the right probability for an event with zero photons when
        the GBS device has been configured to have zero mean photon number."""
        graph = nx.complete_graph(5)
        assert similarity.prob_event_exact(graph, 0, 0, 0) == 1.0

    @pytest.mark.parametrize("k", [3, 5, 7, 9])
    @pytest.mark.parametrize("nmax", [1, 2])
    def test_odd_photon_numbers(self, k, nmax):
        """Test if function returns zero probability for odd number of total photons."""
        graph = nx.complete_graph(10)
        assert similarity.prob_event_exact(graph, k, nmax) == 0.0

    def test_correct_result_returned(self, monkeypatch):
        """Tests if the call to _get_state function is performed correctly and probabilities over all
        constituent orbits of the given event are summed correctly. The test monkeypatches the fock_prob
        function so that the probability is the same for each sample permutation of all constituent orbits
        and is equal to 1/8. For a 4-mode graph, an event with ``photon_number = 2``, and
        ``max_count_per_mode = 1`` contains orbit [1, 1] which has 6 possible sample permutations."""
        graph = nx.complete_graph(4)
        with monkeypatch.context() as m:
            m.setattr(
                "strawberryfields.backends.BaseGaussianState.fock_prob",
                lambda *args, **kwargs: 1 / 8,
            )
            assert similarity.prob_event_exact(graph, 2, 1) == 6 / 8

    def test_known_result(self):
        """Tests if the probability for known cases is correctly
        reproduced."""
        graph = nx.complete_graph(4)
        p1 = similarity.prob_event_exact(graph, 2, 2, 1)
        p2 = similarity.prob_event_exact(graph, 2, 1, 1)
        p3 = similarity.prob_orbit_exact(graph, [1, 1], 1)
        p4 = similarity.prob_event_exact(graph, 2, 2, 4)
        assert np.allclose(p1 - p2, 0)
        assert np.allclose(p2, p3)
        assert np.allclose(p4, 0.21087781178526066)


class TestProbOrbitMC:
    """Tests for the function ``strawberryfields.apps.similarity.prob_orbit_mc``"""

    def test_invalid_samples(self):
        """Test if function raises a ``ValueError`` when a number of samples less than one is
        requested."""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Number of samples must be at least one"):
            similarity.prob_orbit_mc(g, [1, 1, 1, 1], samples=0)

    def test_invalid_n_mean(self):
        """Test if function raises a ``ValueError`` when the mean photon number is specified to
        be negative."""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Mean photon number must be non-negative"):
            similarity.prob_orbit_mc(g, [1, 1, 1, 1], n_mean=-1)

    def test_invalid_loss(self):
        """Test if function raises a ``ValueError`` when the loss parameter is specified outside
        of range."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="Loss parameter must take a value between zero and one"
        ):
            similarity.prob_orbit_mc(g, [1, 1, 1, 1], loss=2)

    def test_mean_computation_orbit(self, monkeypatch):
        """Tests if the calculation of the sample mean is performed correctly. The test
        monkeypatches the fock_prob function so that the probability is the same for each sample and
        is equal to 1/5, i.e., one over the cardinality of the orbit [1,1,1,1] for 5 modes."""
        graph = nx.complete_graph(5)
        with monkeypatch.context() as m:
            m.setattr(
                "strawberryfields.backends.BaseGaussianState.fock_prob",
                lambda *args, **kwargs: 0.2,
            )
            assert np.allclose(similarity.prob_orbit_mc(graph, [1, 1, 1, 1]), 1.0)

    @pytest.mark.parametrize("k", [1, 3, 5, 7, 9])
    def test_odd_photon_numbers(self, k):
        """Test if function returns zero probability for odd number of total photons."""
        graph = nx.complete_graph(10)
        assert similarity.prob_orbit_mc(graph, [k]) == 0.0

    def test_prob_vacuum_orbit(self):
        """Tests if the function gives the right probability for the empty orbit when the GBS
        device has been configured to have zero mean photon number."""
        graph = nx.complete_graph(5)
        assert similarity.prob_orbit_mc(graph, [], 0) == 1.0

    def test_max_loss(self):
        """Test if function samples from the vacuum when maximum loss is applied."""
        graph = nx.complete_graph(5)
        assert similarity.prob_orbit_mc(graph, [1, 1, 1, 1], samples=1, loss=1) == 0.0

    def test_known_result(self):
        """Tests if the probability for known cases is correctly
        reproduced."""
        g = nx.complete_graph(3)

        assert np.allclose(similarity.prob_orbit_mc(g, [1, 1], 2), 0.2422152682538481)
        assert np.allclose(similarity.prob_orbit_mc(g, [4], 1), 0)


class TestProbEventMC:
    """Tests for the function ``strawberryfields.apps.similarity.prob_event_mc``"""

    def test_invalid_samples(self):
        """Test if function raises a ``ValueError`` when a number of samples less than one is
        requested."""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Number of samples must be at least one"):
            similarity.prob_event_mc(g, 2, 2, samples=0)

    def test_invalid_n_mean(self):
        """Test if function raises a ``ValueError`` when the mean photon number is specified to
        be negative."""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Mean photon number must be non-negative"):
            similarity.prob_event_mc(g, 2, 2, n_mean=-1)

    def test_invalid_loss(self):
        """Test if function raises a ``ValueError`` when the loss parameter is specified outside
        of range."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="Loss parameter must take a value between zero and one"
        ):
            similarity.prob_event_mc(g, 2, 2, loss=2)

    def test_invalid_photon_number(self):
        """Test if function raises a ``ValueError`` when a photon number below zero is specified"""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Photon number must not be below zero"):
            similarity.prob_event_mc(g, -1, 2)

    def test_low_count(self):
        """Test if function raises a ``ValueError`` if ``max_count_per_mode`` is negative."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="Maximum number of photons per mode must be non-negative"
        ):
            similarity.prob_event_mc(g, 2, max_count_per_mode=-1)

    def test_prob_vacuum_event(self):
        """Tests if the function gives the right probability for an event with zero photons when
        the GBS device has been configured to have zero mean photon number."""
        graph = nx.complete_graph(5)
        assert similarity.prob_event_mc(graph, 0, 0, 0) == 1.0

    def test_mean_computation_event(self, monkeypatch):
        """Tests if the calculation of the sample mean is performed correctly. The test
        monkeypatches the fock_prob function so that the probability is the same for each sample
        and is equal to 1/216, i.e., one over the number of samples in the event with 5 modes,
        6 photons, and max 3 photons per mode."""
        graph = nx.complete_graph(6)
        with monkeypatch.context() as m:
            m.setattr(
                "strawberryfields.backends.BaseGaussianState.fock_prob",
                lambda *args, **kwargs: 1.0 / 336,
            )
            assert np.allclose(similarity.prob_event_mc(graph, 6, 3), 1.0)

    @pytest.mark.parametrize("k", [3, 5, 7, 9])
    @pytest.mark.parametrize("nmax", [1, 2])
    def test_odd_photon_numbers(self, k, nmax):
        """Test if function returns zero probability for odd number of total photons."""
        graph = nx.complete_graph(10)
        assert similarity.prob_event_mc(graph, k, nmax) == 0.0

    def test_max_loss(self):
        """Test if function samples from the vacuum when maximum loss is applied."""
        graph = nx.complete_graph(6)
        assert similarity.prob_event_mc(graph, 6, 3, samples=1, loss=1) == 0.0

    def test_known_result(self):
        """Tests if the probability for known cases is correctly
        reproduced."""
        g = nx.complete_graph(4)
        p = similarity.prob_event_mc(g, 2, 1, 4)
        assert np.allclose(p, 0.2108778117852639)

        graph = nx.complete_graph(20)
        assert np.allclose(similarity.prob_event_mc(graph, 20, 1, 1, samples=10), 0)


class TestFeatureVectorOrbits:
    """Tests for the function ``strawberryfields.apps.graph.similarity.feature_vector_orbits``"""

    def test_invalid_orbits_list(self):
        """Test if function raises a ``ValueError`` when the list of orbits is empty."""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="List of orbits must have at least one orbit"):
            similarity.feature_vector_orbits(g, list_of_orbits=[])

    def test_bad_orbit_photon_numbers(self):
        """Test if function raises a ``ValueError`` when input is an orbit with numbers
        below zero."""
        graph = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Cannot request orbits with photon number below zero"):
            similarity.feature_vector_orbits(graph, [[-1, 1]])

    def test_invalid_n_mean(self):
        """Test if function raises a ``ValueError`` when the mean photon number is specified to
        be negative."""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Mean photon number must be non-negative"):
            similarity.feature_vector_orbits(g, [[1, 1], [2]], n_mean=-1)

    def test_invalid_loss(self):
        """Test if function raises a ``ValueError`` when the loss parameter is specified outside
        of range."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="Loss parameter must take a value between zero and one"
        ):
            similarity.feature_vector_orbits(g, [[1, 1], [2]], loss=2)

    def test_calls_exact_for_zero_samples(self):
        """Test if function calls the exact function for zero samples"""
        g = nx.complete_graph(5)
        assert similarity.feature_vector_orbits(
            g, [[1, 1]], samples=0
        ) == similarity.prob_orbit_exact(g, [1, 1])

    def test_correct_vector_returned(self, monkeypatch):
        """Test if function correctly constructs the feature vector. The ``prob_orbit_exact``
         and ``prob_orbit_mc`` function called within ``feature_vector_orbits`` are
         monkeypatched to return hard-coded outputs that depend only on the orbit."""

        with monkeypatch.context() as m:
            m.setattr(
                similarity,
                "prob_orbit_mc",
                lambda _graph, orbit, n_mean, samples, loss: 1.0 / sum(orbit),
            )
            graph = nx.complete_graph(8)
            assert similarity.feature_vector_orbits(graph, [[1, 1], [2, 1, 1]], samples=1) == [
                0.5,
                0.25,
            ]

        with monkeypatch.context() as m:
            m.setattr(
                similarity,
                "prob_orbit_exact",
                lambda _graph, orbit, n_mean, loss: 0.5 * (1.0 / sum(orbit)),
            )
            graph = nx.complete_graph(8)
            assert similarity.feature_vector_orbits(graph, [[1, 1], [2, 1, 1]]) == [0.25, 0.125]

    def test_known_result(self):
        """Tests if the probability for known cases is correctly
        reproduced."""
        graph = nx.complete_graph(4)
        p = similarity.feature_vector_orbits(graph, [[1, 1], [2, 1, 1]], 2)
        assert np.allclose(p, [0.22918531118334962, 0.06509669127495163])
        assert np.allclose(
            similarity.feature_vector_orbits(graph, [[2], [4]], 1, samples=10), [0, 0]
        )


class TestFeatureVectorEvents:
    """Tests for the function ``strawberryfields.apps.graph.similarity.feature_vector_events``"""

    def test_invalid_event_photon_numbers(self):
        """Test if function raises a ``ValueError`` when the list of event photons numbers
         is empty."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="List of photon numbers must have at least one element"
        ):
            similarity.feature_vector_events(g, [], 1)

    def test_invalid_n_mean(self):
        """Test if function raises a ``ValueError`` when the mean photon number is specified to
        be negative."""
        g = nx.complete_graph(5)
        with pytest.raises(ValueError, match="Mean photon number must be non-negative"):
            similarity.feature_vector_events(g, [2, 4], 2, n_mean=-1)

    def test_invalid_loss(self):
        """Test if function raises a ``ValueError`` when the loss parameter is specified outside
        of range."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="Loss parameter must take a value between zero and one"
        ):
            similarity.feature_vector_events(g, [2, 4], 2, loss=2)

    def test_bad_event_photon_numbers(self):
        """Test if function raises a ``ValueError`` when input a minimum photon number that is
        below zero."""
        with pytest.raises(ValueError, match="Cannot request events with photon number below zero"):
            graph = nx.complete_graph(5)
            similarity.feature_vector_events(graph, [-1, 4], 2)

    def test_low_count(self):
        """Test if function raises a ``ValueError`` if ``max_count_per_mode`` is negative."""
        g = nx.complete_graph(5)
        with pytest.raises(
            ValueError, match="Maximum number of photons per mode must be non-negative"
        ):
            similarity.feature_vector_events(g, [2, 4], max_count_per_mode=-1)

    def test_calls_exact_for_zero_samples(self):
        """Test if function calls the exact function for zero samples"""
        g = nx.complete_graph(5)
        assert similarity.feature_vector_events(
            g, [2], 1, samples=0
        ) == similarity.prob_event_exact(g, 2, 1)

    def test_correct_vector_returned(self, monkeypatch):
        """Test if function correctly constructs the feature vector. The ``prob_event_exact``
         and ``prob_event_mc`` function called within ``feature_vector_events`` are
         monkeypatched to return hard-coded outputs that depend only on the orbit."""

        with monkeypatch.context() as m:
            m.setattr(
                similarity,
                "prob_event_mc",
                lambda _graph, photons, max_count, n_mean, samples, loss: 1.0 / photons,
            )
            g = nx.complete_graph(8)
            assert similarity.feature_vector_events(g, [2, 4, 8], 1, samples=1) == [
                0.5,
                0.25,
                0.125,
            ]

        with monkeypatch.context() as m:
            m.setattr(
                similarity,
                "prob_event_exact",
                lambda _graph, photons, max_count, n_mean, loss: 0.5 * (1.0 / photons),
            )
            g = nx.complete_graph(8)
            assert similarity.feature_vector_events(g, [2, 4, 8], 1) == [0.25, 0.125, 0.0625]

    def test_known_result(self):
        """Tests if the probability for known cases is correctly
        reproduced."""
        g = nx.complete_graph(4)
        p = similarity.feature_vector_events(g, [2, 4], 2, 4)
        assert np.allclose(p, [0.21087781178526066, 0.11998024483275889])

        graph = nx.complete_graph(20)
        assert np.allclose(
            similarity.feature_vector_events(graph, [18, 20], 1, 1, samples=10), [0, 0]
        )


class TestFeatureVectorOrbitsSampling:
    """Tests for the function ``strawberryfields.apps.graph.similarity.feature_vector_orbits_sampling``"""

    def test_invalid_orbits_list(self):
        """Test if function raises a ``ValueError`` when the list of orbits is empty."""
        with pytest.raises(ValueError, match="List of orbits must have at least one orbit"):
            similarity.feature_vector_orbits_sampling([[1, 1, 0], [1, 0, 1]], [])

    def test_bad_orbit_photon_numbers(self):
        """Test if function raises a ``ValueError`` when input is an orbit with numbers
        below zero."""
        with pytest.raises(ValueError, match="Cannot request orbits with photon number below zero"):
            similarity.feature_vector_orbits_sampling([[1, 1, 0], [1, 0, 1]], [[-1, 1]])

    def test_correct_vector_returned(self, monkeypatch):
        """Test if function correctly constructs the feature vector corresponding to some hard
        coded samples. This test uses a set of samples, corresponding orbits, and resultant
        feature vector to test against the output of ``feature_vector_orbits_sampling``. The
        ``sample_to_orbit`` function called within ``feature_vector_orbits_sampling`` is
        monkeypatched to return the hard coded events corresponding to the samples."""
        samples_orbits_mapping = {
            (1, 1, 0, 0, 0): [1, 1],
            (1, 1, 1, 0, 0): [1, 1, 1],
            (1, 1, 1, 1, 0): [1, 1, 1, 1],
            (1, 1, 1, 1, 1): [1, 1, 1, 1, 1],
            (2, 0, 0, 0, 0): [2],
            (3, 0, 0, 0, 0): [3],
            (4, 0, 0, 0, 0): [4],
            (5, 0, 0, 0, 0): [5],
            (0, 1, 1, 0, 0): [1, 1],
        }
        samples = list(samples_orbits_mapping.keys()) + [(1, 1, 1, 1, 1)]
        list_of_orbits = [[1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1]]
        fv_true = [0.2, 0.1, 0.2]

        with monkeypatch.context() as m:
            m.setattr(similarity, "sample_to_orbit", lambda x: samples_orbits_mapping[x])
            fv = similarity.feature_vector_orbits_sampling(samples, list_of_orbits)

        assert fv_true == fv


class TestFeatureVectorEventsSampling:
    """Tests for the function ``strawberryfields.apps.graph.similarity.feature_vector_events_sampling``"""

    def test_invalid_event_photon_numbers(self):
        """Test if function raises a ``ValueError`` when the list of event photons numbers
         is empty."""
        with pytest.raises(
            ValueError, match="List of photon numbers must have at least one element"
        ):
            similarity.feature_vector_events_sampling([[1, 1, 0], [1, 0, 1]], [], 1)

    def test_bad_event_photon_numbers(self):
        """Test if function raises a ``ValueError`` when input a minimum photon number that is
        below zero."""
        with pytest.raises(ValueError, match="Cannot request events with photon number below zero"):
            similarity.feature_vector_events_sampling([[1, 1, 0], [1, 0, 1]], [-1, 4], 1)

    def test_low_count(self):
        """Test if function raises a ``ValueError`` if ``max_count_per_mode`` is negative."""
        with pytest.raises(
            ValueError, match="Maximum number of photons per mode must be non-negative"
        ):
            similarity.feature_vector_events_sampling([[1, 1, 0], [1, 0, 1]], [2, 4], -1)

    def test_correct_vector_returned(self, monkeypatch):
        """Test if function correctly constructs the feature vector corresponding to some hard
        coded samples. This test uses a set of samples, corresponding events, and resultant
        feature vector to test against the output of ``feature_vector_events_sampling``. The
        ``sample_to_event`` function called within ``feature_vector_events_sampling`` is
        monkeypatched to return the hard coded events corresponding to the samples."""
        samples_events_mapping = {  # max_count_per_mode = 1
            (1, 1, 0, 0, 0): 2,
            (1, 1, 1, 0, 0): 3,
            (1, 1, 1, 1, 0): 4,
            (1, 1, 1, 1, 1): 5,
            (2, 0, 0, 0, 0): None,
            (3, 0, 0, 0, 0): None,
            (4, 0, 0, 0, 0): None,
            (5, 0, 0, 0, 0): None,
            (0, 1, 1, 0, 0): 2,
        }
        samples = list(samples_events_mapping.keys()) + [(1, 1, 1, 1, 1)]  # add a repetition
        event_photon_numbers = [2, 1, 3, 5]  # test alternative ordering
        fv_true = [0.2, 0, 0.1, 0.2]

        with monkeypatch.context() as m:
            m.setattr(similarity, "sample_to_event", lambda x, _: samples_events_mapping[x])
            fv = similarity.feature_vector_events_sampling(samples, event_photon_numbers, 1)

        assert fv_true == fv