Java Code Examples for it.unimi.dsi.fastutil.ints.IntSet#contains()

The following examples show how to use it.unimi.dsi.fastutil.ints.IntSet#contains() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source Project: jstarcraft-ai   File: MAPEvaluator.java    License: Apache License 2.0 6 votes vote down vote up
@Override
protected float measure(IntSet checkCollection, IntList rankList) {
    if (rankList.size() > size) {
        rankList = rankList.subList(0, size);
    }
    int count = 0;
    float map = 0F;
    for (int index = 0; index < rankList.size(); index++) {
        int itemIndex = rankList.get(index);
        if (checkCollection.contains(itemIndex)) {
            count++;
            map += 1F * count / (index + 1);
        }
    }
    return map / (checkCollection.size() < rankList.size() ? checkCollection.size() : rankList.size());
}
 
Example 2
public static KeyTestInfo generateRandomKeys(Random random, int maxNumKeys) {
  int maxKeyOrValue = maxNumKeys << 2;
  int[] keysAndValues = new int[maxNumKeys * 3];
  int[] nonKeys = new int[maxNumKeys];
  IntSet keySet = new IntOpenHashBigSet(maxNumKeys);
  for (int i = 0; i < maxNumKeys; i++) {
    int entry;
    do {
      entry = random.nextInt(maxKeyOrValue);
    } while (keySet.contains(entry));
    keysAndValues[i * 3] = entry;
    keysAndValues[i * 3 + 1] = random.nextInt(maxKeyOrValue);
    keysAndValues[i * 3 + 2] = random.nextInt(maxKeyOrValue);
    keySet.add(entry);
  }
  for (int i = 0; i < maxNumKeys; i++) {
    int nonKey;
    do {
      nonKey = random.nextInt(maxKeyOrValue);
    } while (keySet.contains(nonKey));
    nonKeys[i] = nonKey;
  }
  return new KeyTestInfo(keysAndValues, nonKeys);
}
 
Example 3
/**
 * Baseline implementation. Augments the "standard" list with alternatives.
 * 
 * @param l1
 * @param l2
 * @return
 */
public static <TK,FV> List<RichTranslation<TK,FV>> mergeAndDedup(List<RichTranslation<TK,FV>> standard,
    List<RichTranslation<TK,FV>> alt, int maxAltItems) {
  
  IntSet hashCodeSet = new IntOpenHashSet(standard.size());
  for (RichTranslation<TK,FV> s : standard) {
    hashCodeSet.add(derivationHashCode(s.getFeaturizable().derivation));
  }
  
  List<RichTranslation<TK,FV>> returnList = new ArrayList<>(standard);
  for (int i = 0, sz = Math.min(maxAltItems, alt.size()); i < sz; ++i) {
    RichTranslation<TK,FV> t = alt.get(i);
    int hashCode = derivationHashCode(t.getFeaturizable().derivation);
    if (! hashCodeSet.contains(hashCode)) returnList.add(t);
  }
  Collections.sort(returnList);
  
  return returnList;
}
 
Example 4
Source Project: deeplearning4j   File: ExpReplay.java    License: Apache License 2.0 6 votes vote down vote up
public ArrayList<Transition<A>> getBatch(int size) {
    ArrayList<Transition<A>> batch = new ArrayList<>(size);
    int storageSize = storage.size();
    int actualBatchSize = Math.min(storageSize, size);

    int[] actualIndex = new int[actualBatchSize];
    IntSet set = new IntOpenHashSet();
    for( int i=0; i<actualBatchSize; i++ ){
        int next = rnd.nextInt(storageSize);
        while(set.contains(next)){
            next = rnd.nextInt(storageSize);
        }
        set.add(next);
        actualIndex[i] = next;
    }

    for (int i = 0; i < actualBatchSize; i ++) {
        Transition<A> trans = storage.get(actualIndex[i]);
        batch.add(trans.dup());
    }

    return batch;
}
 
