diff --git a/frustratometer/classes/AWSEM.py b/frustratometer/classes/AWSEM.py index 770f4ad..b755de8 100644 --- a/frustratometer/classes/AWSEM.py +++ b/frustratometer/classes/AWSEM.py @@ -57,6 +57,7 @@ def __init__(self, pdb_structure: object, sequence: str =None, expose_indicator_functions: bool=False, + sparse: bool=False, **parameters)->object: """ Generate AWSEM object diff --git a/frustratometer/classes/Gamma.py b/frustratometer/classes/Gamma.py index 821bffe..76da501 100644 --- a/frustratometer/classes/Gamma.py +++ b/frustratometer/classes/Gamma.py @@ -26,7 +26,6 @@ def __init__(self, data, segment_definition=None, description=None, alphabet=Non self._init_from_file(data) else: raise TypeError("Unsupported type for initializing Gamma.") - print(self.gamma_array) self.alphabet = alphabet if alphabet is not None else self.default_alphabet.copy() self.segment_definition = segment_definition if segment_definition is not None else self.default_segment_definition.copy() diff --git a/frustratometer/frustration/frustration.py b/frustratometer/frustration/frustration.py index 81ec880..bffb50e 100644 --- a/frustratometer/frustration/frustration.py +++ b/frustratometer/frustration/frustration.py @@ -796,7 +796,7 @@ def compute_auc(roc_score): auc : float AUC value """ - fpr, tpr = roc + fpr, tpr = roc_score auc = np.sum(tpr[:-1] * (fpr[1:] - fpr[:-1])) return auc @@ -811,7 +811,7 @@ def plot_roc(roc_score): Array containing lists of false and true positive rates """ import matplotlib.pyplot as plt - plt.plot(roc[0], roc[1]) + plt.plot(roc_score[0], roc_score[1]) plt.xlabel('False positive rate (1-specificity)') plt.ylabel('True positive rate (sensiticity)') plt.suptitle('Receiver operating characteristic') diff --git a/frustratometer/optimization/EnergyTerm.py b/frustratometer/optimization/EnergyTerm.py index 6cd87bc..dfcb322 100644 --- a/frustratometer/optimization/EnergyTerm.py +++ b/frustratometer/optimization/EnergyTerm.py @@ -80,45 +80,65 @@ def dummy_decorator(func, *args, **kwargs): return func @property - #@lru_cache(maxsize=None) def energies_function(self): - """ Returns the energy function as a numba dispatcher. """ - energy_function = self.energy_function - def compute_energies(seq_indices:np.ndarray): - """Compute the energies of multiple sequences.""" - energies = np.zeros(len(seq_indices)) - for i in numba.prange(len(seq_indices)): - energies[i] = energy_function(seq_indices[i]) - return energies - - if self.use_numba: - return numba.njit(types.Array(types.float64, 1, 'C')(types.Array(types.int64, 2, 'A', readonly=True)), parallel=True)(compute_energies) - else: - return compute_energies + if not hasattr(self, '_energies_function_cache'): + energy_function = self.energy_function + + def compute_energies(seq_indices: np.ndarray): + energies = np.zeros(len(seq_indices)) + for i in numba.prange(len(seq_indices)): + energies[i] = energy_function(seq_indices[i]) + return energies + + if self.use_numba: + self._energies_function_cache = numba.njit( + types.Array(types.float64, 1, 'C')(types.Array(types.int64, 2, 'A', readonly=True)), + parallel=True + )(compute_energies) + else: + self._energies_function_cache = compute_energies + return self._energies_function_cache @property - #@lru_cache(maxsize=None) def energy_function(self): - """ Returns the energy function as a numba dispatcher. """ - if self.use_numba: - return numba.njit(types.float64(types.Array(types.int64, 1, 'A', readonly=True)))(self.compute_energy) - return self.compute_energy - + if not hasattr(self, '_energy_function_cache'): + if self.use_numba: + self._energy_function_cache = numba.njit( + types.float64(types.Array(types.int64, 1, 'A', readonly=True)) + )(self.compute_energy) + else: + self._energy_function_cache = self.compute_energy + return self._energy_function_cache + @property - #@lru_cache(maxsize=None) def denergy_mutation_function(self): - """ Returns the mutation energy change function as a numba dispatcher. """ - if self.use_numba: - return numba.njit(types.float64(types.Array(types.int64, 1, 'A', readonly=True),types.int64,types.int64))(self.compute_denergy_mutation) - return self.compute_denergy_mutation + if not hasattr(self, '_denergy_mutation_function_cache'): + if self.use_numba: + self._denergy_mutation_function_cache = numba.njit( + types.float64( + types.Array(types.int64, 1, 'A', readonly=True), + types.int64, + types.int64 + ) + )(self.compute_denergy_mutation) + else: + self._denergy_mutation_function_cache = self.compute_denergy_mutation + return self._denergy_mutation_function_cache @property - #@lru_cache(maxsize=None) def denergy_swap_function(self): - """ Returns the swap energy change function as a numba dispatcher. """ - if self.use_numba: - return numba.njit(types.float64(types.Array(types.int64, 1, 'A', readonly=True),types.int64,types.int64))(self.compute_denergy_swap) - return self.compute_denergy_swap + if not hasattr(self, '_denergy_swap_function_cache'): + if self.use_numba: + self._denergy_swap_function_cache = numba.njit( + types.float64( + types.Array(types.int64, 1, 'A', readonly=True), + types.int64, + types.int64 + ) + )(self.compute_denergy_swap) + else: + self._denergy_swap_function_cache = self.compute_denergy_swap + return self._denergy_swap_function_cache @staticmethod #@abc.abstractmethod #TODO: Add abstract method decorator. Currently not working due to the late initialization of the methods. diff --git a/frustratometer/optimization/optimization.py b/frustratometer/optimization/optimization.py index 0bdc5e5..bceba16 100644 --- a/frustratometer/optimization/optimization.py +++ b/frustratometer/optimization/optimization.py @@ -1106,6 +1106,286 @@ def find_optimal_replicas(self, max_replicas=32, n_repeats=5, n_steps=10000): return results +_VMPT_EMPTY = -1e30 # log-weight of an unvisited bin (additive identity of ExpSum) +_VMPT_VISITED = -1e20 # a bin counts as visited if its log-weight exceeds this + + +class VirtualMoveParallelTempering(MonteCarlo): + """Virtual Move Parallel Tempering with an adaptive bias in (energy, heterogeneity). + + Samples the Hamiltonian H = energy - Ep*heterogeneity across a temperature + ladder while growing a per-replica bias potential Wpot(E, het) that flattens + the sampled (E, het) histogram (metadynamics-style, with Frenkel waste-recycling + deposits). Replica swaps use a bias-corrected acceptance so the extended + ensemble stays in detailed balance. + + The unbiased production histogram `Histo` (log-weights) and the bias `Wpot` are + stored on the instance after a run; `Visits` is the raw biased occupancy. + + Parameters + ---------- + energy : EnergyTerm energy axis (e.g. AwsemEnergy); its delta drives E. + heterogeneity : EnergyTerm heterogeneity axis (default: Heterogeneity()). + Ep : float coupling in H = energy - Ep*heterogeneity. + energy_range, het_range : (min, max, bin) of each bias axis. + umbrella : float bias deposit rate. + unbias_het : bool also remove the Ep*het tilt from `Histo` (off: het + stays a reaction coordinate, recommended). + """ + + def __init__(self, sequence: str, energy: EnergyTerm, heterogeneity: EnergyTerm = None, + Ep: float = 20.0, energy_range=(-2000.0, 0.0, 10.0), het_range=(0.0, 170.0, 1.0), + umbrella: float = 0.05, unbias_het: bool = False, + alphabet: str = _AA, use_numba: bool = True, evaluation_energies: dict = {}): + self.heterogeneity = heterogeneity if heterogeneity is not None else Heterogeneity(alphabet=alphabet, use_numba=use_numba) + self.Ep = Ep + self.energy_range = energy_range + self.het_range = het_range + self.umbrella = umbrella + self.unbias_het = unbias_het + Emin, Emax, Ebin = energy_range + Pmin, Pmax, Pbin = het_range + self.sizeE = int(round((Emax - Emin) / Ebin)) + self.sizeP = int(round((Pmax - Pmin) / Pbin)) + super().__init__(sequence=sequence, energy=energy, alphabet=alphabet, + use_numba=use_numba, evaluation_energies=evaluation_energies) + + def initialize_functions(self): + alphabet_size = len(self.alphabet) + sequence_size = self.seq_len + Ep = self.Ep + umbrella = self.umbrella + unbias_het = self.unbias_het + Emin, Emax, Ebin = self.energy_range + Pmin, Pmax, Pbin = self.het_range + sizeE, sizeP = self.sizeE, self.sizeP + EMPTY, VISITED = _VMPT_EMPTY, _VMPT_VISITED + + compute_energy = self.energy.energy_function + mutation_denergy = self.energy.denergy_mutation_function + swap_denergy = self.energy.denergy_swap_function + compute_het = self.heterogeneity.energy_function + mutation_dhet = self.heterogeneity.denergy_mutation_function + swap_dhet = self.heterogeneity.denergy_swap_function + + def expsum(a, b): + if a > b: + return a + np.log1p(np.exp(b - a)) + return b + np.log1p(np.exp(a - b)) + expsum = self.numbify(expsum) + + def bin_index(value, vmin, vbin, size): + idx = int(round((value - vmin) / vbin)) + if idx < 0 or idx >= size: + return -1 + return idx + bin_index = self.numbify(bin_index) + + def deposit(i, espo, beta, E_old, P_old, E_new, P_new, Whisto, Histo, Wpot, record_production): + if espo < 50.0: + espo1 = -np.log1p(np.exp(espo)) + else: + espo1 = -espo + espo_New = espo + espo1 + espo_Old = espo1 + eo = bin_index(E_old, Emin, Ebin, sizeE) + po = bin_index(P_old, Pmin, Pbin, sizeP) + en = bin_index(E_new, Emin, Ebin, sizeE) + pn = bin_index(P_new, Pmin, Pbin, sizeP) + if eo >= 0 and po >= 0: + Whisto[i, eo, po] = expsum(espo_Old, Whisto[i, eo, po]) + if en >= 0 and pn >= 0: + Whisto[i, en, pn] = expsum(espo_New, Whisto[i, en, pn]) + if not record_production: + return + w_old = Wpot[i, eo, po] if (eo >= 0 and po >= 0) else 0.0 + w_new = Wpot[i, en, pn] if (en >= 0 and pn >= 0) else 0.0 + prod_Old = espo_Old - w_old * beta + prod_New = espo_New - w_new * beta + if unbias_het: + prod_Old -= Ep * P_old * beta + prod_New -= Ep * P_new * beta + if eo >= 0 and po >= 0: + Histo[i, eo, po] = expsum(prod_Old, Histo[i, eo, po]) + if en >= 0 and pn >= 0: + Histo[i, en, pn] = expsum(prod_New, Histo[i, en, pn]) + deposit = self.numbify(deposit) + + def montecarlo_steps(temperature, i, seq_index, Whisto, Histo, Wpot, Visits, + record_production, n_steps=1000, kb=0.008314): + seq_index = seq_index.copy() + beta = 1.0 / (kb * temperature) + E = compute_energy(seq_index) + P = compute_het(seq_index) + for _ in range(n_steps): + if np.random.random() > 0.5: + res1 = np.random.randint(0, sequence_size) + res2 = np.random.randint(0, sequence_size - 1) + res2 += (res2 >= res1) + dE = swap_denergy(seq_index, res1, res2) + dP = swap_dhet(seq_index, res1, res2) + is_swap = True + res = 0 + aa = 0 + else: + r = np.random.randint(0, alphabet_size * sequence_size) + res = r // alphabet_size + aa = r % alphabet_size + dE = mutation_denergy(seq_index, res, aa) + dP = mutation_dhet(seq_index, res, aa) + is_swap = False + res1 = 0 + res2 = 0 + E_new = E + dE + P_new = P + dP + eo = bin_index(E, Emin, Ebin, sizeE) + po = bin_index(P, Pmin, Pbin, sizeP) + en = bin_index(E_new, Emin, Ebin, sizeE) + pn = bin_index(P_new, Pmin, Pbin, sizeP) + w_old = Wpot[i, eo, po] if (eo >= 0 and po >= 0) else 0.0 + w_new = Wpot[i, en, pn] if (en >= 0 and pn >= 0) else 0.0 + DW = w_new - w_old + espo = (-dE + Ep * dP + DW) * beta + deposit(i, espo, beta, E, P, E_new, P_new, Whisto, Histo, Wpot, record_production) + if np.random.random() < np.exp(min(0.0, espo)): + if is_swap: + tmp = seq_index[res1] + seq_index[res1] = seq_index[res2] + seq_index[res2] = tmp + else: + seq_index[res] = aa + E = E_new + P = P_new + if record_production: + ei = bin_index(E, Emin, Ebin, sizeE) + pi = bin_index(P, Pmin, Pbin, sizeP) + if ei >= 0 and pi >= 0: + Visits[i, ei, pi] += 1.0 + return seq_index, E, P + montecarlo_steps = self.numbify(montecarlo_steps) + + def parallel_montecarlo_step(seq_indices, temperatures, n_steps_per_cycle, + Whisto, Histo, Wpot, Visits, record_production, kb=0.008314): + n_replicas = len(temperatures) + energies = np.zeros(n_replicas) + hets = np.zeros(n_replicas) + for i in numba.prange(n_replicas): + seq, E, P = montecarlo_steps(temperatures[i], i, seq_indices[i], Whisto, Histo, + Wpot, Visits, record_production, n_steps_per_cycle, kb) + seq_indices[i] = seq + energies[i] = E + hets[i] = P + return energies, hets + parallel_montecarlo_step = self.numbify(parallel_montecarlo_step, parallel=True) + + def update_bias(Whisto, Wpot): + n_replicas = Whisto.shape[0] + for i in range(n_replicas): + wmin = 1e300 + for e in range(Whisto.shape[1]): + for p in range(Whisto.shape[2]): + w = Whisto[i, e, p] + if w > VISITED: + Wpot[i, e, p] -= umbrella * w + if Wpot[i, e, p] < wmin: + wmin = Wpot[i, e, p] + for e in range(Whisto.shape[1]): + for p in range(Whisto.shape[2]): + Wpot[i, e, p] -= wmin + Whisto[i, e, p] = EMPTY + update_bias = self.numbify(update_bias) + + def reduced_energy(E, P, replica, temperature, Wpot, kb): + H = E - Ep * P + e = bin_index(E, Emin, Ebin, sizeE) + p = bin_index(P, Pmin, Pbin, sizeP) + w = Wpot[replica, e, p] if (e >= 0 and p >= 0) else 0.0 + beta = 1.0 / (kb * temperature) + return beta * (H - w) + reduced_energy = self.numbify(reduced_energy) + + def replica_exchange(seq_indices, energies, hets, temperatures, Wpot, parity, + swap_attempts, swap_accepts, kb=0.008314): + n_replicas = len(temperatures) + for i in range(parity, n_replicas - 1, 2): + j = i + 1 + u_ii = reduced_energy(energies[i], hets[i], i, temperatures[i], Wpot, kb) + u_jj = reduced_energy(energies[j], hets[j], j, temperatures[j], Wpot, kb) + u_ij = reduced_energy(energies[i], hets[i], j, temperatures[j], Wpot, kb) + u_ji = reduced_energy(energies[j], hets[j], i, temperatures[i], Wpot, kb) + delta = (u_ii + u_jj) - (u_ji + u_ij) + swap_attempts[i] += 1 + if np.random.random() < np.exp(min(0.0, delta)): + tmp = seq_indices[i].copy() + seq_indices[i] = seq_indices[j] + seq_indices[j] = tmp + energies[i], energies[j] = energies[j], energies[i] + hets[i], hets[j] = hets[j], hets[i] + swap_accepts[i] += 1 + replica_exchange = self.numbify(replica_exchange) + + self.montecarlo_steps = montecarlo_steps + self.parallel_montecarlo_step = parallel_montecarlo_step + self.update_bias = update_bias + self.replica_exchange = replica_exchange + + @csv_writer + def parallel_tempering(self, seq_indices=None, temperatures=None, n_steps=int(1E7), + n_steps_per_cycle=int(1E4), n_equilibration_steps=0, + bias_update_interval=10, record_interval=10, kb=0.008314, + csv_filename="vmpt_results.csv", csv_write=None): + if temperatures is None: + temperatures = np.geomspace(0.2, 25.0, 24) / kb + temperatures = np.asarray(temperatures, dtype=float) + n_replicas = len(temperatures) + if seq_indices is None: + seq_indices = self.generate_random_sequences(n_replicas) + seq_indices = np.ascontiguousarray(seq_indices) + + Whisto = np.full((n_replicas, self.sizeE, self.sizeP), _VMPT_EMPTY) + Histo = np.full((n_replicas, self.sizeE, self.sizeP), _VMPT_EMPTY) + Wpot = np.zeros((n_replicas, self.sizeE, self.sizeP)) + Visits = np.zeros((n_replicas, self.sizeE, self.sizeP)) + swap_attempts = np.zeros(n_replicas, dtype=np.int64) + swap_accepts = np.zeros(n_replicas, dtype=np.int64) + + n_cycles = int(n_steps // n_steps_per_cycle) + n_equilibration_cycles = int(n_equilibration_steps // n_steps_per_cycle) + diagnostics = [] + + for s in range(n_cycles): + record_production = s >= n_equilibration_cycles + energies, hets = self.parallel_montecarlo_step( + seq_indices, temperatures, n_steps_per_cycle, + Whisto, Histo, Wpot, Visits, record_production, kb) + + if (s + 1) % bias_update_interval == 0: + self.update_bias(Whisto, Wpot) + + self.replica_exchange(seq_indices, energies, hets, temperatures, Wpot, + s % 2, swap_attempts, swap_accepts, kb) + + if s % record_interval == 0 or s == n_cycles - 1: + step = (s + 1) * n_steps_per_cycle + eval_energies = {key: term.energies(seq_indices) + for key, term in self.evaluation_energies.items()} + for i, temp in enumerate(temperatures): + step_data = {'Step': step, 'Temperature': temp, + 'Sequence': index_to_sequence(seq_indices[i], self.alphabet), + 'Energy': energies[i], 'Heterogeneity': hets[i], + 'Total Energy': energies[i] - self.Ep * hets[i]} + step_data.update({key: eval_energies[key][i] for key in self.evaluation_energies}) + csv_write(step_data) + diagnostics.append({'cycle': s, 'step': step, + 'wpot_max': float(Wpot.max()), + 'histo_nonempty_bins': int(np.sum(Histo > _VMPT_VISITED))}) + + self.Whisto, self.Histo, self.Wpot, self.Visits = Whisto, Histo, Wpot, Visits + self.swap_attempts, self.swap_accepts = swap_attempts, swap_accepts + self.diagnostics = diagnostics + self.temperatures = temperatures + return seq_indices + if __name__ == '__main__': diff --git a/tests/test_optimization.py b/tests/test_optimization.py index 7b939a1..8b32477 100644 --- a/tests/test_optimization.py +++ b/tests/test_optimization.py @@ -343,7 +343,6 @@ def test_diff_mean_inner_product_1_by_1(n_elements = 10): _AA = '-ACDEFGHIKLMNPQRSTVWY' @pytest.fixture(params=[(10, 2, 0.0), (10, 2, 4.15), (None, 10, 4.15)]) -@pytest.mark.parametrize(["distance_cutoff_contact", "min_sequence_separation_contact", "k_electrostatics"], []) def model(request): native_pdb = "tests/data/1bfz.pdb" distance_cutoff_contact, min_sequence_separation_contact, k_electrostatics = request.param @@ -475,4 +474,70 @@ def test_awsem_energy_variance(model, reduced_alphabet, use_numba): # contact_gamma=np.concatenate([a.ravel() for a in model.gamma_array[3:]]) # contact_energy_predicted = (contact_gamma * np.concatenate([a.ravel() for a in true_indicator2D])).sum() # contact_energy_expected = model.couplings_energy() -# assert np.isclose(contact_energy_predicted,contact_energy_expected), f"Expected energy {contact_energy_expected} but got {contact_energy_predicted}" \ No newline at end of file +# assert np.isclose(contact_energy_predicted,contact_energy_expected), f"Expected energy {contact_energy_expected} but got {contact_energy_predicted}" + +################################### +# Virtual Move Parallel Tempering # +################################### + +def _vmpt_model(): + structure = Structure.full_pdb("tests/data/1r69.pdb", "A") + return AWSEM(structure, distance_cutoff_contact=10, min_sequence_separation_contact=2) + + +def test_vmpt_adaptive_bias(tmp_path): + """Contract for VMPT: the adaptive bias must (A) grow into a healthy, bounded, + non-saturated potential, (B) fill a non-degenerate production histogram, and + (C) collapse to ordinary parallel tempering when umbrella=0.""" + VISITED = -1e20 + model = _vmpt_model() + energy = AwsemEnergy(model=model, alphabet=_AA, use_numba=True) + temperatures = np.geomspace(0.3, 20.0, 8) / 0.008314 + energy_range = (-1400.0, 600.0, 20.0) + het_range = (80.0, 170.0, 2.0) + + sampler = VirtualMoveParallelTempering( + sequence=model.sequence, energy=energy, Ep=20.0, + energy_range=energy_range, het_range=het_range, umbrella=0.05, alphabet=_AA) + sampler.parallel_tempering( + temperatures=temperatures, n_steps=60000, n_steps_per_cycle=500, + n_equilibration_steps=10000, bias_update_interval=5, record_interval=20, + csv_filename=str(tmp_path / "vmpt.csv")) + + Wpot, Histo = sampler.Wpot, sampler.Histo + # A: bias healthy (finite, grew, not saturated-degenerate like the old run) + assert np.all(np.isfinite(Wpot)) + assert Wpot.max() > 0.0 + frac_at_max = np.mean(np.abs(Wpot - Wpot.max()) < 1e-6) + assert frac_at_max < 0.9, f"bias looks saturated/degenerate (frac_at_max={frac_at_max})" + # B: production histogram non-degenerate (real weight, not the ~1e-12 floor) + nonempty = int(np.sum(Histo > VISITED)) + assert nonempty > 20, f"only {nonempty} non-empty histogram bins" + assert Histo[Histo > VISITED].max() > 0.0 + + # C: umbrella=0 must reduce to plain parallel tempering (no bias) + plain = VirtualMoveParallelTempering( + sequence=model.sequence, energy=energy, Ep=20.0, + energy_range=energy_range, het_range=het_range, umbrella=0.0, alphabet=_AA) + plain.parallel_tempering( + temperatures=temperatures, n_steps=30000, n_steps_per_cycle=500, + n_equilibration_steps=5000, bias_update_interval=5, record_interval=20, + csv_filename=str(tmp_path / "vmpt_umbrella0.csv")) + assert np.abs(plain.Wpot).max() < 1e-9 + assert int(np.sum(plain.Histo > VISITED)) > 20 + + +def test_vmpt_runs_without_numba(tmp_path): + """The pure-Python path (use_numba=False) must execute end to end.""" + model = _vmpt_model() + energy = AwsemEnergy(model=model, alphabet=_AA, use_numba=False) + sampler = VirtualMoveParallelTempering( + sequence=model.sequence, energy=energy, Ep=20.0, + energy_range=(-1400.0, 600.0, 50.0), het_range=(80.0, 170.0, 5.0), + umbrella=0.05, alphabet=_AA, use_numba=False) + sampler.parallel_tempering( + temperatures=np.geomspace(0.5, 15.0, 3) / 0.008314, + n_steps=600, n_steps_per_cycle=200, n_equilibration_steps=200, + bias_update_interval=2, record_interval=1, csv_filename=str(tmp_path / "vmpt_nonumba.csv")) + assert np.all(np.isfinite(sampler.Wpot)) + assert int(np.sum(sampler.Histo > -1e20)) > 0