Here are the examples of the java api ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
56 Examples
19
View Complete Implementation : NDCGLoss.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
private double computeDCG(IDyadRankingInstance ranking, Map<Dyad, Integer> relevance) {
int length = ranking.length();
double dcg = 0;
for (int i = 0; i < length; i++) {
dcg += (Math.pow(2, relevance.get(ranking.getDyadAtPosition(i))) - 1) / log2(i + 2.0);
}
return dcg;
}
18
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Returns the the log of the probablity of the top ranking for a given
* {@link IDyadRankingInstance} under the Plackett Luce model parametrized by
* the latent skill values predicted by the PLNet. This may be useful as the
* probability of a particular ranking diminishes drastically with increasing
* length of the ranking.
*
* @param drInstance
* {@link IDyadRankingInstance} for which the probability is
* computed.
* @return Log of the probablity of the top ranking for a given
* {@link IDyadRankingInstance} given the Plackett Luce model
* parametrized by the skill values predicted by the PLNet.
*/
public double getLogProbabilityOfTopRanking(final IDyadRankingInstance drInstance) {
return this.getLogProbabilityOfTopKRanking(drInstance, Integer.MAX_VALUE);
}
18
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Returns the probablity of the top ranking for a given
* {@link IDyadRankingInstance} under the Plackett Luce model parametrized by
* the latent skill values predicted by the PLNet. This may be useful as the
* probability of a particular ranking diminishes drastically with increasing
* length of the ranking.
*
* @param drInstance
* {@link IDyadRankingInstance} for which the probability is
* computed.
* @return Probablity of the top ranking for a given
* {@link IDyadRankingInstance} given the Plackett Luce model
* parametrized by the skill values predicted by the PLNet.
*/
public double getProbabilityOfTopRanking(final IDyadRankingInstance drInstance) {
return this.getProbabilityOfTopKRanking(drInstance, drInstance.length());
}
18
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Returns the log of the probablity of the top k of a given
* {@link IDyadRankingInstance} under the Plackett Luce model parametrized by
* the latent skill values predicted by the PLNet. This may be useful as the
* probability of a particular ranking diminishes drastically with increasing
* length of the ranking.
*
* @param drInstance
* {@link IDyadRankingInstance} for which the probability is
* computed.
* @param k
* Number of top dyads to be considered.
* @return Log of the probablity of the top k of a the given
* {@link IDyadRankingInstance} given the Plackett Luce model
* parametrized by the skill values predicted by the PLNet.
*/
public double getLogProbabilityOfTopKRanking(final IDyadRankingInstance drInstance, final int k) {
List<Pair<Dyad, Double>> dyadUtilityPairs = this.getSortedDyadUtilityPairsForInstance(drInstance);
// compute the probability of this ranking according to the Plackett-Luce model
double currentProbability = 0;
for (int i = 0; i < Integer.min(k, dyadUtilityPairs.size()); i++) {
double sumOfRemainingSkills = 0;
for (int j = i; j < Integer.min(k, dyadUtilityPairs.size()); j++) {
sumOfRemainingSkills += Math.exp(dyadUtilityPairs.get(j).getRight());
}
currentProbability += (dyadUtilityPairs.get(i).getRight() - Math.log(sumOfRemainingSkills));
}
return currentProbability;
}
18
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Computes the logarithmic probability for a particular ranking according to
* the log Placket-Luce model.
*
* @param drInstance
* @return Logarithmic probability of the given ranking.
*/
public double getLogProbabilityRanking(final IDyadRankingInstance drInstance) {
List<Pair<Dyad, Double>> dyadUtilityPairs = this.getDyadUtilityPairsForInstance(drInstance);
// compute the probability of this ranking according to the Plackett-Luce model
double currentProbability = 0;
for (int i = 0; i < dyadUtilityPairs.size(); i++) {
double sumOfRemainingSkills = 0;
for (int j = i; j < dyadUtilityPairs.size(); j++) {
sumOfRemainingSkills += dyadUtilityPairs.get(j).getRight();
}
currentProbability += (dyadUtilityPairs.get(i).getRight() - sumOfRemainingSkills);
}
return currentProbability;
}
18
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Updates this {@link PLNetDyadRanker} based on the given {@link IInstance},
* which needs to be an {@link IDyadRankingInstance}. The update procedure is
* based on algorithm 2 in [1].
*
* @param instances
* The {@link IInstance} the update should be based on. Needs to be a
* {@link IDyadRankingInstance}.
* @throws TrainingException
* If something fails during the update process.
*/
@Override
public void update(final IDyadRankingInstance instance) throws TrainingException {
if (this.plNet == null) {
int dyadSize = (instance.getDyadAtPosition(0).getInstance().length()) + (instance.getDyadAtPosition(0).getAlternative().length());
this.plNet = this.createNetwork(dyadSize);
this.plNet.init();
}
INDArray deltaW = this.computeScaledGradient(instance);
this.plNet.params().subi(deltaW);
this.iteration++;
}
18
View Complete Implementation : AbstractDyadScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Transforms only the alternatives of each dyad in an
* {@link IDyadRankingInstance} according to the mean and standard
* deviation of the data the scaler has been fit to. The attributes with indices
* contained in ignoredIndices are not transformed. {
*
* @param dataset The dataset of which the alternatives are to be
* standardized.
* @param ignoredIndices The {@link List} of indices that are been ignored by
* the scaler.
*/
public void transformAlternatives(final IDyadRankingInstance drInstance, final List<Integer> ignoredIndices) {
for (Dyad dyad : drInstance) {
this.transformAlternatives(dyad, ignoredIndices);
}
}
18
View Complete Implementation : AbstractDyadScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Transforms only the alternatives of each dyad in a
* {@link DyadRankingDataset} according to the mean and standard
* deviation of the data the scaler has been fit to. The attributes with indices
* contained in ignoredIndices are not transformed. {
*
* @param dataset The dataset of which the alternatives are to be
* standardized.
* @param ignoredIndices The {@link List} of indices that are been ignored by
* the scaler.
*/
public void transformAlternatives(final DyadRankingDataset dataset, final List<Integer> ignoredIndices) {
for (IDyadRankingInstance instance : dataset) {
this.transformAlternatives(instance, ignoredIndices);
}
}
18
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Returns the probablity of a given {@link IDyadRankingInstance} under the
* Plackett Luce model parametrized by the latent skill values predicted by the
* PLNet.
*
* @param drInstance
* {@link IDyadRankingInstance} for which the probability is
* computed.
* @return Probability of the given {@link IDyadRankingInstance} given the
* Plackett Luce model parametrized by the skill values predicted by the
* PLNet.
*/
public double getProbabilityRanking(final IDyadRankingInstance drInstance) {
List<Pair<Dyad, Double>> dyadUtilityPairs = this.getDyadUtilityPairsForInstance(drInstance);
// compute the probability of this ranking according to the Plackett-Luce model
double currentProbability = 1;
for (int i = 0; i < dyadUtilityPairs.size(); i++) {
double sumOfRemainingSkills = 0;
for (int j = i; j < dyadUtilityPairs.size(); j++) {
sumOfRemainingSkills += Math.exp(dyadUtilityPairs.get(j).getRight());
}
if (sumOfRemainingSkills != 0) {
currentProbability *= (Math.exp(dyadUtilityPairs.get(i).getRight()) / sumOfRemainingSkills);
} else {
currentProbability = Double.NaN;
}
}
return currentProbability;
}
18
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
public double getProbabilityOfTopKRanking(final IDyadRankingInstance drInstance, final int k) {
List<Pair<Dyad, Double>> dyadUtilityPairs = this.getSortedDyadUtilityPairsForInstance(drInstance);
// compute the probability of this ranking according to the Plackett-Luce model
double currentProbability = 1;
for (int i = 0; i < Integer.min(k, dyadUtilityPairs.size()); i++) {
double sumOfRemainingSkills = 0;
for (int j = i; j < Integer.min(k, dyadUtilityPairs.size()); j++) {
sumOfRemainingSkills += Math.exp(dyadUtilityPairs.get(j).getRight());
}
if (sumOfRemainingSkills != 0) {
currentProbability *= (Math.exp(dyadUtilityPairs.get(i).getRight()) / sumOfRemainingSkills);
} else {
currentProbability = Double.NaN;
}
}
return currentProbability;
}
17
View Complete Implementation : DyadMinMaxScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Undoes the transformation of the instances of each dyad.
*
* @param dataset
* @param decimals number of decimal places for rounding
*/
public void untransformInstances(final DyadRankingDataset dataset, final int decimals) {
for (IDyadRankingInstance instance : dataset) {
for (Dyad dyad : instance) {
this.untransformInstance(dyad, decimals);
}
}
}
17
View Complete Implementation : AbstractDyadScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Transforms only the instances of each dyad in a
* {@link DyadRankingDataset} according to the mean and standard
* deviation of the data the scaler has been fit to. The attributes with indices
* contained in ignoredIndices are not transformed. {
*
* @param dataset The dataset of which the alternatives are to be
* standardized.
* @param ignoredIndices The {@link List} of indices that are been ignored by
* the scaler.
*/
public void transformInstances(final DyadRankingDataset dataset, final List<Integer> ignoredIndices) {
for (IDyadRankingInstance instance : dataset) {
if (instance instanceof SparseDyadRankingInstance) {
SparseDyadRankingInstance drSparseInstance = (SparseDyadRankingInstance) instance;
this.transformInstances(drSparseInstance, ignoredIndices);
} else if (instance instanceof DyadRankingInstance) {
DyadRankingInstance drDenseInstance = (DyadRankingInstance) instance;
this.transformInstances(drDenseInstance, ignoredIndices);
} else {
throw new IllegalArgumentException("The scalers only support SparseDyadRankingInstance and DyadRankingInstance!");
}
}
}
17
View Complete Implementation : DyadMinMaxScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Undoes the transformation of the alternatives of each dyad.
*
* @param dataset
*/
public void untransformAlternatives(final DyadRankingDataset dataset) {
for (IDyadRankingInstance instance : dataset) {
for (Dyad dyad : instance) {
this.untransformAlternative(dyad);
}
}
}
17
View Complete Implementation : DyadMinMaxScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Undoes the transformation of the instances of each dyad.
*
* @param dataset
*/
public void untransformInstances(final DyadRankingDataset dataset) {
for (IDyadRankingInstance instance : dataset) {
for (Dyad dyad : instance) {
this.untransformInstance(dyad);
}
}
}
17
View Complete Implementation : DyadMinMaxScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Undoes the transformation of the alternatives of each dyad.
*
* @param dataset
* @param decimals number of de
*/
public void untransformAlternatives(final DyadRankingDataset dataset, final int decimals) {
for (IDyadRankingInstance instance : dataset) {
for (Dyad dyad : instance) {
this.untransformAlternative(dyad, decimals);
}
}
}
17
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Converts a dyad ranking to a {@link INDArray} matrix where each row
* corresponds to a dyad.
*
* @param drInstance
* The dyad ranking to convert to a matrix.
* @return The dyad ranking in {@link INDArray} matrix form.
*/
private INDArray dyadRankingToMatrix(final IDyadRankingInstance drInstance) {
List<INDArray> dyadList = new ArrayList<>(drInstance.length());
for (Dyad dyad : drInstance) {
INDArray dyadVector = this.dyadToVector(dyad);
dyadList.add(dyadVector);
}
INDArray dyadMatrix;
dyadMatrix = Nd4j.vstack(dyadList);
return dyadMatrix;
}
17
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
private List<Pair<Dyad, Double>> getSortedDyadUtilityPairsForInstance(final IDyadRankingInstance drInstance) {
List<Pair<Dyad, Double>> dyadUtilityPairs = this.getDyadUtilityPairsForInstance(drInstance);
Collections.sort(dyadUtilityPairs, Comparator.comparing(p -> -p.getRight()));
return dyadUtilityPairs;
}
17
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
private List<Pair<Dyad, Double>> getDyadUtilityPairsForInstance(final IDyadRankingInstance drInstance) {
if (this.plNet == null) {
int dyadSize = (drInstance.getDyadAtPosition(0).getInstance().length()) + (drInstance.getDyadAtPosition(0).getAlternative().length());
this.plNet = this.createNetwork(dyadSize);
this.plNet.init();
}
List<Pair<Dyad, Double>> dyadUtilityPairs = new ArrayList<>(drInstance.length());
for (Dyad dyad : drInstance) {
INDArray plNetInput = this.dyadToVector(dyad);
double plNetOutput = this.plNet.output(plNetInput).getDouble(0);
dyadUtilityPairs.add(new Pair<Dyad, Double>(dyad, plNetOutput));
}
return dyadUtilityPairs;
}
17
View Complete Implementation : DyadRankingLossUtil.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Computes the average loss over several dyad orderings.
*
* @param lossFunction
* The loss function to be used for the individual
* {@link IDyadRankingInstance}s
* @param trueOrderings
* The true orderings represented by {@link IDyadRankingInstance}s
* @param predictedOrderings
* The predicted orderings represented by
* {@link IDyadRankingInstance}s
* @return Average loss over all {@link IDyadRankingInstance}s
*/
public static double computeAverageLoss(DyadRankingLossFunction lossFunction, DyadRankingDataset trueOrderings, DyadRankingDataset predictedOrderings) {
if (trueOrderings.size() != predictedOrderings.size()) {
throw new IllegalArgumentException("The list of predictions and the list of ground truth dyad rankings need to have the same length!");
}
double avgLoss = 0.0d;
for (int i = 0; i < trueOrderings.size(); i++) {
IDyadRankingInstance actual = trueOrderings.get(i);
IDyadRankingInstance predicted = predictedOrderings.get(i);
avgLoss += lossFunction.loss(actual, predicted);
}
avgLoss /= trueOrderings.size();
return avgLoss;
}
17
View Complete Implementation : KendallsTauDyadRankingLoss.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
private boolean isRankingCorrectForIndex(IDyadRankingInstance actual, IDyadRankingInstance predicted, int dyadRankingLength, int actualIndex, int i) {
Dyad predPairedDyad = predicted.getDyadAtPosition(i);
boolean found = false;
for (int j = actualIndex + 1; j < dyadRankingLength && !found; j++) {
if (actual.getDyadAtPosition(j).equals(predPairedDyad)) {
found = true;
}
}
return found;
}
16
View Complete Implementation : FeatureTransformPLDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public List<IDyadRankingInstance> predict(final DyadRankingDataset dataset) throws PredictionException {
List<IDyadRankingInstance> predictions = new ArrayList<>();
for (IDyadRankingInstance i : dataset) {
predictions.add(this.predict(i));
}
return predictions;
}
16
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public List<IDyadRankingInstance> predict(final DyadRankingDataset dataset) throws PredictionException {
List<IDyadRankingInstance> results = new ArrayList<>(dataset.size());
for (IDyadRankingInstance instance : dataset) {
results.add(this.predict(instance));
}
return results;
}
16
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Returns the pair of {@link Dyad}s for which the model is least certain.
*
* @param drInstance
* Ranking for which certainty should be replacedessed.
* @return The pair of {@link Dyad}s for which the model is least certain.
*/
public IDyadRankingInstance getPairWithLeastCertainty(final IDyadRankingInstance drInstance) {
if (this.plNet == null) {
int dyadSize = (drInstance.getDyadAtPosition(0).getInstance().length()) + (drInstance.getDyadAtPosition(0).getAlternative().length());
this.plNet = this.createNetwork(dyadSize);
this.plNet.init();
}
if (drInstance.length() < 2) {
throw new IllegalArgumentException("The query instance must contain at least 2 dyads!");
}
List<Pair<Dyad, Double>> dyadUtilityPairs = new ArrayList<>(drInstance.length());
for (Dyad dyad : drInstance) {
INDArray plNetInput = this.dyadToVector(dyad);
double plNetOutput = this.plNet.output(plNetInput).getDouble(0);
dyadUtilityPairs.add(new Pair<Dyad, Double>(dyad, plNetOutput));
}
// sort the instance in descending order of utility values
Collections.sort(dyadUtilityPairs, Comparator.comparing(p -> -p.getRight()));
int indexOfPairWithLeastCertainty = 0;
double currentlyLowestCertainty = Double.MAX_VALUE;
for (int i = 0; i < dyadUtilityPairs.size() - 1; i++) {
double currentCertainty = Math.abs(dyadUtilityPairs.get(i).getRight() - dyadUtilityPairs.get(i + 1).getRight());
if (currentCertainty < currentlyLowestCertainty) {
currentlyLowestCertainty = currentCertainty;
indexOfPairWithLeastCertainty = i;
}
}
List<Dyad> leastCertainDyads = new LinkedList<>();
leastCertainDyads.add(dyadUtilityPairs.get(indexOfPairWithLeastCertainty).getLeft());
leastCertainDyads.add(dyadUtilityPairs.get(indexOfPairWithLeastCertainty + 1).getLeft());
return new DyadRankingInstance(leastCertainDyads);
}
16
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public void update(final Set<IDyadRankingInstance> instances) throws TrainingException {
List<INDArray> minibatch = new ArrayList<>(instances.size());
for (IDyadRankingInstance instance : instances) {
if (this.plNet == null) {
int dyadSize = (instance.getDyadAtPosition(0).getInstance().length()) + (instance.getDyadAtPosition(0).getAlternative().length());
this.plNet = this.createNetwork(dyadSize);
this.plNet.init();
}
minibatch.add(instance.toMatrix());
}
this.updateWithMinibatch(minibatch);
}
16
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public double getCertainty(final IDyadRankingInstance queryInstance) {
if (queryInstance.length() != 2) {
throw new IllegalArgumentException("Can only provide certainty for pairs of dyads!");
}
List<Pair<Dyad, Double>> dyadUtilityPairs = new ArrayList<>(queryInstance.length());
for (Dyad dyad : queryInstance) {
INDArray plNetInput = this.dyadToVector(dyad);
double plNetOutput = this.plNet.output(plNetInput).getDouble(0);
dyadUtilityPairs.add(new Pair<Dyad, Double>(dyad, plNetOutput));
}
return Math.abs(dyadUtilityPairs.get(0).getRight() - dyadUtilityPairs.get(1).getRight());
}
16
View Complete Implementation : DyadRankingMLLossFunctionWrapper.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public double loss(IDyadRankingInstance actual, IDyadRankingInstance predicted) {
// Convert Ranking to doubles
double[] actualLabels = new double[actual.length()];
double[] predictedLabels = new double[predicted.length()];
for (int i = 0; i < actualLabels.length; i++) {
actualLabels[i] = i;
for (int j = 0; j < predictedLabels.length; j++) {
if (predicted.getDyadAtPosition(i).getAlternative().equals(actual.getDyadAtPosition(j).getAlternative())) {
predictedLabels[j] = i;
break;
}
}
}
// Compute loss
return measure.calculateMeasure(actualLabels, predictedLabels);
}
16
View Complete Implementation : KendallsTauDyadRankingLoss.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public double loss(IDyadRankingInstance actual, IDyadRankingInstance predicted) {
int dyadRankingLength = actual.length();
if (dyadRankingLength <= 1) {
throw new IllegalArgumentException("Dyad rankings must have length greater than 1.");
}
int nConc = 0;
int nDisc = 0;
for (int predIndex = 0; predIndex < dyadRankingLength - 1; predIndex++) {
Dyad predDyad = predicted.getDyadAtPosition(predIndex);
int actualIndex = -1;
for (int i = 0; i < dyadRankingLength; i++) {
if (actual.getDyadAtPosition(i).equals(predDyad)) {
actualIndex = i;
break;
}
}
for (int i = predIndex + 1; i < dyadRankingLength; i++) {
if (isRankingCorrectForIndex(actual, predicted, dyadRankingLength, actualIndex, i)) {
nConc++;
} else {
nDisc++;
}
}
}
return 2.0 * (nConc - nDisc) / (dyadRankingLength * (dyadRankingLength - 1));
}
16
View Complete Implementation : KendallsTauOfTopK.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public double loss(final IDyadRankingInstance actual, final IDyadRankingInstance predicted) {
if (this.k <= 1) {
throw new IllegalArgumentException("Dyad rankings must have length greater than 1.");
}
double kendallsDistance = 0;
for (int actualI = 0; actualI < actual.length() - 1; actualI++) {
Dyad actualDyad = actual.getDyadAtPosition(actualI);
int predictedI = -1;
for (int i = 0; i < predicted.length(); i++) {
if (predicted.getDyadAtPosition(i).equals(actualDyad)) {
predictedI = i;
break;
}
}
for (int actualJ = actualI + 1; actualJ < actual.length(); actualJ++) {
Dyad actPairedDyad = actual.getDyadAtPosition(actualJ);
int predictedJ = -1;
for (int j = 0; j < predicted.length(); j++) {
if (predicted.getDyadAtPosition(j).equals(actPairedDyad)) {
predictedJ = j;
break;
}
}
double penalty = 0;
boolean iAndJAreBothInPredictedTopK = predictedI < this.k && predictedJ < this.k;
boolean iAndJAreBothInActualTopK = actualI < this.k && actualJ < this.k;
// case 1: i,j are both in the top k list of the predicted and actual ranking
penalty = this.checkCase1(actualI, predictedI, actualJ, predictedJ, penalty, iAndJAreBothInPredictedTopK, iAndJAreBothInActualTopK);
boolean justIIsInPredictedTopK = predictedI < this.k && predictedJ >= this.k;
boolean justJIsInPredictedTopK = predictedJ < this.k && predictedI >= this.k;
boolean justIIsInActualTopK = actualI < this.k && actualJ >= this.k;
boolean justJIsInActualTopK = actualJ < this.k && actualI >= this.k;
// case 2: i,j are both in one top k ranking but for the other ranking just one
// is in the top k
penalty = this.checkCase2(actualI, predictedI, actualJ, predictedJ, penalty, iAndJAreBothInPredictedTopK, iAndJAreBothInActualTopK, justIIsInPredictedTopK, justJIsInPredictedTopK, justIIsInActualTopK, justJIsInActualTopK);
// case 3: i, but not j, appears in one top k list , and j, but not i, appears
// in the other top k list
penalty = this.checkCase3(penalty, justIIsInPredictedTopK, justJIsInPredictedTopK, justIIsInActualTopK, justJIsInActualTopK);
// case 4:
penalty = this.checkCase4(actualI, predictedI, actualJ, predictedJ, penalty, iAndJAreBothInPredictedTopK, iAndJAreBothInActualTopK);
kendallsDistance += penalty;
}
}
return kendallsDistance;
}
16
View Complete Implementation : DyadRankingFeatureTransformNegativeLogLikelihoodDerivative.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Computes the partial derivatives of every single w_i. Algorithm (19) of [1].
*
* @param i
* the index of the partial derivative.
* @param vector
* the w vector
* @return the partial derivative w_i
*/
private double computeDerivativeForIndex(final int i, final Vector vector) {
double secondSum = 0d;
int largeN = this.dataset.size();
double firstSum = 0d;
for (int smallN = 0; smallN < largeN; smallN++) {
IDyadRankingInstance instance = this.dataset.get(smallN);
int mN = instance.length();
for (int m = 0; m < mN - 1; m++) {
double innerDenumerator = 0d;
double innerNumerator = 0d;
Dyad dyad = instance.getDyadAtPosition(m);
firstSum = firstSum + this.featureTransforms.get(instance).get(dyad).getValue(i);
for (int l = m; l < mN; l++) {
Vector zNL = this.featureTransforms.get(instance).get(instance.getDyadAtPosition(l));
double dotProd = Math.exp(vector.dotProduct(zNL));
innerNumerator = innerNumerator + zNL.getValue(i) * dotProd;
innerDenumerator = innerDenumerator + dotProd;
}
if (innerDenumerator != 0) {
secondSum = secondSum + innerNumerator / innerDenumerator;
}
}
}
return -firstSum + secondSum;
}
16
View Complete Implementation : AdvancedDyadDatasetDyadRankerTester.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Test
public void testSwapOrdering1() throws PredictionException {
int maxDyadRankingLength = 4;
int nTestInstances = 100;
double avgKendallTau = 0;
for (int testInst = 0; testInst < nTestInstances; testInst++) {
IDyadRankingInstance test = DyadRankingInstanceSupplier.getDyadRankingInstance(maxDyadRankingLength, SEED);
IDyadRankingInstance predict = ranker.predict(test);
double kendallTau = new KendallsTauDyadRankingLoss().loss(test, predict);
avgKendallTau += kendallTau;
}
avgKendallTau /= nTestInstances;
replacedert.replacedertTrue(avgKendallTau >= 0.5d);
}
15
View Complete Implementation : DyadDatasetPoolProvider.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Returns the position of a dyad in the ranking over the same instance
* features. Returns -1 if the ranking does not contain the dyad.
*
* @param dyad
* @return Position of the dyad in the ranking, -1 if the ranking does not
* contain the dyad.
*/
private int getPositionInRankingByInstanceFeatures(final Dyad dyad) {
if (!this.dyadRankingsByInstances.containsKey(dyad.getInstance())) {
return -1;
} else {
IDyadRankingInstance ranking = this.dyadRankingsByInstances.get(dyad.getInstance());
boolean found = false;
int curPos = 0;
while (curPos < ranking.length() && !found) {
Dyad dyadInRanking = ranking.getDyadAtPosition(curPos);
if (dyadInRanking.equals(dyad)) {
found = true;
} else {
curPos++;
}
}
return curPos;
}
}
15
View Complete Implementation : FeatureTransformPLDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public IDyadRankingInstance predict(final IDyadRankingInstance instance) throws PredictionException {
if (this.w == null) {
throw new PredictionException("The Ranker has not been trained yet.");
}
log.debug("Training ranker with instance {}", instance);
List<Pair<Double, Dyad>> skillForDyads = new ArrayList<>();
for (Dyad d : instance) {
double skill = this.computeSkillForDyad(d);
skillForDyads.add(new Pair<Double, Dyad>(skill, d));
}
return new DyadRankingInstance(skillForDyads.stream().sorted((p1, p2) -> Double.compare(p1.getX(), p2.getX())).map(Pair::getY).collect(Collectors.toList()));
}
15
View Complete Implementation : NDCGLoss.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public double loss(IDyadRankingInstance actual, IDyadRankingInstance predicted) {
if (actual.length() <= 1) {
throw new IllegalArgumentException("Dyad rankings must have length greater than 1.");
}
if (actual.length() != predicted.length()) {
throw new IllegalArgumentException("Dyad rankings must have equal length.");
}
Map<Dyad, Integer> relevance = new HashMap<>();
for (int i = 0; i < l; i++) {
relevance.put(actual.getDyadAtPosition(i), -(i + 1));
}
double dcg = computeDCG(predicted, relevance);
double idcg = computeDCG(actual, relevance);
if (dcg != 0) {
return idcg / dcg;
} else {
return 0;
}
}
15
View Complete Implementation : TopKOfPredicted.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public double loss(final IDyadRankingInstance actual, final IDyadRankingInstance predicted) {
List<Dyad> topKDyads = new ArrayList<>();
// first derive the top k ranked dyads
for (int i = 0; i < this.k; i++) {
topKDyads.add(actual.getDyadAtPosition(i));
}
int incorrectNum = 0;
for (int i = 0; i < this.k; i++) {
Dyad topKDyadInPred = predicted.getDyadAtPosition(i);
if (!topKDyads.contains(topKDyadInPred)) {
incorrectNum++;
}
}
if (incorrectNum == 0) {
return 0.0d;
}
return ((double) incorrectNum / (double) this.k);
}
15
View Complete Implementation : DyadRankingFeatureTransformNegativeLogLikelihood.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Algorithm (18) of [1]. We adhere their notations, but, unify the sums.
*/
@Override
public double apply(final Vector w) {
double firstSum = 0d;
double secondSum = 0d;
int largeN = this.dataset.size();
for (int smallN = 0; smallN < largeN; smallN++) {
IDyadRankingInstance instance = this.dataset.get(smallN);
int mN = instance.length();
for (int m = 0; m < mN; m++) {
Dyad dyad = instance.getDyadAtPosition(m);
firstSum = firstSum + w.dotProduct(this.featureTransforms.get(instance).get(dyad));
double innerSum = 0d;
for (int l = m; l < mN - 1; l++) {
Dyad innerDyad = instance.getDyadAtPosition(l);
innerSum = innerSum + Math.exp(w.dotProduct(this.featureTransforms.get(instance).get(innerDyad)));
}
secondSum = secondSum + Math.log(innerSum);
}
}
return -firstSum + secondSum;
}
15
View Complete Implementation : AbstractDyadScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Fits the standard scaler to the dataset.
*
* @param dataset The dataset the scaler should be fit to.
*/
public void fit(final DyadRankingDataset dataset) {
int lengthX = dataset.get(0).getDyadAtPosition(0).getInstance().length();
int lengthY = dataset.get(0).getDyadAtPosition(0).getAlternative().length();
this.statsX = new SummaryStatistics[lengthX];
this.statsY = new SummaryStatistics[lengthY];
for (int i = 0; i < lengthX; i++) {
this.statsX[i] = new SummaryStatistics();
}
for (int i = 0; i < lengthY; i++) {
this.statsY[i] = new SummaryStatistics();
}
for (IDyadRankingInstance instance : dataset) {
for (Dyad dyad : instance) {
for (int i = 0; i < lengthX; i++) {
this.statsX[i].addValue(dyad.getInstance().getValue(i));
}
for (int i = 0; i < lengthY; i++) {
this.statsY[i].addValue(dyad.getAlternative().getValue(i));
}
}
}
}
15
View Complete Implementation : DyadUnitIntervalScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public void transformAlternatives(final DyadRankingDataset dataset, final List<Integer> ignoredIndices) {
int lengthY = dataset.get(0).getDyadAtPosition(0).getAlternative().length();
for (IDyadRankingInstance instance : dataset) {
for (Dyad dyad : instance) {
for (int i = 0; i < lengthY; i++) {
double value = dyad.getAlternative().getValue(i);
if (value != 0.0d) {
value /= this.lengthOfY[i];
}
dyad.getAlternative().setValue(i, value);
}
}
}
}
15
View Complete Implementation : DyadUnitIntervalScaler.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public void transformInstances(final DyadRankingDataset dataset, final List<Integer> ignoredIndices) {
int lengthX = dataset.get(0).getDyadAtPosition(0).getInstance().length();
for (IDyadRankingInstance instance : dataset) {
for (Dyad dyad : instance) {
for (int i = 0; i < lengthX; i++) {
double value = dyad.getInstance().getValue(i);
if (value != 0.0d) {
value /= this.lengthOfX[i];
}
dyad.getInstance().setValue(i, value);
}
}
}
}
15
View Complete Implementation : DyadRankerGATSPTest.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Trims the sparse dyad ranking instances by randomly selecting alternatives
* from each dyad ranking instance.
*
* @param dataset
* @param dyadRankingLength the length of the trimmed dyad ranking instances
* @param seed
* @return
*/
private static DyadRankingDataset randomlyTrimSparseDyadRankingInstances(final DyadRankingDataset dataset, final int dyadRankingLength) {
DyadRankingDataset trimmedDataset = new DyadRankingDataset();
for (IDyadRankingInstance instance : dataset) {
if (instance.length() < dyadRankingLength) {
continue;
}
ArrayList<Boolean> flagVector = new ArrayList<>(instance.length());
for (int i = 0; i < dyadRankingLength; i++) {
flagVector.add(Boolean.TRUE);
}
for (int i = dyadRankingLength; i < instance.length(); i++) {
flagVector.add(Boolean.FALSE);
}
Collections.shuffle(flagVector);
List<Vector> trimmedAlternatives = new ArrayList<>(dyadRankingLength);
for (int i = 0; i < instance.length(); i++) {
if (flagVector.get(i)) {
trimmedAlternatives.add(instance.getDyadAtPosition(i).getAlternative());
}
}
SparseDyadRankingInstance trimmedDRInstance = new SparseDyadRankingInstance(instance.getDyadAtPosition(0).getInstance(), trimmedAlternatives);
trimmedDataset.add(trimmedDRInstance);
}
return trimmedDataset;
}
15
View Complete Implementation : DyadRankingBasedNodeEvaluator.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
private List<ComponentInstance> rankRandomPipelines(final Map<Vector, ComponentInstance> randomPipelines) throws PredictionException {
List<Vector> alternatives = new ArrayList<>(randomPipelines.keySet());
/* Use a sparse instance for ranking */
SparseDyadRankingInstance toRank = new SparseDyadRankingInstance(new DenseDoubleVector(this.datasetMetaFeatures), alternatives);
IDyadRankingInstance rankedInstance;
rankedInstance = this.dyadRanker.predict(toRank);
List<ComponentInstance> rankedPipelines = new ArrayList<>();
for (Dyad dyad : rankedInstance) {
rankedPipelines.add(randomPipelines.get(dyad.getAlternative()));
}
return rankedPipelines;
}
15
View Complete Implementation : DyadDatasetPoolProvider.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Adds a {@link IDyadRankingInstance} instance to the pool.
*
* @param instance
*/
private void addDyadRankingInstance(final IDyadRankingInstance instance) {
// Add the dyad ranking instance to the pool
this.pool.add(instance);
// Add the dyad ranking instances to the hash maps
this.dyadRankingsByInstances.put(instance.getDyadAtPosition(0).getInstance(), instance);
this.dyadRankingsByAlternatives.put(instance.getDyadAtPosition(0).getAlternative(), instance);
for (Dyad dyad : instance) {
// Add all dyads to the HashMap with instance features as key
if (!this.dyadsByInstances.containsKey(dyad.getInstance())) {
this.dyadsByInstances.put(dyad.getInstance(), new HashSet<Dyad>());
}
this.dyadsByInstances.get(dyad.getInstance()).add(dyad);
// Add all dyads to the HashMap with alternative features as key
if (!this.dyadsByAlternatives.containsKey(dyad.getAlternative())) {
this.dyadsByAlternatives.put(dyad.getAlternative(), new HashSet<Dyad>());
}
this.dyadsByAlternatives.get(dyad.getAlternative()).add(dyad);
}
}
14
View Complete Implementation : FeatureTransformPLDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Computes the likelihood of the parameter vector w. Algorithm (16) of [1].
*
* @param w
* the likelihood to be computed
* @param dataset
* the dataset on which the likelihood should be evaluated
* @return the likelihood, measured as a probability
*/
private double likelihoodOfParameter(final Vector w, final DyadRankingDataset dataset) {
int largeN = dataset.size();
double outerProduct = 1.0;
for (int smallN = 0; smallN < largeN; smallN++) {
IDyadRankingInstance dyadRankingInstance = dataset.get(smallN);
int mN = dyadRankingInstance.length();
double innerProduct = 1.0;
for (int m = 0; m < mN; m++) {
Dyad dyad = dyadRankingInstance.getDyadAtPosition(m);
Vector zNM = this.featureTransform.transform(dyad);
double en = Math.exp(w.dotProduct(zNM));
double denumSum = 0;
for (int l = m; l < mN; l++) {
Dyad dyadL = dyadRankingInstance.getDyadAtPosition(l);
Vector zNL = this.featureTransform.transform(dyadL);
denumSum += Math.exp(w.dotProduct(zNL));
}
innerProduct = innerProduct * (en / denumSum);
}
outerProduct = outerProduct * innerProduct;
}
return outerProduct;
}
14
View Complete Implementation : SimpleDyadDatasetDyadRankerTester.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Test
public void testSwapOrdering1() throws PredictionException {
System.out.println("Now testing if alternative1 > alternative2");
Vector instance = new DenseDoubleVector(new double[] { 1.0, 1.0, 1.0 });
SparseDyadRankingInstance test = new SparseDyadRankingInstance(instance, Arrays.asList(alternative2, alternative1));
IDyadRankingInstance predict = this.ranker.predict(test);
replacedertEquals(new double[] { 0.0 }, predict.getDyadAtPosition(0).getAlternative().asArray());
replacedertEquals(new double[] { 1.0 }, predict.getDyadAtPosition(1).getAlternative().asArray());
}
14
View Complete Implementation : SimpleDyadDatasetDyadRankerTester.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Test
public void testSwapOrdering0() throws PredictionException {
System.out.println("Now testing if alternative2 > alternative1");
Vector instance = new DenseDoubleVector(new double[] { 1.0, 1.0, 0.0 });
SparseDyadRankingInstance test = new SparseDyadRankingInstance(instance, Arrays.asList(alternative2, alternative1));
IDyadRankingInstance predict = this.ranker.predict(test);
replacedertEquals(new double[] { 1.0 }, predict.getDyadAtPosition(0).getAlternative().asArray());
replacedertEquals(new double[] { 0.0 }, predict.getDyadAtPosition(1).getAlternative().asArray());
}
13
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public IDyadRankingInstance predict(final IDyadRankingInstance instance) throws PredictionException {
if (this.plNet == null) {
int dyadSize = (instance.getDyadAtPosition(0).getInstance().length()) + (instance.getDyadAtPosition(0).getAlternative().length());
this.plNet = this.createNetwork(dyadSize);
this.plNet.init();
}
List<Pair<Dyad, Double>> dyadUtilityPairs = new ArrayList<>(instance.length());
for (Dyad dyad : instance) {
INDArray plNetInput = this.dyadToVector(dyad);
double plNetOutput = this.plNet.output(plNetInput).getDouble(0);
dyadUtilityPairs.add(new Pair<Dyad, Double>(dyad, plNetOutput));
}
// sort the instance in descending order of utility values
Collections.sort(dyadUtilityPairs, Comparator.comparing(p -> -p.getRight()));
List<Dyad> ranking = new ArrayList<>();
for (Pair<Dyad, Double> pair : dyadUtilityPairs) {
ranking.add(pair.getLeft());
}
return new DyadRankingInstance(ranking);
}
13
View Complete Implementation : ADyadRankedNodeQueue.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public boolean add(final Node<N, V> e) {
if (this.queue.contains(e)) {
return true;
} else if (e != null) {
try {
this.logger.debug("Add node to OPEN, is Goal: {}", e.isGoal());
// characterize new node
Vector characterization = this.characterize(e);
this.nodeCharacterizations.add(characterization);
Dyad newDyad = new Dyad(this.contextCharacterization, characterization);
this.queryDyads.add(newDyad);
if (this.useScaler) {
// scale node
DyadRankingDataset dataset = new DyadRankingDataset();
dataset.add(new DyadRankingInstance(Arrays.asList(newDyad)));
this.scaler.transformAlternatives(dataset);
}
this.replaceNaNByZeroes(characterization);
// add new pairing of node and characterization
this.nodesAndCharacterizationsMap.put(e, characterization);
// predict new ranking and reorder queue accordingly
IDyadRankingInstance prediction = this.dyadRanker.predict(new DyadRankingInstance(this.queryDyads));
this.queue.clear();
for (int i = 0; i < prediction.length(); i++) {
Node<N, V> toAdd = this.nodesAndCharacterizationsMap.inverse().get(prediction.getDyadAtPosition(i).getAlternative());
if (toAdd != null) {
this.queue.add(toAdd);
} else {
this.logger.warn("Got a node in a prediction that doesnt exist");
}
}
return true;
} catch (PredictionException e1) {
this.logger.warn("Failed to characterize: {}", e1.getLocalizedMessage());
// remove unneeded characterization (ranking has failed)
this.nodeCharacterizations.remove(this.nodeCharacterizations.size() - 1);
return false;
}
} else {
return false;
}
}
12
View Complete Implementation : DyadDatasetPoolProvider.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Override
public IDyadRankingInstance query(final IDyadRankingInstance queryInstance) {
this.numberQueries++;
if (!(queryInstance instanceof SparseDyadRankingInstance)) {
throw new IllegalArgumentException("Currently only supports SparseDyadRankingInstances!");
}
SparseDyadRankingInstance drInstance = (SparseDyadRankingInstance) queryInstance;
List<Pair<Dyad, Integer>> dyadPositionPairs = new ArrayList<>(drInstance.length());
for (Dyad dyad : drInstance) {
int position = this.getPositionInRankingByInstanceFeatures(dyad);
dyadPositionPairs.add(new Pair<Dyad, Integer>(dyad, position));
}
// sort the instance in descending order of utility values
Collections.sort(dyadPositionPairs, Comparator.comparing(Pair<Dyad, Integer>::getRight));
List<Dyad> dyadList = new ArrayList<>(dyadPositionPairs.size());
for (Pair<Dyad, Integer> pair : dyadPositionPairs) {
dyadList.add(pair.getFirst());
}
DyadRankingInstance trueRanking = new DyadRankingInstance(dyadList);
if (this.removeDyadsWhenQueried) {
for (Dyad dyad : dyadList) {
this.removeDyadFromPool(dyad);
}
}
this.queriedRankings.add(trueRanking);
return trueRanking;
}
12
View Complete Implementation : DyadRankerMetaminingTest.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
@Test
public void test() {
AbstractDyadScaler scaler = new DyadUnitIntervalScaler();
Collections.shuffle(this.dataset, new Random(seed));
// split data
DyadRankingDataset trainData = new DyadRankingDataset(this.dataset.subList(0, (int) (0.7 * this.dataset.size())));
DyadRankingDataset testData = new DyadRankingDataset(this.dataset.subList((int) (0.7 * this.dataset.size()), this.dataset.size()));
// standardize data
// scaler.fit(trainData);
// scaler.transformAlternatives(trainData);
// scaler.transformAlternatives(testData);
// trainData = randomlyTrimSparseDyadRankingInstances(trainData, 2);
// testData = randomlyTrimSparseDyadRankingInstances(testData, 5);
try {
// train the ranker
this.ranker.train(trainData);
double avgKendallTau = 0.0d;
avgKendallTau = DyadRankingLossUtil.computeAverageLoss(new KendallsTauDyadRankingLoss(), testData, this.ranker);
System.out.println("Average Kendall's tau for " + this.ranker.getClreplaced().getSimpleName() + ": " + avgKendallTau);
replacedertTrue(avgKendallTau > 0.5d);
IDyadRankingInstance drInstance = testData.get(0);
List<Dyad> dyads = new LinkedList<Dyad>();
// for(int i = 0; i < drInstance.length(); i++) {
// dyads.add(drInstance.getDyadAtPosition(i));
// }
} catch (TrainingException | PredictionException e) {
e.printStackTrace();
}
}
11
View Complete Implementation : DyadRankingLossUtil.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Computes the average loss over several dyad orderings. Predictions are
* obtained by the given {@link IDyadRanker}.
*
* @param lossFunction
* The loss function to be used for the individual
* {@link IDyadRankingInstance}s
* @param trueOrderings
* The true orderings represented by {@link IDyadRankingInstance}s
* @param ranker
* The {@link IDyadRanker} used to make predictions
* @return Average loss over all {@link IDyadRankingInstance}s
*/
public static double computeAverageLoss(DyadRankingLossFunction lossFunction, DyadRankingDataset trueOrderings, IDyadRanker ranker, Random random) throws PredictionException {
double avgLoss = 0.0d;
for (int i = 0; i < trueOrderings.size(); i++) {
IDyadRankingInstance actual = trueOrderings.get(i);
// shuffle the instance such that a ranker that doesn't do anything can't come
// up with a perfect result
List<Dyad> shuffleContainer = Lists.newArrayList(actual.iterator());
Collections.shuffle(shuffleContainer, random);
IDyadRankingInstance shuffledActual = new DyadRankingInstance(shuffleContainer);
IDyadRankingInstance predicted = ranker.predict(shuffledActual);
avgLoss += lossFunction.loss(actual, predicted);
}
avgLoss /= trueOrderings.size();
return avgLoss;
}
10
View Complete Implementation : PLNetDyadRanker.java
Copyright GNU Affero General Public License v3.0
Author : fmohr
Copyright GNU Affero General Public License v3.0
Author : fmohr
/**
* Computes the gradient of the plNets' error function for a given instance. The
* returned gradient is already scaled by the updater. The update procedure is
* based on algorithm 2 in [1].
*
* @param instance
* The instance to compute the scaled gradient for.
* @return The gradient for the given instance, multiplied by the updater's
* learning rate.
*/
private INDArray computeScaledGradient(final IDyadRankingInstance instance) {
// init weight update vector
INDArray dyadMatrix;
List<INDArray> dyadList = new ArrayList<>(instance.length());
for (Dyad dyad : instance) {
INDArray dyadVector = this.dyadToVector(dyad);
dyadList.add(dyadVector);
}
dyadMatrix = this.dyadRankingToMatrix(instance);
List<INDArray> activations = this.plNet.feedForward(dyadMatrix);
INDArray output = activations.get(activations.size() - 1);
output = output.transpose();
INDArray deltaW = Nd4j.zeros(this.plNet.params().length());
Gradient deltaWk = null;
MultiLayerNetwork plNetClone = this.plNet.clone();
for (int k = 0; k < instance.length(); k++) {
// compute derivative of loss w.r.t. k
plNetClone.setInput(dyadList.get(k));
plNetClone.feedForward(true, false);
INDArray lossGradient = PLNetLoss.computeLossGradient(output, k);
// compute backprop gradient for weight updates w.r.t. k
Pair<Gradient, INDArray> p = plNetClone.backpropGradient(lossGradient, null);
deltaWk = p.getFirst();
this.plNet.getUpdater().update(this.plNet, deltaWk, this.iteration, this.epoch, 1, LayerWorkspaceMgr.noWorkspaces());
deltaW.addi(deltaWk.gradient());
}
return deltaW;
}