Example 5
Source Project: jstarcraft-rns   File: SBPRModel.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public void prepare(Configurator configuration, DataModule model, DataSpace space) {
    super.prepare(configuration, model, space);
    regBias = configuration.getFloat("recommender.bias.regularization", 0.01F);
    // cacheSpec = conf.get("guava.cache.spec",
    // "maximumSize=5000,expireAfterAccess=50m");

    itemBiases = DenseVector.valueOf(itemSize);
    itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> {
        scalar.setValue(RandomUtility.randomFloat(1F));
    });

    userItemSet = getUserItemSet(scoreMatrix);

    // TODO 考虑重构
    // find items rated by trusted neighbors only
    socialItemList = new ArrayList<>(userSize);

    for (int userIndex = 0; userIndex < userSize; userIndex++) {
        SparseVector userVector = scoreMatrix.getRowVector(userIndex);
        IntSet itemSet = userItemSet.get(userIndex);
        // find items rated by trusted neighbors only

        SparseVector socialVector = socialMatrix.getRowVector(userIndex);
        List<Integer> socialList = new LinkedList<>();
        for (VectorScalar term : socialVector) {
            int socialIndex = term.getIndex();
            userVector = scoreMatrix.getRowVector(socialIndex);
            for (VectorScalar enrty : userVector) {
                int itemIndex = enrty.getIndex();
                // v's rated items
                if (!itemSet.contains(itemIndex) && !socialList.contains(itemIndex)) {
                    socialList.add(itemIndex);
                }
            }
        }
        socialItemList.add(new ArrayList<>(socialList));
    }
}
 
Example 6
Source Project: jstarcraft-rns   File: RankingTask.java    License: Apache License 2.0 5 votes vote down vote up
@Override
protected IntList recommend(Model recommender, int userIndex) {
    ReferenceModule trainModule = trainModules[userIndex];
    ReferenceModule testModule = testModules[userIndex];
    IntSet itemSet = new IntOpenHashSet();
    for (DataInstance instance : trainModule) {
        itemSet.add(instance.getQualityFeature(itemDimension));
    }
    // TODO 此处代码需要重构
    ArrayInstance copy = new ArrayInstance(trainMarker.getQualityOrder(), trainMarker.getQuantityOrder());
    copy.copyInstance(testModule.getInstance(0));
    copy.setQualityFeature(userDimension, userIndex);

    List<Integer2FloatKeyValue> rankList = new ArrayList<>(itemSize - itemSet.size());
    for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) {
        if (itemSet.contains(itemIndex)) {
            continue;
        }
        copy.setQualityFeature(itemDimension, itemIndex);
        recommender.predict(copy);
        rankList.add(new Integer2FloatKeyValue(itemIndex, copy.getQuantityMark()));
    }
    Collections.sort(rankList, (left, right) -> {
        return Float.compare(right.getValue(), left.getValue());
    });

    IntList recommendList = new IntArrayList(rankList.size());
    for (Integer2FloatKeyValue keyValue : rankList) {
        recommendList.add(keyValue.getKey());
    }
    return recommendList;
}
 
Example 7
@Override
protected float measure(IntSet checkCollection, IntList rankList) {
    if (rankList.size() > size) {
        rankList = rankList.subList(0, size);
    }
    int count = 0;
    for (int itemIndex : rankList) {
        if (checkCollection.contains(itemIndex)) {
            count++;
        }
    }
    return count / (size + 0F);
}
 
Example 8
Source Project: jstarcraft-ai   File: RecallEvaluator.java    License: Apache License 2.0 5 votes vote down vote up
@Override
protected float measure(IntSet checkCollection, IntList rankList) {
    if (rankList.size() > size) {
        rankList = rankList.subList(0, size);
    }
    int count = 0;
    for (int itemIndex : rankList) {
        if (checkCollection.contains(itemIndex)) {
            count++;
        }
    }
    return count / (checkCollection.size() + 0F);
}
 
Example 9
private void intersect(IntSet positions, IntSet indexSet) {

        IntSet toRemove = new IntArraySet();
        for (int l : positions) {
            if (!indexSet.contains(l)) {
                toRemove.add(l);
            }
        }
        positions.removeAll(toRemove);
    }
 
