Python itertools.permutations() Examples

The following are 30 code examples for showing how to use itertools.permutations(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module itertools , or try the search function .

Example 1
Project: OpenFermion-Cirq   Author: quantumlib   File: fermionic_simulation_test.py    License: Apache License 2.0 6 votes vote down vote up
def assert_permute_consistent(gate):
    gate = gate.__copy__()
    n_qubits = gate.num_qubits()
    qubits = cirq.LineQubit.range(n_qubits)
    for pos in itertools.permutations(range(n_qubits)):
        permuted_gate = gate.__copy__()
        gate.permute(pos)
        assert permuted_gate.permuted(pos) == gate
        actual_unitary = cirq.unitary(permuted_gate)

        ops = [
            cca.LinearPermutationGate(n_qubits, dict(zip(range(n_qubits), pos)),
                                      ofc.FSWAP)(*qubits),
            gate(*qubits),
            cca.LinearPermutationGate(n_qubits, dict(zip(pos, range(n_qubits))),
                                      ofc.FSWAP)(*qubits)
        ]
        circuit = cirq.Circuit(ops)
        expected_unitary = cirq.unitary(circuit)
        assert np.allclose(actual_unitary, expected_unitary)

    with pytest.raises(ValueError):
        gate.permute(range(1, n_qubits))
    with pytest.raises(ValueError):
        gate.permute([1] * n_qubits) 
Example 2
Project: pydfs-lineup-optimizer   Author: DimaKudosh   File: rules.py    License: MIT License 6 votes vote down vote up
def apply(self, solver, players_dict):
        if not self.optimizer.opposing_teams_position_restriction:
            return
        for game in self.optimizer.games:
            first_team_players = {player: variable for player, variable in players_dict.items()
                                  if player.team == game.home_team}
            second_team_players = {player: variable for player, variable in players_dict.items()
                                   if player.team == game.away_team}
            for first_team_positions, second_team_positions in \
                    permutations(self.optimizer.opposing_teams_position_restriction, 2):
                first_team_variables = [variable for player, variable in first_team_players.items()
                                        if list_intersection(player.positions, first_team_positions)]
                second_team_variables = [variable for player, variable in second_team_players.items()
                                         if list_intersection(player.positions, second_team_positions)]
                for variables in product(first_team_variables, second_team_variables):
                    solver.add_constraint(variables, None, SolverSign.LTE, 1) 
Example 3
Project: pydfs-lineup-optimizer   Author: DimaKudosh   File: rules.py    License: MIT License 6 votes vote down vote up
def apply(self, solver, players_dict):
        raw_all_force_positions = self.optimizer.opposing_team_force_positions
        if not raw_all_force_positions:
            return
        all_force_positions = [tuple(sorted(positions)) for positions in raw_all_force_positions]
        for positions, total_combinations in Counter(all_force_positions).items():
            positions_vars = []
            combinations_count = 0
            for game in self.optimizer.games:
                first_team_players = {player: variable for player, variable in players_dict.items()
                                      if player.team == game.home_team}
                second_team_players = {player: variable for player, variable in players_dict.items()
                                       if player.team == game.away_team}
                for first_team_positions, second_team_positions in permutations(positions, 2):
                    first_team_variables = [variable for player, variable in first_team_players.items()
                                            if first_team_positions in player.positions]
                    second_team_variables = [variable for player, variable in second_team_players.items()
                                             if second_team_positions in player.positions]
                    for variables in product(first_team_variables, second_team_variables):
                        solver_variable = solver.add_variable('force_positions_%s_%d' % (positions, combinations_count))
                        combinations_count += 1
                        positions_vars.append(solver_variable)
                        solver.add_constraint(variables, None, SolverSign.GTE, 2 * solver_variable)
            solver.add_constraint(positions_vars, None, SolverSign.GTE, total_combinations) 
Example 4
Project: pyGSTi   Author: pyGSTio   File: listtools.py    License: Apache License 2.0 6 votes vote down vote up
def partitions(n):
    """
    Iterate over all partitions of integer `n`.

    A partition of `n` here is defined as a list of one or more non-zero
    integers which sum to `n`.  Every partition is iterated over exacty
    once - there are no duplicates/repetitions.

    Parameters
    ----------
    n : int
        The number to partition.

    Returns
    -------
    iterator
        Iterates over arrays of integers (partitions).
    """
    for p in sorted_partitions(n):
        previous = tuple()
        for pp in _itertools.permutations(p[::-1]):  # flip p so it's in *ascending* order
            if pp > previous:  # only *unique* permutations
                previous = pp  # (relies in itertools implementatin detail that
                yield pp      # any permutations of a sorted iterable are in
                # sorted order unless they are duplicates of prior permutations 
Example 5
Project: linter-pylama   Author: AtomLinter   File: config.py    License: MIT License 6 votes vote down vote up
def _validate_options(cls, options):
        """Validate the mutually exclusive options.

        Return `True` iff only zero or one of `BASE_ERROR_SELECTION_OPTIONS`
        was selected.

        """
        for opt1, opt2 in \
                itertools.permutations(cls.BASE_ERROR_SELECTION_OPTIONS, 2):
            if getattr(options, opt1) and getattr(options, opt2):
                log.error('Cannot pass both {} and {}. They are '
                          'mutually exclusive.'.format(opt1, opt2))
                return False

        if options.convention and options.convention not in conventions:
            log.error("Illegal convention '{}'. Possible conventions: {}"
                      .format(options.convention,
                              ', '.join(conventions.keys())))
            return False
        return True 
Example 6
Project: plugin.video.emby   Author: MediaBrowser   File: test_parser.py    License: GNU General Public License v3.0 6 votes vote down vote up
def test_ybd(self):
        # If we have a 4-digit year, a non-numeric month (abbreviated or not),
        # and a day (1 or 2 digits), then there is no ambiguity as to which
        # token is a year/month/day.  This holds regardless of what order the
        # terms are in and for each of the separators below.

        seps = ['-', ' ', '/', '.']

        year_tokens = ['%Y']
        month_tokens = ['%b', '%B']
        day_tokens = ['%d']
        if PLATFORM_HAS_DASH_D:
            day_tokens.append('%-d')

        prods = itertools.product(year_tokens, month_tokens, day_tokens)
        perms = [y for x in prods for y in itertools.permutations(x)]
        unambig_fmts = [sep.join(perm) for sep in seps for perm in perms]

        actual = datetime(2003, 9, 25)

        for fmt in unambig_fmts:
            dstr = actual.strftime(fmt)
            res = parse(dstr)
            self.assertEqual(res, actual) 
Example 7
Project: recruit   Author: Frank-qlu   File: test_numeric.py    License: Apache License 2.0 6 votes vote down vote up
def test_count_nonzero_axis_consistent(self):
        # Check that the axis behaviour for valid axes in
        # non-special cases is consistent (and therefore
        # correct) by checking it against an integer array
        # that is then casted to the generic object dtype
        from itertools import combinations, permutations

        axis = (0, 1, 2, 3)
        size = (5, 5, 5, 5)
        msg = "Mismatch for axis: %s"

        rng = np.random.RandomState(1234)
        m = rng.randint(-100, 100, size=size)
        n = m.astype(object)

        for length in range(len(axis)):
            for combo in combinations(axis, length):
                for perm in permutations(combo):
                    assert_equal(
                        np.count_nonzero(m, axis=perm),
                        np.count_nonzero(n, axis=perm),
                        err_msg=msg % (perm,)) 
Example 8
Project: recruit   Author: Frank-qlu   File: reshaping.py    License: Apache License 2.0 6 votes vote down vote up
def test_unstack(self, data, index, obj):
        data = data[:len(index)]
        if obj == "series":
            ser = pd.Series(data, index=index)
        else:
            ser = pd.DataFrame({"A": data, "B": data}, index=index)

        n = index.nlevels
        levels = list(range(n))
        # [0, 1, 2]
        # [(0,), (1,), (2,), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
        combinations = itertools.chain.from_iterable(
            itertools.permutations(levels, i) for i in range(1, n)
        )

        for level in combinations:
            result = ser.unstack(level=level)
            assert all(isinstance(result[col].array, type(data))
                       for col in result.columns)
            expected = ser.astype(object).unstack(level=level)
            result = result.astype(object)

            self.assert_frame_equal(result, expected) 
Example 9
Project: recruit   Author: Frank-qlu   File: test_internals.py    License: Apache License 2.0 6 votes vote down vote up
def test_equals_block_order_different_dtypes(self):
        # GH 9330

        mgr_strings = [
            "a:i8;b:f8",  # basic case
            "a:i8;b:f8;c:c8;d:b",  # many types
            "a:i8;e:dt;f:td;g:string",  # more types
            "a:i8;b:category;c:category2;d:category2",  # categories
            "c:sparse;d:sparse_na;b:f8",  # sparse
        ]

        for mgr_string in mgr_strings:
            bm = create_mgr(mgr_string)
            block_perms = itertools.permutations(bm.blocks)
            for bm_perm in block_perms:
                bm_this = BlockManager(bm_perm, bm.axes)
                assert bm.equals(bm_this)
                assert bm_this.equals(bm) 
Example 10
Project: zhusuan   Author: thu-ml   File: test_utils.py    License: MIT License 6 votes vote down vote up
def testGetBackwardOpsChain(self):
        # a -> b -> c
        a = tf.placeholder(tf.float32)
        b = tf.sqrt(a)
        c = tf.square(b)
        for n in range(4):
            for seed_tensors in permutations([a, b, c], n):
                if c in seed_tensors:
                    truth = [a.op, b.op, c.op]
                elif b in seed_tensors:
                    truth = [a.op, b.op]
                elif a in seed_tensors:
                    truth = [a.op]
                else:
                    truth = []
                self.assertEqual(get_backward_ops(seed_tensors), truth)

        self.assertEqual(get_backward_ops([c], treat_as_inputs=[b]), [c.op])
        self.assertEqual(
            get_backward_ops([b, c], treat_as_inputs=[b]), [c.op])
        self.assertEqual(
            get_backward_ops([a, c], treat_as_inputs=[b]), [a.op, c.op]) 
Example 11
Project: sspam   Author: quarkslab   File: pattern_matcher.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def visit_BoolOp(self, target, pattern):
        'Match pattern on flattened operators of same length and same type'
        conds = (type(target.op) == type(pattern.op) and
                 len(target.values) == len(pattern.values))
        if not conds:
            return False
        # try every combination wildcard <=> value
        old_context = deepcopy(self.wildcards)
        for perm in itertools.permutations(target.values):
            self.wildcards = deepcopy(old_context)
            res = True
            i = 0
            for i in range(len(pattern.values)):
                res &= self.visit(perm[i], pattern.values[i])
            if res:
                return res
        return False 
Example 12
Project: BERT-Relation-Extraction   Author: plkmo   File: infer.py    License: Apache License 2.0 6 votes vote down vote up
def get_all_sub_obj_pairs(self, sent):
        if isinstance(sent, str):
            sents_doc = self.nlp(sent)
        else:
            sents_doc = sent
        sent_ = next(sents_doc.sents)
        root = sent_.root
        #print('Root: ', root.text)
        
        subject = None; objs = []; pairs = []
        for child in root.children:
            #print(child.dep_)
            if child.dep_ in ["nsubj", "nsubjpass"]:
                if len(re.findall("[a-z]+",child.text.lower())) > 0: # filter out all numbers/symbols
                    subject = child; #print('Subject: ', child)
            elif child.dep_ in ["dobj", "attr", "prep", "ccomp"]:
                objs.append(child); #print('Object ', child)
        
        if (subject is not None) and (len(objs) > 0):
            for a, b in permutations([subject] + [obj for obj in objs], 2):
                a_ = [w for w in a.subtree]
                b_ = [w for w in b.subtree]
                pairs.append((a_[0] if (len(a_) == 1) else a_ , b_[0] if (len(b_) == 1) else b_))
                    
        return pairs 
Example 13
Project: BERT-Relation-Extraction   Author: plkmo   File: misc.py    License: Apache License 2.0 6 votes vote down vote up
def get_subject_objects(sent_):
    ### get subject, object entities by dependency tree parsing
    #sent_ = next(sents_doc.sents)
    root = sent_.root
    subject = None; objs = []; pairs = []
    for child in root.children:
        #print(child.dep_)
        if child.dep_ in ["nsubj", "nsubjpass"]:
            if len(re.findall("[a-z]+",child.text.lower())) > 0: # filter out all numbers/symbols
                subject = child; #print('Subject: ', child)
        elif child.dep_ in ["dobj", "attr", "prep", "ccomp"]:
            objs.append(child); #print('Object ', child)
    if (subject is not None) and (len(objs) > 0):
        for a, b in permutations([subject] + [obj for obj in objs], 2):
            a_ = [w for w in a.subtree]
            b_ = [w for w in b.subtree]
            pairs.append((a_[0] if (len(a_) == 1) else a_ , b_[0] if (len(b_) == 1) else b_))
            
    return pairs 
Example 14
Project: aumfor   Author: virtualrealitysystems   File: scanprof.py    License: GNU General Public License v3.0 6 votes vote down vote up
def permscan(self, address_space, offset = 0, maxlen = None):
    times = []
    # Run a warm-up scan to ensure the file is cached as much as possible
    self.oldscan(address_space, offset, maxlen)

    perms = list(itertools.permutations(self.checks))
    for i in range(len(perms)):
        self.checks = perms[i]
        print "Running scan {0}/{1}...".format(i + 1, len(perms))
        profobj = ScanProfInstance(self.oldscan, address_space, offset, maxlen)
        value = timeit.timeit(profobj, number = self.repeats)
        times.append((value, len(list(profobj.results)), i))

    print "Scan results"
    print "{0:20} | {1:7} | {2:6} | {3}".format("Time", "Results", "Perm #", "Ordering")
    for val, l, ordering in sorted(times):
        print "{0:20} | {1:7} | {2:6} | {3}".format(val, l, ordering, perms[ordering])
    sys.exit(1) 
Example 15
Project: textdistance   Author: life4   File: compression_based.py    License: MIT License 6 votes vote down vote up
def __call__(self, *sequences):
        if not sequences:
            return 0
        sequences = self._get_sequences(*sequences)

        concat_len = float('Inf')
        empty = type(sequences[0])()
        for data in permutations(sequences):
            if isinstance(empty, (str, bytes)):
                data = empty.join(data)
            else:
                data = sum(data, empty)
            concat_len = min(concat_len, self._get_size(data))

        compressed_lens = [self._get_size(s) for s in sequences]
        max_len = max(compressed_lens)
        if max_len == 0:
            return 0
        return (concat_len - min(compressed_lens) * (len(sequences) - 1)) / max_len 
Example 16
Project: vnpy_crypto   Author: birforce   File: test_numeric.py    License: MIT License 6 votes vote down vote up
def test_count_nonzero_axis_consistent(self):
        # Check that the axis behaviour for valid axes in
        # non-special cases is consistent (and therefore
        # correct) by checking it against an integer array
        # that is then casted to the generic object dtype
        from itertools import combinations, permutations

        axis = (0, 1, 2, 3)
        size = (5, 5, 5, 5)
        msg = "Mismatch for axis: %s"

        rng = np.random.RandomState(1234)
        m = rng.randint(-100, 100, size=size)
        n = m.astype(object)

        for length in range(len(axis)):
            for combo in combinations(axis, length):
                for perm in permutations(combo):
                    assert_equal(
                        np.count_nonzero(m, axis=perm),
                        np.count_nonzero(n, axis=perm),
                        err_msg=msg % (perm,)) 
Example 17
Project: vnpy_crypto   Author: birforce   File: test_internals.py    License: MIT License 6 votes vote down vote up
def test_equals_block_order_different_dtypes(self):
        # GH 9330

        mgr_strings = [
            "a:i8;b:f8",  # basic case
            "a:i8;b:f8;c:c8;d:b",  # many types
            "a:i8;e:dt;f:td;g:string",  # more types
            "a:i8;b:category;c:category2;d:category2",  # categories
            "c:sparse;d:sparse_na;b:f8",  # sparse
        ]

        for mgr_string in mgr_strings:
            bm = create_mgr(mgr_string)
            block_perms = itertools.permutations(bm.blocks)
            for bm_perm in block_perms:
                bm_this = BlockManager(bm_perm, bm.axes)
                assert bm.equals(bm_this)
                assert bm_this.equals(bm) 
Example 18
Project: aospy   Author: spencerahill   File: test_calc_basic.py    License: Apache License 2.0 5 votes vote down vote up
def test_calc_object_time_options():
    time_options = ['av', 'std', 'ts', 'reg.av', 'reg.std', 'reg.ts']
    for i in range(1, len(time_options) + 1):
        for time_option in list(itertools.permutations(time_options, i)):
            if time_option != ('None',):
                test_params_not_time_defined['dtype_out_time'] = time_option
                with pytest.raises(ValueError):
                    Calc(**test_params_not_time_defined) 
Example 19
Project: aospy   Author: spencerahill   File: test_automate.py    License: Apache License 2.0 5 votes vote down vote up
def test_prune_invalid_time_reductions(var):
    time_options = ['av', 'std', 'ts', 'reg.av', 'reg.std', 'reg.ts']
    spec = {
        'var': var,
        'dtype_out_time': None
    }
    assert _prune_invalid_time_reductions(spec) is None
    for i in range(1, len(time_options) + 1):
        for time_option in list(itertools.permutations(time_options, i)):
            spec['dtype_out_time'] = time_option
            if spec['var'].def_time:
                assert _prune_invalid_time_reductions(spec) == time_option
            else:
                assert _prune_invalid_time_reductions(spec) == [] 
Example 20
Project: ANGRYsearch   Author: DoTheEvo   File: angrysearch.py    License: GNU General Public License v2.0 5 votes vote down vote up
def like_query_adjustment(self, input_text):
        input_text = input_text.replace('"', '""')

        o = []
        p = permutations(input_text.strip().split())
        for x in p:
            o.append('"%{0}%"'.format('%'.join(x)))

        return ' OR path LIKE '.join(o)

    # FTS CHECKBOX IS CHECKED, FTS VIRTUAL TABLES ARE USED 
Example 21
Project: gcp-variant-transforms   Author: googlegenomics   File: vcf_parser_test.py    License: Apache License 2.0 5 votes vote down vote up
def test_sort_variants(self):
    sorted_variants = [
        Variant(reference_name='a', start=20, end=22),
        Variant(reference_name='a', start=20, end=22, quality=20),
        Variant(reference_name='b', start=20, end=22),
        Variant(reference_name='b', start=21, end=22),
        Variant(reference_name='b', start=21, end=23)]

    for permutation in permutations(sorted_variants):
      self.assertEqual(sorted(permutation), sorted_variants) 
Example 22
Project: pyscf   Author: pyscf   File: shci.py    License: Apache License 2.0 5 votes vote down vote up
def populate(self, array, list, value):
        dim = len(list) / 2
        up = list[:dim]
        dn = list[dim:]
        import itertools

        for t in itertools.permutations(range(dim), dim):
            updn = [up[i] for i in t] + [dn[i] for i in t]
            array[tuple(updn)] = value 
Example 23
Project: pyscf   Author: pyscf   File: util.py    License: Apache License 2.0 5 votes vote down vote up
def p_count(permutation, destination=None, debug=False):
    """
    Counts permutations.
    Args:
        permutation (iterable): a list of unique integers from 0 to N-1 or any iterable of unique entries if `normal`
        is provided;
        destination (iterable): ordered elements from `permutation`;
        debug (bool): prints debug information if True;

    Returns:
        The number of permutations needed to achieve this list from a 0..N-1 series.
    """
    if destination is None:
        destination = sorted(permutation)
    if len(permutation) != len(destination):
        raise ValueError("Permutation and destination do not match: {:d} vs {:d}".format(len(permutation), len(destination)))
    destination = dict((element, i) for i, element in enumerate(destination))
    permutation = tuple(destination[i] for i in permutation)
    visited = [False] * len(permutation)
    result = 0
    for i in range(len(permutation)):
        if not visited[i]:
            j = i
            while permutation[j] != i:
                j = permutation[j]
                result += 1
                visited[j] = True
    if debug:
        print("p_count(" + ", ".join("{:d}".format(i) for i in permutation) + ") = {:d}".format(result))
    return result 
Example 24
Project: pyscf   Author: pyscf   File: shci.py    License: Apache License 2.0 5 votes vote down vote up
def populate(self, array, list, value):
        dim = len(list) / 2
        up = list[:dim]
        dn = list[dim:]
        import itertools

        for t in itertools.permutations(range(dim), dim):
            updn = [up[i] for i in t] + [dn[i] for i in t]
            array[tuple(updn)] = value 
Example 25
Project: pyscf   Author: pyscf   File: dmrgci.py    License: Apache License 2.0 5 votes vote down vote up
def populate(self, array, list, value):
        dim=len(list)/2
        up=list[:dim]
        dn=list[dim:]
        import itertools
        for t in itertools.permutations(range(dim), dim):
          updn=[up[i] for i in t]+[dn[i] for i in t]
          array[tuple(updn)] = value 
Example 26
Project: symspellpy   Author: mammothb   File: test_editdistance.py    License: MIT License 5 votes vote down vote up
def build_test_strings():
    alphabet = "abcd"
    strings = [""]
    for i in range(1, len(alphabet) + 1):
        for combi in combinations(alphabet, i):
            strings += ["".join(p) for p in permutations(combi)]
    return strings 
Example 27
Project: pydfs-lineup-optimizer   Author: DimaKudosh   File: utils.py    License: MIT License 5 votes vote down vote up
def link_players_with_positions(
        players: List['Player'],
        positions: List[LineupPosition]
) -> Dict['Player', LineupPosition]:
    """
    This method tries to set positions for given players, and raise error if can't.
    """
    positions = positions[:]
    players_with_positions = {}  # type: Dict['Player', LineupPosition]
    players = sorted(players, key=get_player_priority)
    for position in positions:
        players_for_position = [p for p in players if list_intersection(position.positions, p.positions)]
        if len(players_for_position) == 1:
            players_with_positions[players_for_position[0]] = position
            positions.remove(position)
            players.remove(players_for_position[0])
    for players_permutation in permutations(players):
        is_correct = True
        remaining_positions = positions[:]
        for player in players_permutation:
            for position in remaining_positions:
                if list_intersection(player.positions, position.positions):
                    players_with_positions[player] = position
                    remaining_positions.remove(position)
                    break
            else:
                is_correct = False
                break
        if is_correct:
            break
    else:
        raise LineupOptimizerException('Unable to build lineup')
    return players_with_positions 
Example 28
Project: hadrian   Author: modelop   File: array.py    License: Apache License 2.0 5 votes vote down vote up
def __call__(self, state, scope, pos, paramTypes, a):
        out = []
        i = 0
        for permutation in itertools.permutations(a):
            i += 1
            if i % 1000 == 0:
                state.checkTime()
            out.append(list(permutation))
        return out 
Example 29
Project: python-control   Author: python-control   File: minreal_test.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def testMinrealBrute(self):
        for n, m, p in permutations(range(1,6), 3):
            s = matlab.rss(n, p, m)
            sr = s.minreal()
            if s.states > sr.states:
                self.nreductions += 1
            else:
                # Check to make sure that poles and zeros match

                # For poles, just look at eigenvalues of A
                np.testing.assert_array_almost_equal(
                    np.sort(eigvals(s.A)), np.sort(eigvals(sr.A)))

                # For zeros, need to extract SISO systems
                for i in range(m):
                    for j in range(p):
                        # Extract SISO dynamixs from input i to output j
                        s1 = matlab.ss(s.A, s.B[:,i], s.C[j,:], s.D[j,i])
                        s2 = matlab.ss(sr.A, sr.B[:,i], sr.C[j,:], sr.D[j,i])

                        # Check that the zeros match
                        # Note: sorting doesn't work => have to do the hard way
                        z1 = matlab.zero(s1)
                        z2 = matlab.zero(s2)

                        # Start by making sure we have the same # of zeros
                        self.assertEqual(len(z1), len(z2))

                        # Make sure all zeros in s1 are in s2
                        for zero in z1:
                            # Find the closest zero
                            self.assertAlmostEqual(min(abs(z2 - zero)), 0.)

                        # Make sure all zeros in s2 are in s1
                        for zero in z2:
                            # Find the closest zero
                            self.assertAlmostEqual(min(abs(z1 - zero)), 0.)

        # Make sure that the number of systems reduced is as expected
        # (Need to update this number if you change the seed at top of file)
        self.assertEqual(self.nreductions, 2) 
Example 30
Project: pyGSTi   Author: pyGSTio   File: compilationlibrary.py    License: Apache License 2.0 5 votes vote down vote up
def compute_connectivity_of(self, gate_name):
        """
        Compuate the connectivity (the nearest-neighbor links) for `gate_name`
        using the (compiled) gates available this library.  The result, a
        :class:`QubitGraph`, is stored in `self.connectivity[gate_name]`.

        Parameters
        ----------
        gate_name : str

        Returns
        -------
        None
        """
        nQ = int(round(_np.log2(self.model.dim)))  # assumes *unitary* mode (OK?)
        qubit_labels = self.model.state_space_labels.labels[0]
        d = {qlbl: i for i, qlbl in enumerate(qubit_labels)}
        assert(len(qubit_labels) == nQ), "Number of qubit labels is inconsistent with Model dimension!"

        connectivity = _np.zeros((nQ, nQ), dtype=bool)
        for compiled_gatelabel in self.keys():
            if compiled_gatelabel.name == gate_name:
                for p in _itertools.permutations(compiled_gatelabel.qubits, 2):
                    connectivity[d[p[0]], d[p[1]]] = True
                    # Note: d converts from qubit labels to integer indices

        self.connectivity[gate_name] = _QubitGraph(qubit_labels, connectivity)