Source code for pepkit.modelling.af.post.analysis

#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import logging
import math
import shutil
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

from .base import BaseFeature
from .config import BaseConfig, IndexBasedConfig
from .dockq import (
    ensure_native_pdbs,
    find_native_pdb_file,
    inject_dockq_into_entry,
    native_pdb_dir_for_batch_path,
    read_mapping_csv,
)
from .indices import IndexCalculator
from .pae import PAE
from .plddt import PLDDT
from .ptm import PTM
from ....prep.utils import Utils

_DEFAULT_BASE_CONFIG = BaseConfig()
_DEFAULT_ANALYSIS_CONFIG = IndexBasedConfig()

LOGGER = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Core dataclasses
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class AnalysisInputs: json_path: Optional[Path] pdb_path: Optional[Path]
[docs] @dataclass(frozen=True) class EntryMeta: length: Optional[int] processing_time: Optional[float]
[docs] @dataclass class BatchStats: ok: int = 0 empty: int = 0 error: int = 0 dockq_ok: int = 0 dockq_fail: int = 0
[docs] class ProgressLogger: """ Log at K% increments (10%, 20%, ...). """ def __init__(self, total: int, step_pct: int) -> None: self.total = max(1, int(total)) self.step = max(1, int(step_pct)) self.next_pct = self.step
[docs] def tick(self, i: int) -> None: pct = int((100.0 * i) / self.total) if pct >= self.next_pct: LOGGER.info("Progress: %d/%d (%d%%)", i, self.total, pct) self.next_pct += self.step
# --------------------------------------------------------------------------- # Analysis implementation # ---------------------------------------------------------------------------
[docs] class Analysis(BaseFeature): """ High-level feature aggregation for AF(-Multimer) outputs. DockQ integration (via dockq.py): - Provide --mapping_csv with pdb_id,mapping to enable DockQ. - DockQ is computed for EACH entry *and* EACH rank. - Written inside each rank dict: rankXXX["total_dockq"] rankXXX["avg_dockq"] """ def __init__( self, json_path: Optional[str] = None, pdb_path: Optional[str] = None, peptide_chain_position: str = "last", distance_cutoff: float = 8.0, round_digits: int = 2, *, pdockq2_d0: float = 10.0, pdockq2_sym_pae: bool = True, ) -> None: super().__init__( pdb_lines=None, peptide_chain_position=peptide_chain_position, distance_cutoff=distance_cutoff, ) self.json_path = Path(json_path) if json_path else None self.pdb_path = Path(pdb_path) if pdb_path else None self.round_digits = int(round_digits) self.pdockq2_d0 = float(pdockq2_d0) self.pdockq2_sym_pae = bool(pdockq2_sym_pae) # ----------------------- # Single-rank analysis # -----------------------
[docs] def single_analysis(self) -> Dict[str, Any]: if self.json_path is None or self.pdb_path is None: raise ValueError( "single_analysis requires both json_path and pdb_path" ) rec_json: Dict[str, Any] = Utils.process_json(self.json_path) pdb_lines: List[str] = Utils.process_pdb(self.pdb_path) ( peptide_indices, peptide_chain, protein_interface_indices, peptide_interface_indices, interacting_pairs, protein_chains, ) = self._compute_interface_indices(pdb_lines) plddt_summary = self._compute_plddt( rec_json=rec_json, peptide_indices=peptide_indices, protein_interface_indices=protein_interface_indices, peptide_interface_indices=peptide_interface_indices, ) pae_summary = self._compute_pae( rec_json=rec_json, peptide_indices=peptide_indices, protein_interface_indices=protein_interface_indices, peptide_interface_indices=peptide_interface_indices, interacting_pairs=interacting_pairs, ) ptm_summary = self._compute_ptm( rec_json=rec_json, peptide_chain=peptide_chain ) dockq_inputs = self._compute_pdockq_contacts_and_ptm( pdb_lines=pdb_lines, rec_json=rec_json, peptide_chain=peptide_chain, protein_chains=protein_chains, d0_pae=self.pdockq2_d0, sym_pae=self.pdockq2_sym_pae, ) d: Dict[str, Any] = { **plddt_summary, **pae_summary, **ptm_summary, "protein_interface_residues": protein_interface_indices, "peptide_interface_residues": peptide_interface_indices, "peptide_chain": peptide_chain, "protein_chains": protein_chains, "n_chains": 1 + len(protein_chains), **dockq_inputs, } d.update(self._compute_confidence_scores(d)) return d
def _compute_interface_indices( self, pdb_lines: List[str] ) -> Tuple[List[int], str, List[int], List[int], List[Tuple[int, int]], List[str]]: peptide_indices, peptide_chain = IndexCalculator.get_peptide_indices( pdb_lines, peptide_chain_position=self.peptide_chain_position, ) ( protein_interface_indices, peptide_interface_indices, protein_chains, _, interacting_pairs, ) = IndexCalculator.get_interface_indices( pdb_lines, peptide_chain=peptide_chain, distance_cutoff=self.distance_cutoff, ) return ( peptide_indices, peptide_chain, protein_interface_indices, peptide_interface_indices, interacting_pairs, protein_chains, ) def _compute_plddt( self, *, rec_json: Dict[str, Any], peptide_indices: List[int], protein_interface_indices: List[int], peptide_interface_indices: List[int], ) -> Dict[str, Any]: plddt_obj = PLDDT( rec_json, peptide_indices, protein_interface_indices, peptide_interface_indices, round_digits=self.round_digits, ) ( mean_plddt, median_plddt, peptide_plddt, protein_interface_plddt, peptide_interface_plddt, interface_plddt, ) = plddt_obj.summary() return { "mean_plddt": mean_plddt, "median_plddt": median_plddt, "peptide_plddt": peptide_plddt, "protein_interface_plddt": protein_interface_plddt, "peptide_interface_plddt": peptide_interface_plddt, "interface_plddt": interface_plddt, } def _compute_pae( self, *, rec_json: Dict[str, Any], peptide_indices: List[int], protein_interface_indices: List[int], peptide_interface_indices: List[int], interacting_pairs: List[Tuple[int, int]], ) -> Dict[str, Any]: pae_obj = PAE( rec_json, peptide_indices=peptide_indices, protein_interface_indices=protein_interface_indices, peptide_interface_indices=peptide_interface_indices, interacting_pairs=interacting_pairs, round_digits=self.round_digits, ) ( mean_pae, max_pae, peptide_pae, protein_interface_pae, peptide_interface_pae, mean_interface_pae, ) = pae_obj.summary() return { "mean_pae": mean_pae, "max_pae": max_pae, "peptide_pae": peptide_pae, "protein_interface_pae": protein_interface_pae, "peptide_interface_pae": peptide_interface_pae, "mean_interface_pae": mean_interface_pae, } def _compute_ptm( self, *, rec_json: Dict[str, Any], peptide_chain: str ) -> Dict[str, Any]: ptm_obj = PTM( rec_json, peptide_chain=peptide_chain, round_digits=self.round_digits, ) ( ptm, global_iptm, composite_ptm, peptide_ptm, protein_ptm, actif_ptm, ) = ptm_obj.summary() return { "ptm": ptm, "global_iptm": global_iptm, "composite_ptm": composite_ptm, "peptide_ptm": peptide_ptm, "protein_ptm": protein_ptm, "actif_ptm": actif_ptm, } def _compute_pdockq_contacts_and_ptm( self, *, pdb_lines: List[str], rec_json: Dict[str, Any], peptide_chain: str, protein_chains: List[str], d0_pae: float = 10.0, sym_pae: bool = True, ) -> Dict[str, Any]: from .contact import ContactCounter cc = ContactCounter( pdb_lines=pdb_lines, peptide_chain_position=self.peptide_chain_position, distance_cutoff=self.distance_cutoff, ) all_pairs_global: List[Tuple[int, int]] = [] n_contacts_total = 0 for prot_chain in protein_chains: if prot_chain == peptide_chain: continue r = cc.contact_count_pair( peptide_chain, prot_chain, return_global=True ) n_contacts_total += int(r.n_contacts) if r.pairs_global: all_pairs_global.extend(r.pairs_global) pae = rec_json.get("pae", None) mean_ptm = float("nan") if pae is not None and all_pairs_global: vals: List[float] = [] for gi, gj in all_pairs_global: i = int(gi) - 1 j = int(gj) - 1 try: pae_ij = float(pae[i][j]) pae_use = pae_ij if sym_pae: pae_ji = float(pae[j][i]) pae_use = 0.5 * (pae_ij + pae_ji) vals.append(1.0 / (1.0 + (pae_use / float(d0_pae)) ** 2)) except Exception: continue if vals: mean_ptm = float(sum(vals) / len(vals)) return { "n_contacts_pdockq": int(n_contacts_total), "mean_ptm_pdockq2": ( round(mean_ptm, self.round_digits) if mean_ptm == mean_ptm else float("nan") ), } def _compute_confidence_scores(self, d: Dict[str, Any]) -> Dict[str, Any]: from ..score.mpdockq import MPDockQ from ..score.pdockq import PDockQ from ..score.pdockq2 import PDockQ2 def clean(x: float) -> Optional[float]: if isinstance(x, float) and math.isnan(x): return None return round(float(x), self.round_digits) return { "pdockq": clean(PDockQ().compute(d).score), "pdockq2": clean(PDockQ2().compute(d).score), "mpdockq": clean(MPDockQ(warn_if_lt3=False).compute(d).score), } # ----------------------- # Entry directory analysis (per-rank) # -----------------------
[docs] def all_analysis(self, dir_path: Union[str, Path]) -> Dict[str, Any]: entry_dir = _resolve_entry_dir(Path(dir_path)) entry_result: Dict[str, Any] = {} try: meta = self._entry_meta(entry_dir) for i in range(1, 6): key = f"rank{i:03d}" inputs = self._rank_inputs(entry_dir, i) if inputs is None: self._warn_missing_rank(entry_dir, i) continue try: rank_out = Analysis( json_path=str(inputs.json_path), pdb_path=str(inputs.pdb_path), peptide_chain_position=self.peptide_chain_position, distance_cutoff=self.distance_cutoff, round_digits=self.round_digits, pdockq2_d0=self.pdockq2_d0, pdockq2_sym_pae=self.pdockq2_sym_pae, ).single_analysis() # Keep path for DockQ injection later rank_out["_pred_pdb_path"] = ( str(inputs.pdb_path) if inputs.pdb_path else None ) entry_result[key] = rank_out except Exception as e: self.log_error( f"Error processing {entry_dir.name} rank {i}: {e}" ) continue entry_result["length"] = meta.length entry_result["processing_time"] = meta.processing_time return entry_result except Exception as e: self.log_error(f"Error processing entry {entry_dir.name}: {e}") return {}
def _entry_meta(self, entry_dir: Path) -> EntryMeta: length = Utils.get_length(entry_dir) log_matches = sorted(entry_dir.glob("*log.txt")) process_time = Utils.processing_time( log_matches[0] ) if log_matches else None return EntryMeta(length=length, processing_time=process_time) @staticmethod def _rank_inputs(entry_dir: Path, rank_i: int) -> Optional[AnalysisInputs]: rank_tag = f"rank_{rank_i:03d}" json_matches = sorted(entry_dir.glob(f"*_scores_{rank_tag}_*.json")) pdb_matches = sorted(entry_dir.glob(f"*relaxed_{rank_tag}_*.pdb")) if not json_matches or not pdb_matches: return None return AnalysisInputs( json_path=json_matches[0], pdb_path=pdb_matches[0] ) @staticmethod def _warn_missing_rank(entry_dir: Path, rank_i: int) -> None: LOGGER.warning( "No matching PDB or JSON file found for %s rank %d", entry_dir.name, rank_i, ) # ----------------------- # Batch analysis (+ DockQ injection into EACH rank) + PROGRESS LOGGING # -----------------------
[docs] def batch_analysis( self, batch_dir: Union[str, Path], *, delete_zips: bool = True, mapping_by_pdbid: Optional[Dict[str, Dict[str, str]]] = None, native_pdb_dir: Optional[Path] = None, progress_step_pct: int = 10, ) -> Dict[str, Any]: """ progress_step_pct=10 => log at 10%,20%,...,100% """ batch_result: Dict[str, Any] = {} entry_dirs, _ = _prepare_batch_dirs(batch_dir, delete_zips=delete_zips) n = len(entry_dirs) if n == 0: LOGGER.warning("No entry directories found under %s", batch_dir) return {} LOGGER.info("Processing %d entries...", n) progress = ProgressLogger(total=n, step_pct=progress_step_pct) stats = BatchStats() for i, entry_dir in enumerate(entry_dirs, start=1): progress.tick(i) name = entry_dir.name real_entry_dir = _resolve_entry_dir(entry_dir) try: entry_out = self._process_one_entry( real_entry_dir=real_entry_dir, entry_dir_name=name, mapping_by_pdbid=mapping_by_pdbid, native_pdb_dir=native_pdb_dir, stats=stats, ) batch_result[name] = entry_out stats.ok += 1 if not entry_out: stats.empty += 1 except Exception as e: self.log_error(f"Batch error in entry {name}: {e}") batch_result[name] = {} stats.error += 1 LOGGER.info( "DONE. ok=%d empty=%d error=%d | dockq_ok=%d dockq_fail=%d", stats.ok, stats.empty, stats.error, stats.dockq_ok, stats.dockq_fail, ) return batch_result
def _process_one_entry( self, *, real_entry_dir: Path, entry_dir_name: str, mapping_by_pdbid: Optional[Dict[str, Dict[str, str]]], native_pdb_dir: Optional[Path], stats: BatchStats, ) -> Dict[str, Any]: entry_out = self.all_analysis(real_entry_dir) if mapping_by_pdbid is None or native_pdb_dir is None: return entry_out pid = entry_dir_name.split("_")[0].lower() mapping = mapping_by_pdbid.get(pid) exp_pdb = find_native_pdb_file(native_pdb_dir, pid) if mapping is None: LOGGER.debug("DockQ: no mapping for %s", pid) return entry_out if exp_pdb is None: LOGGER.debug( "DockQ: native PDB missing for %s under %s", pid, native_pdb_dir, ) return entry_out dq_stats = inject_dockq_into_entry( entry_out=entry_out, entry_dir_name=entry_dir_name, mapping=mapping, exp_pdb=exp_pdb, round_digits=self.round_digits, ) stats.dockq_ok += dq_stats.ok stats.dockq_fail += dq_stats.fail return entry_out # ----------------------- # CLI # -----------------------
[docs] @staticmethod def args() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Extract summary metrics from JSON and PDB files" ) parser.add_argument("--json", type=str, help="Path to the JSON file") parser.add_argument("--pdb", type=str, help="Path to the PDB file") parser.add_argument( "--chain", type=str, choices=["last", "first", "none"], default=_DEFAULT_BASE_CONFIG.peptide_chain_position, help="Which chain to consider as peptide", ) parser.add_argument( "--cutoff", type=float, default=_DEFAULT_BASE_CONFIG.cutoff, help="Distance cutoff (Å) for defining interface residues", ) parser.add_argument( "--round", type=int, default=_DEFAULT_ANALYSIS_CONFIG.round_digits, help="Number of decimal places to round metrics", ) parser.add_argument( "--entry_dir", type=str, help="Path to a single entry directory", ) parser.add_argument( "--batch_dir", type=str, help="Path to a batch directory OR a .zip containing it", ) parser.add_argument( "--mapping_csv", type=str, default=None, help=( "Optional CSV with columns pdb_id,mapping." "If provided, downloads native PDBs " "and computes DockQ per rank." ), ) parser.add_argument( "--progress_pct", type=int, default=10, help="Progress logging step in percent (default: 10).", ) parser.add_argument("--pdockq2_d0", type=float, default=10.0) parser.add_argument("--pdockq2_sym_pae", action="store_true") parser.add_argument("--verbose", action="store_true") return parser
# --------------------------------------------------------------------------- # File + zip utilities # --------------------------------------------------------------------------- def _write_json(path: Path, obj: Any) -> None: path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, indent=4) def _has_any_af_files(d: Path) -> bool: return bool(list(d.glob("*_scores_*.json"))) and bool( list(d.glob("*relaxed_*.pdb")) ) def _resolve_entry_dir(entry_dir: Path, max_depth: int = 3) -> Path: cur = Path(entry_dir) for _ in range(max_depth): if _has_any_af_files(cur): return cur same = cur / cur.name if same.is_dir() and _has_any_af_files(same): return same subs = [p for p in cur.iterdir() if p.is_dir()] if len(subs) == 1: cur = subs[0] continue return cur return cur def _unzip_zip_to_dir(zip_path: Path, out_dir: Path) -> Path: zip_path = Path(zip_path) out_dir = Path(out_dir) if out_dir.exists(): return out_dir if not zip_path.exists(): raise FileNotFoundError(f"Zip not found: {zip_path}") tmp = out_dir.with_name(out_dir.name + "__tmp_extract") if tmp.exists(): shutil.rmtree(tmp) tmp.mkdir(parents=True, exist_ok=True) try: with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(tmp) candidate: Optional[Path] = None inner = tmp / "outputdir" / "outputdir" if inner.is_dir(): items = [p for p in inner.iterdir()] if len(items) == 1 and items[0].is_dir(): candidate = items[0] if candidate is None: top_dirs = [p for p in tmp.iterdir() if p.is_dir()] top_files = [p for p in tmp.iterdir() if p.is_file()] if len(top_dirs) == 1 and not top_files: candidate = top_dirs[0] if candidate is not None: shutil.move(str(candidate), str(out_dir)) else: out_dir.mkdir(parents=True, exist_ok=True) for p in tmp.iterdir(): shutil.move(str(p), str(out_dir / p.name)) finally: if tmp.exists(): shutil.rmtree(tmp) return out_dir def _unzip_all_zips_in_folder( folder: Path, *, delete_zips: bool ) -> List[Path]: folder = Path(folder) extracted: List[Path] = [] zips = sorted(folder.glob("*.zip")) if not zips: return extracted LOGGER.info("Unzipping %d zip(s) in %s", len(zips), folder) for z in zips: target = folder / z.stem _unzip_zip_to_dir(z, target) extracted.append(z) if delete_zips: try: z.unlink() except Exception as e: LOGGER.warning("Failed to delete zip %s: %s", z, e) return extracted def _normalize_batch_root(batch_dir: Union[str, Path]) -> Path: batch_path = Path(batch_dir) if not batch_path.exists(): raise FileNotFoundError(f"Batch path not found: {batch_path}") if batch_path.is_file() and batch_path.suffix == ".zip": out_dir = batch_path.parent / batch_path.stem LOGGER.info("Unzipping batch archive %s -> %s", batch_path, out_dir) _unzip_zip_to_dir(batch_path, out_dir) return out_dir if batch_path.is_dir(): return batch_path raise ValueError(f"Unsupported batch_dir type: {batch_path}") def _prepare_batch_dirs( batch_dir: Union[str, Path], *, delete_zips: bool, ) -> Tuple[List[Path], List[Path]]: root = _normalize_batch_root(batch_dir) extracted_zip_paths: List[Path] = [] extracted_zip_paths.extend( _unzip_all_zips_in_folder(root, delete_zips=delete_zips) ) entry_dirs: List[Path] = [] for p in sorted(root.iterdir()): if not p.is_dir(): continue extracted_zip_paths.extend( _unzip_all_zips_in_folder(p, delete_zips=delete_zips) ) entry_dirs.append(p) return entry_dirs, extracted_zip_paths def _setup_logging(verbose: bool) -> None: level = logging.DEBUG if verbose else logging.INFO logging.basicConfig( level=level, format="%(asctime)s | %(levelname)s | %(message)s" )
[docs] def main() -> None: parser = Analysis.args() args = parser.parse_args() _setup_logging(bool(getattr(args, "verbose", False))) has_batch = bool(args.batch_dir) has_entry = bool(args.entry_dir) has_single = bool(args.json and args.pdb) modes = sum([has_batch, has_entry, has_single]) if modes == 0: parser.print_help() raise SystemExit(2) if modes > 1: raise SystemExit( "Provide exactly one mode: --batch_dir OR --entry_dir" "OR (--json AND --pdb)." ) mapping_by_pdbid: Optional[Dict[str, Dict[str, str]]] = None native_pdb_dir: Optional[Path] = None if args.mapping_csv: mapping_by_pdbid = read_mapping_csv(args.mapping_csv) if args.batch_dir: native_pdb_dir = native_pdb_dir_for_batch_path(args.batch_dir) elif args.entry_dir: native_pdb_dir = Path(args.entry_dir).parent / "pdb" else: native_pdb_dir = Path(args.pdb).resolve().parent / "pdb" ensure_native_pdbs(mapping_by_pdbid, native_pdb_dir) if args.batch_dir: batch_root = _normalize_batch_root(args.batch_dir) analysis = Analysis( json_path=None, pdb_path=None, peptide_chain_position=args.chain, distance_cutoff=args.cutoff, round_digits=args.round, pdockq2_d0=args.pdockq2_d0, pdockq2_sym_pae=bool(args.pdockq2_sym_pae), ) result = analysis.batch_analysis( batch_root, delete_zips=True, mapping_by_pdbid=mapping_by_pdbid, native_pdb_dir=native_pdb_dir, progress_step_pct=int(args.progress_pct), ) out_path = Path(batch_root) / "result.json" LOGGER.info("Writing result to %s", out_path) _write_json(out_path, result) return if args.entry_dir: analysis = Analysis( json_path=None, pdb_path=None, peptide_chain_position=args.chain, distance_cutoff=args.cutoff, round_digits=args.round, pdockq2_d0=args.pdockq2_d0, pdockq2_sym_pae=bool(args.pdockq2_sym_pae), ) result = analysis.all_analysis(args.entry_dir) out_path = Path(args.entry_dir) / "result.json" _write_json(out_path, result) return analysis = Analysis( json_path=args.json, pdb_path=args.pdb, peptide_chain_position=args.chain, distance_cutoff=args.cutoff, round_digits=args.round, pdockq2_d0=args.pdockq2_d0, pdockq2_sym_pae=bool(args.pdockq2_sym_pae), ) result = analysis.single_analysis() print(json.dumps(result, indent=4))
if __name__ == "__main__": main()