Example 10
Source Project: tagme   File: PageToCategoryIDs.java    License: Apache License 2.0 5 votes vote down vote up
@Override
protected int[][] parseSet() throws IOException {
	final Int2ObjectMap<IntSet> map = new Int2ObjectOpenHashMap<IntSet>(3000000);
	final IntSet hidden= DatasetLoader.get(new HiddenCategoriesWIDs(lang));
	File input = WikipediaFiles.CAT_LINKS.getSourceFile(lang);
	final Object2IntMap<String> categories=DatasetLoader.get(new CategoriesToWIDMap(lang));
	
	SQLWikiParser parser = new SQLWikiParser(log) {
		@Override
		public boolean compute(ArrayList<String> values) throws IOException {
			String c_title=cleanPageName(values.get(SQLWikiParser.CATLINKS_TITLE_TO));
			int id=Integer.parseInt(values.get(SQLWikiParser.CATLINKS_ID_FROM));
			if(categories.containsKey(c_title) && !hidden.contains(categories.get(c_title).intValue())){
				if(map.containsKey(id)){
					map.get(id).add(categories.get(c_title).intValue());
				}else{
					IntSet set = new IntOpenHashSet();
					set.add(categories.get(c_title).intValue());
					map.put(id, set);
				}
				return true;
			} else return false;
		}
		
	};
	InputStreamReader reader = new InputStreamReader(new FileInputStream(input), Charset.forName("UTF-8"));
	parser.compute(reader);
	reader.close();
	return createDump(map);
}
 
Example 11
Source Project: samantha   File: NegativeSamplingExpander.java    License: MIT License 5 votes vote down vote up
private IntList getSampledIndices(IntSet trues, int maxVal) {
    IntList samples = new IntArrayList();
    int num = trues.size();
    if (maxNumSample != null) {
        num = maxNumSample;
    }
    for (int i=0; i<num; i++) {
        int dice = new Random().nextInt(maxVal);
        if (!trues.contains(dice)) {
            samples.add(dice);
        }
    }
    return samples;
}
 
Example 12
Source Project: RankSys   File: FastFilters.java    License: Mozilla Public License 2.0 5 votes vote down vote up
/**
 * Item filter that discards items in the training preference data.
 *
 * @param <U> type of the users
 * @param <I> type of the items
 * @param trainData preference data
 * @return item filters for each using returning true if the
 * user-item pair was not observed in the preference data
 */
public static <U, I> Function<U, IntPredicate> notInTrain(FastPreferenceData<U, I> trainData) {
    return user -> {
        IntSet set = new IntOpenHashSet();
        trainData.getUidxPreferences(trainData.user2uidx(user))
                .mapToInt(IdxPref::v1)
                .forEach(set::add);

        return iidx -> !set.contains(iidx);
    };
}
 
Example 13
private static int[] filterInts(IntSet intSet, int[] source) {
  IntList intList = new IntArrayList();
  for (int value : source) {
    if (intSet.contains(value)) {
      intList.add(value);
    }
  }
  if (intList.size() == source.length) {
    return source;
  } else {
    return intList.toIntArray();
  }
}
 
Example 14
Source Project: jstarcraft-rns   File: PRankDModel.java    License: Apache License 2.0 4 votes vote down vote up
/**
 * train model
 *
 * @throws ModelException if error occurs
 */
@Override
protected void doPractice() {
    List<IntSet> userItemSet = getUserItemSet(scoreMatrix);
    for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) {
        totalError = 0F;
        // for each rated user-item (u,i) pair
        for (int userIndex = 0; userIndex < userSize; userIndex++) {
            SparseVector userVector = scoreMatrix.getRowVector(userIndex);
            if (userVector.getElementSize() == 0) {
                continue;
            }
            IntSet itemSet = userItemSet.get(userIndex);
            for (VectorScalar term : userVector) {
                // each rated item i
                int positiveItemIndex = term.getIndex();
                float positiveScore = term.getValue();
                int negativeItemIndex = -1;
                do {
                    // draw an item j with probability proportional to
                    // popularity
                    negativeItemIndex = SampleUtility.binarySearch(itemProbabilities, 0, itemProbabilities.getElementSize() - 1, RandomUtility.randomFloat(itemProbabilities.getValue(itemProbabilities.getElementSize() - 1)));
                    // ensure that it is unrated by user u
                } while (itemSet.contains(negativeItemIndex));
                float negativeScore = 0F;
                // compute predictions
                float positivePredict = predict(userIndex, positiveItemIndex), negativePredict = predict(userIndex, negativeItemIndex);
                float distance = (float) Math.sqrt(1 - Math.tanh(itemCorrelations.getValue(positiveItemIndex, negativeItemIndex) * similarityFilter));
                float itemWeight = itemWeights.getValue(negativeItemIndex);
                float error = itemWeight * (positivePredict - negativePredict - distance * (positiveScore - negativeScore));
                totalError += error * error;

                // update vectors
                float learnFactor = learnRatio * error;
                for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) {
                    float userFactor = userFactors.getValue(userIndex, factorIndex);
                    float positiveItemFactor = itemFactors.getValue(positiveItemIndex, factorIndex);
                    float negativeItemFactor = itemFactors.getValue(negativeItemIndex, factorIndex);
                    userFactors.shiftValue(userIndex, factorIndex, -learnFactor * (positiveItemFactor - negativeItemFactor));
                    itemFactors.shiftValue(positiveItemIndex, factorIndex, -learnFactor * userFactor);
                    itemFactors.shiftValue(negativeItemIndex, factorIndex, learnFactor * userFactor);
                }
            }
        }

        totalError *= 0.5F;
        if (isConverged(epocheIndex) && isConverged) {
            break;
        }
        isLearned(epocheIndex);
        currentError = totalError;
    }
}
 
