Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 88 additions & 11 deletions src/main/java/edu/ucsd/msjava/msscorer/NewRankScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ public class NewRankScorer implements NewAdditiveScorer {
protected HashMap<Partition, Float[]> noiseErrDistTable = null;
protected HashMap<Partition, Float[]> ionExistenceTable = null;

// Caches of precomputed log scores. Populated by precomputeLogScoreTables()
// at the end of readFromInputStream. Bit-identical to the runtime
// Math.log(...) expressions they replace. Each lookup saves one
// Math.log call plus (for nodeLogTable) two HashMap.get calls per
// scoring call.
private transient HashMap<Partition, float[]> errorLogTable = null; // log(ionErr[i] / noiseErr[i])
private transient HashMap<Partition, HashMap<IonType, float[]>> nodeLogTable = null; // log(freq[i] / (noise[i] * min(ionCharge, numSegments)))

// Ion Types
private HashMap<Partition, IonType> mainIonTable;
private HashMap<Partition, IonType[]> ionTypeTable;
Expand Down Expand Up @@ -99,19 +107,30 @@ public boolean supportEdgeScores() {
}

public float getNodeScore(Partition part, IonType ionType, int rank) {
// ion score
int rankIndex = rank > maxRank ? maxRank - 1 : rank - 1;
// Fast path: precomputed log score, populated by precomputeLogScoreTables.
HashMap<IonType, float[]> ionLogs = (nodeLogTable != null) ? nodeLogTable.get(part) : null;
if (ionLogs != null) {
float[] logs = ionLogs.get(ionType);
if (logs != null && rankIndex >= 0 && rankIndex < logs.length)
return logs[rankIndex];
}
// Fallback to the original path (kept for safety during migration).
HashMap<IonType, Float[]> rankTable = rankDistTable.get(part); // rank -> probability
assert (rankTable != null);
int rankIndex = rank > maxRank ? maxRank - 1 : rank - 1;
float ionScore = getScoreFromTable(rankIndex, rankTable, ionType, false);

return ionScore;
return getScoreFromTable(rankIndex, rankTable, ionType, false);
}

public float getMissingIonScore(Partition part, IonType ionType) {
int rankIndex = maxRank;
HashMap<IonType, float[]> ionLogs = (nodeLogTable != null) ? nodeLogTable.get(part) : null;
if (ionLogs != null) {
float[] logs = ionLogs.get(ionType);
if (logs != null && rankIndex < logs.length)
return logs[rankIndex];
}
HashMap<IonType, Float[]> table = rankDistTable.get(part);
assert (table != null);
int rankIndex = maxRank;
return getScoreFromTable(rankIndex, table, ionType, false);
}

Expand All @@ -121,12 +140,14 @@ public float getErrorScore(Partition part, float error) {
errIndex = errorScalingFactor;
else if (errIndex < -errorScalingFactor)
errIndex = -errorScalingFactor;
Float[] ionErrHist = this.ionErrDistTable.get(part);
// float noiseProb = (errorScalingFactor-Math.abs(errIndex))/(errorScalingFactor*errorScalingFactor);
// if(noiseProb == 0)
// noiseProb = 1f/(errorScalingFactor*errorScalingFactor);
// return (float)Math.log(ionErrHist[errIndex+errorScalingFactor]/noiseProb);
errIndex += errorScalingFactor;
if (errorLogTable != null) {
float[] logs = errorLogTable.get(part);
if (logs != null && errIndex < logs.length)
return logs[errIndex];
}
// Fallback to the original path.
Float[] ionErrHist = this.ionErrDistTable.get(part);
Float[] noiseErrHist = this.noiseErrDistTable.get(part);
return (float) Math.log(ionErrHist[errIndex] / noiseErrHist[errIndex]);
}
Expand Down Expand Up @@ -415,11 +436,67 @@ private void readFromInputStream(InputStream is, boolean verbose) {
System.exit(-1);
}
in.close();
precomputeLogScoreTables();
} catch (IOException e) {
e.printStackTrace();
}
}

/**
* Precompute log(x/y) values that scoring methods would otherwise
* recompute on every call. The expressions match {@link #getErrorScore}
* and {@link #getScoreFromTable} exactly (same operations, same float
* rounding), so scoring results are bit-identical.
*
* Profiling on Astral showed native Math.log (libmLog) at ~5.5% of CPU
* before this cache.
*/
private void precomputeLogScoreTables() {
// --- errorLogTable: log(ionErr[i] / noiseErr[i]) per (partition, i) ---
if (ionErrDistTable != null && noiseErrDistTable != null) {
errorLogTable = new HashMap<Partition, float[]>(ionErrDistTable.size() * 2);
for (Map.Entry<Partition, Float[]> e : ionErrDistTable.entrySet()) {
Partition p = e.getKey();
Float[] ionErr = e.getValue();
Float[] noiseErr = noiseErrDistTable.get(p);
if (ionErr == null || noiseErr == null) continue;
int n = Math.min(ionErr.length, noiseErr.length);
float[] logs = new float[n];
for (int i = 0; i < n; i++)
logs[i] = (float) Math.log(ionErr[i] / noiseErr[i]);
errorLogTable.put(p, logs);
}
}

// --- nodeLogTable: log(freq[i] / (noise[i] * min(charge, numSegments))) per (partition, ionType, i) ---
if (rankDistTable != null) {
nodeLogTable = new HashMap<Partition, HashMap<IonType, float[]>>(rankDistTable.size() * 2);
for (Map.Entry<Partition, HashMap<IonType, Float[]>> pe : rankDistTable.entrySet()) {
HashMap<IonType, Float[]> ionTable = pe.getValue();
if (ionTable == null) continue;
Float[] noiseFrequencies = ionTable.get(IonType.NOISE);
if (noiseFrequencies == null) continue;
HashMap<IonType, float[]> perIon = new HashMap<IonType, float[]>(ionTable.size() * 2);
for (Map.Entry<IonType, Float[]> ie : ionTable.entrySet()) {
IonType ionType = ie.getKey();
Float[] frequencies = ie.getValue();
if (frequencies == null) continue;
int n = Math.min(frequencies.length, noiseFrequencies.length);
int chargeOrSeg = Math.min(ionType.getCharge(), numSegments);
float[] logs = new float[n];
for (int i = 0; i < n; i++) {
float ionFrequency = frequencies[i];
float noiseFrequency = noiseFrequencies[i] * chargeOrSeg;
// Match getScoreFromTable semantics exactly: guard against non-positive only in assertions.
logs[i] = (float) Math.log(ionFrequency / noiseFrequency);
}
Comment on lines +480 to +492
perIon.put(ionType, logs);
}
nodeLogTable.put(pe.getKey(), perIon);
}
}
}

// Builders
protected NewRankScorer tolerance(Tolerance mme) {
this.mme = mme;
Expand Down
Loading