Example 15
Source Project: jstarcraft-rns   File: RankSGDModel.java    License: Apache License 2.0 4 votes vote down vote up
@Override
protected void doPractice() {
    List<IntSet> userItemSet = getUserItemSet(scoreMatrix);
    for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) {
        totalError = 0F;
        // for each rated user-item (u,i) pair
        for (MatrixScalar term : scoreMatrix) {
            int userIndex = term.getRow();
            IntSet itemSet = userItemSet.get(userIndex);
            int positiveItemIndex = term.getColumn();
            float positiveScore = term.getValue();
            int negativeItemIndex = -1;

            do {
                // draw an item j with probability proportional to
                // popularity
                negativeItemIndex = SampleUtility.binarySearch(itemProbabilities, 0, itemProbabilities.getElementSize() - 1, RandomUtility.randomFloat(itemProbabilities.getValue(itemProbabilities.getElementSize() - 1)));
                // ensure that it is unrated by user u
            } while (itemSet.contains(negativeItemIndex));

            float negativeScore = 0F;
            // compute predictions
            float error = (predict(userIndex, positiveItemIndex) - predict(userIndex, negativeItemIndex)) - (positiveScore - negativeScore);
            totalError += error * error;

            // update vectors
            float value = learnRatio * error;
            for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) {
                float userFactor = userFactors.getValue(userIndex, factorIndex);
                float positiveItemFactor = itemFactors.getValue(positiveItemIndex, factorIndex);
                float negativeItemFactor = itemFactors.getValue(negativeItemIndex, factorIndex);

                userFactors.shiftValue(userIndex, factorIndex, -value * (positiveItemFactor - negativeItemFactor));
                itemFactors.shiftValue(positiveItemIndex, factorIndex, -value * userFactor);
                itemFactors.shiftValue(negativeItemIndex, factorIndex, value * userFactor);
            }
        }

        totalError *= 0.5D;
        if (isConverged(epocheIndex) && isConverged) {
            break;
        }
        isLearned(epocheIndex);
        currentError = totalError;
    }
}
 
Example 16
Source Project: tagme   File: WikipediaEdges.java    License: Apache License 2.0 4 votes vote down vote up
@Override
protected void parseFile(File file) throws IOException
{

	final Int2IntMap redirects = DatasetLoader.get(new RedirectMap(lang));
	final IntSet disambiguations = DatasetLoader.get(new DisambiguationWIDs(lang));
	final IntSet listpages = DatasetLoader.get(new ListPageWIDs(lang));
	final IntSet ignores = DatasetLoader.get(new IgnoreWIDs(lang));
	final IntSet valids = new AllWIDs(lang).getDataset();//DatasetLoader.get(new AllWIDs(lang));
	valids.removeAll(redirects.keySet());
	//valids.removeAll(disambiguations);
	//valids.removeAll(listpages);
	valids.removeAll(ignores);
	final Object2IntMap<String> titles = DatasetLoader.get(new TitlesToWIDMap(lang));


	File tmp = Dataset.createTmpFile();
	final BufferedWriter out = new BufferedWriter(new FileWriter(tmp));
	SQLWikiParser parser = new 	SQLWikiParser(log) {
		@Override
		public boolean compute(ArrayList<String> values) throws IOException
		{
			int idFrom = Integer.parseInt(values.get(SQLWikiParser.PAGELINKS_ID_FROM));
			if (redirects.containsKey(idFrom)) idFrom = redirects.get(idFrom);
			
			int ns = Integer.parseInt(values.get(SQLWikiParser.PAGELINKS_NS));

			
			if (ns == SQLWikiParser.NS_ARTICLE && !redirects.containsKey(idFrom) && !ignores.contains(idFrom) &&
					//questo e' necessario perchè alcune pagine che sono delle liste, in inglese finiscono
					//tra le pagine di disambiguazione (per via della categoria All_set_index_articles)
					(listpages.contains(idFrom) || !disambiguations.contains(idFrom))
					//!listpages.contains(idFrom) && !disambiguations.contains(idFrom)
					&& valids.contains(idFrom)
			
			/**/ )
			{

				String titleTo = Dataset.cleanPageName(values.get(SQLWikiParser.PAGELINKS_TITLE_TO));

				int idTo = titles.getInt(titleTo);
				
				if (redirects.containsKey(idTo)) idTo = redirects.get(idTo);
				if (idTo >= 0 && !ignores.contains(idTo) && (listpages.contains(idFrom) || !disambiguations.contains(idFrom)) && valids.contains(idTo))
				{
					out.append(Integer.toString(idFrom));
					out.append(SEP_CHAR);
					out.append(Integer.toString(idTo));
					out.append('\n');
					return true;
				}
			}
			return false;
		}
	};

	File input = WikipediaFiles.PAGE_LINKS.getSourceFile(lang);
	parser.compute(input);
	out.close();

	log.info("Now sorting edges...");

	ExternalSort sorter = new ExternalSort();
	sorter.setUniq(true);
	sorter.setNumeric(true);
	sorter.setColumns(new int[]{0,1});
	sorter.setInFile(tmp.getAbsolutePath());
	sorter.setOutFile(file.getAbsolutePath());
	sorter.run();

	tmp.delete();

	log.info("Sorted. Done.");

}
 
Example 17
/**
   * Extract the n-best list.
   * 
   * @param size
   * @param distinct 
   * @return
   */
  public List<Derivation<TK,FV>> decode(int size, boolean distinct, int sourceInputId, 
      FeatureExtractor<TK, FV> featurizer, Scorer<FV> scorer, SearchHeuristic<TK, FV> heuristic, 
      OutputSpace<TK, FV> outputSpace) {
    if (isIncompleteLattice) return Collections.emptyList();
    
    List<Derivation<TK,FV>> returnList = new ArrayList<>(size);

    // WSGDEBUG

    // TODO(spenceg) Remaining bugs
    //
    //  1) Sometimes duplicate derivations can be extracted. Probably has to do with recombination.
    //
    for (int i = 0, sz = markedNodes.size(); i < sz; ++i) {
//    for (int i = 0, sz = Math.min(markedNodes.size(), size); i < sz; ++i) {
      Derivation<TK,FV> node = markedNodes.get(i);
      Derivation<TK,FV> finalDerivation = constructDerivation(node, sourceInputId, featurizer,
          scorer, heuristic, outputSpace);
      returnList.add(finalDerivation);
    }
  
    // Sort the return list
    returnList = returnList.stream().sorted().limit(size).collect(Collectors.toList());
    
    // Apply distinctness after the sort. The ordering of markedNodes doesn't account for
    // combination costs.
    if (distinct) {
      IntSet uniqSet = new IntOpenHashSet(markedNodes.size());
      List<Derivation<TK,FV>> uniqList = new ArrayList<>(returnList.size());
      for (Derivation<TK,FV> d : returnList) {
        int hashCode = d.targetSequence.hashCode();
        if (! uniqSet.contains(hashCode)) {
          uniqSet.add(hashCode);
          uniqList.add(d);
        }
      }
      returnList = uniqList;
    }
    
    // WSGDEBUG
//    System.err.printf("### %d: %d marked nodes ########%n", sourceInputId, markedNodes.size());
//    System.err.println(prefix);
//    System.err.println(oneBest);
//    System.err.println("-------");
//    returnList.stream().forEach(d -> {
//      System.err.println(d);
//    });
//    if (returnList.get(0).score < oneBest.score) {
//      System.err.println(returnList.get(0));
//      System.err.println(oneBest);
//    }
    
    return returnList;
  }