Source code for modelarchive.modelcif.fix_af3

"""ModelCIF files generated by AlphaFold 3 deviate from the official ModelCIF
definition dictionary in specific cases. In particular, for homomeric
assemblies, each molecular entity copy is written as a separate entity in the
CIF document, instead of defining a single entity referenced multiple times.
This module provides functionality to correct the deviations.
"""

# Copyright (c) 2026, SIB - Swiss Institute of Bioinformatics and
#                     Biozentrum - University of Basel
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# pylint: disable=too-many-lines

from pathlib import Path
import json
import zipfile

from gemmi import cif
import numpy as np
import requests

from .. import _utils
from . import access
from . import edit


def _is_null(value):
    """Borrowed from gemmi."""
    # ToDo: This may become a public function in the future.
    return len(value) == 1 and value[0] in ("?", ".")


def _char_table(c):
    """Borrowed from gemmi."""
    # fmt: off
    table = [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 2, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        2, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0
    ] + [0] * 128
    # fmt: on
    return table[ord(c) % 256]


def _quote(v):
    """Borrowed from gemmi, prefer double quotes above single quotes for AF."""
    if all(_char_table(c) == 1 for c in v) and len(v) > 0 and not _is_null(v):
        return v
    q = ";"
    if "\n" not in v:
        if '"' not in v:
            q = '"'
        elif "'" not in v:
            q = "'"
    v = q + v
    if q == ";":
        v += "\n"
    v += q
    return v


[docs] def fix_model_name(block, mdl_rank): """Normalise _ma_model_list.model_name for given rank. AlphaFold 3 sets _ma_model_list.model_name to "Top ranked model" for all models, regardless of their rank. This function rewrites the value such that only ``mdl_rank == 1`` is labelled "Top ranked model". All other ranks are renamed to "#<``mdl_rank``> ranked model". Examples: >>> from gemmi import cif >>> from modelarchive.modelcif import fix_af3 >>> # get sample CIF data >>> cif_data = '''data_test ... _ma_model_list.data_id 1 ... _ma_model_list.model_name "Top ranked model" ... _ma_model_list.model_type "Ab initio model" ... _ma_model_list.ordinal_id 1 ... ''' >>> block = cif.read_string(cif_data).sole_block() >>> fix_af3.fix_model_name(block, 2) >>> print(block.as_string()) data_test _ma_model_list.data_id 1 _ma_model_list.model_name "#2 ranked model" _ma_model_list.model_type "Ab initio model" _ma_model_list.ordinal_id 1 <BLANKLINE> >>> fix_af3.fix_model_name(block, 1) >>> print(block.as_string()) data_test _ma_model_list.data_id 1 _ma_model_list.model_name "Top ranked model" _ma_model_list.model_type "Ab initio model" _ma_model_list.ordinal_id 1 <BLANKLINE> Args: block (|gemmicifBlock|): CIF block to operate on. mdl_rank (int): Rank of the AlphaFold 3 model. If ``mdl_rank == 1``, the name is set to "Top ranked model". Returns: None Raises: RuntimeError: If the ``_ma_model_list`` category contains more than one row. edit.NotFoundCategoryError: no software entry found for AF3. edit.NotFoundItemError: If _ma_model_list.model_name can not be found in ``block``. """ if mdl_rank == 1: mdl_name = "Top ranked model" else: mdl_name = f"#{mdl_rank} ranked model" table = access.get_table(block, "_ma_model_list", items=["model_name"]) if not table: raise edit.NotFoundItemError( msg="File is missing _ma_model_list.model_name, single model " + "required" ) if len(table) != 1: raise RuntimeError("File must have a single model in _ma_model_list.") table[0]["model_name"] = _quote(mdl_name)
def _get_ordinal_ids(cur_ids, num_ids_needed): """Find set of ordinal IDs avoiding existing ones. - cur_ids: IDs as strings (hopefully something like 1, 2, 3, ...) - num_ids_needed: number of IDs to provide (next numerals not in cur_ids) """ possible_ids = [ str(i) for i in range(1, len(cur_ids) + num_ids_needed + 1) if str(i) not in cur_ids ] return possible_ids[:num_ids_needed]
[docs] class NotIdentifiedRecordError(RuntimeError): """General exception for records that can not be identified in a table. This exception should not be raised directly, it exists to define other "NotIdentified" exceptions inheriting from it. Args: msg (str): Exception message. """ def __init__(self, msg): super().__init__(msg)
[docs] class NotIdentifiedDuplicatedRecordError(NotIdentifiedRecordError): """Exception if a duplicated record is found in a table. Attributes: category (str): Category with the non-unique records. Args: category (str): Missing category. record_id (str): Identifier for the duplicated record. Not bound to a specific item on purpose. """ def __init__(self, category, record_id): self.category = category msg = ( f"Duplicated records found in category '{category}' for " + f"'{record_id}'" ) super().__init__(msg)
[docs] class NotIdentifiedSingleRecordError(NotIdentifiedRecordError): """Exception if a specific record can not be identified in a table. Attributes: category (str): Affected category. item(str|None): Affected item. Args: category (str): Affected category. item (str, optional): Missing item, extends the exception message. value (str, optional): Value, in case a record is found but with mismatching value. Extends the exception message. """ def __init__(self, category, item=None, value=None): self.category = category self.item = item msg = f"Could not identify record in category '{category}'" if item is not None and value is not None: msg += f", mismatch at item '{item}={value}'" elif item is not None: msg += f", missing item '{item}'" msg += "." super().__init__(msg)
[docs] class NotIdentifiedContextRecordError(NotIdentifiedRecordError): """Exception if a record for a specific context can not be identified. Attributes: category (str): Category for which the exception is raised. item (str|None): Involved item, if any. Args: category (str): Affected category. item (str, optional): Affected item. context (str, optional): Context, part of the message. """ def __init__(self, category, item=None, context=None): self.category = category self.item = item msg = f"Could not identify record in category '{category}'" if item is not None: msg += f", item '{item}'" if context is not None: msg += context msg += "." super().__init__(msg)
def _fix_citation_fallback(block, primary_citation_id): """Fix citation for AF3 if usual approach failed. Return old_af3_sw_cit_id, new_af3_sw_cit_id. Both set to None if something failed. """ exp_cit = { "country": "UK", "journal_full": "Nature", "journal_id_ASTM": "NATUAS", "journal_id_CSD": "0006", "journal_id_ISSN": "0028-0836", "journal_volume": "630", "page_first": "493", "page_last": "500", "pdbx_database_id_DOI": "10.1038/s41586-024-07487-w", "pdbx_database_id_PubMed": "38718835", "title": "Accurate structure prediction of biomolecular interactions " + "with AlphaFold 3", "year": "2024", } cat = "_citation." cit_table = access.get_table(block, cat) num_rows = len(cit_table) if cit_table else 0 # check and abort as needed if num_rows == 0: new_af3_sw_cit_id = primary_citation_id exp_cit = {"id": new_af3_sw_cit_id, **exp_cit} block.set_pairs(cat, exp_cit) return None, new_af3_sw_cit_id # check that all items exist for itm in exp_cit: if f"{cat}{itm}" not in cit_table.tags: raise NotIdentifiedSingleRecordError(cat[:-1], item=itm) # search for a record with all matching values or '?' for i in range(num_rows): found = True for key, val in exp_cit.items(): # Checking for empty string "" is because gemmi as_string() # translates "?" and "." to "". if cif.as_string(cit_table[i][f"{cat}{key}"]) not in ["", val]: found = False break if found: break if not found: # At this point, 'key' and 'val' are defined as 'num_rows' must be # greateer than 0. Silence Pylint warning. # pylint: disable=undefined-loop-variable raise NotIdentifiedSingleRecordError(cat[:-1], item=key, value=val) old_af3_sw_cit_id = cit_table[i]["id"] if old_af3_sw_cit_id in ["?", "."] or old_af3_sw_cit_id == "primary": new_af3_sw_cit_id = primary_citation_id else: new_af3_sw_cit_id = old_af3_sw_cit_id # fix dict cit_table[i]["id"] = new_af3_sw_cit_id for key, val in exp_cit.items(): cit_table[i][key] = val return old_af3_sw_cit_id, new_af3_sw_cit_id def _is_af3_sw_name(sw_name): """Check if given _software.name is for AF3.""" return sw_name.lower().startswith("alphafold") class _AF3ItemSetter: """Class as a callback for adding columns. Sets a fixed value in an AF3 row. Fails if there are multiple AF3 rows.""" # This is a tiny helper-class with the purpose of storing a state, only # for local use, disable Pylint warning # pylint: disable=too-few-public-methods def __init__(self, value, item="citation_id"): self.item = item self.sw_found = False self.value = value def __call__(self, row, same=False): if _is_af3_sw_name(row["name"]): if self.sw_found: raise NotIdentifiedDuplicatedRecordError( "_software", "AlphaFold" ) self.sw_found = True return self.value if same: return row[self.item] return "?" def _fix_software(block, new_af3_sw_cit_id): """Update _software with _citation record ID.""" sw_table = access.get_table(block, "_software") af3_cid_setter = _AF3ItemSetter(new_af3_sw_cit_id) if sw_table: if "_software.citation_id" not in sw_table.tags: edit.add_column(block, "_software", "citation_id", af3_cid_setter) else: for row in sw_table: row["citation_id"] = af3_cid_setter(row, same=True) if not af3_cid_setter.sw_found: raise NotIdentifiedContextRecordError( "_software", context=": AlphaFold 3 not found" ) def _get_key_primary(row): """For edit.sort(), make sure 'primary' comes first""" if row["id"] == "primary": return (-1, "") try: return (0, int(row["id"])) except ValueError: return (1, row["id"]) def _ensure_citation_id_first(block): """Make sure _citation.id is the first tag of a table.""" cif_dict = block.get_mmcif_category("_citation") table = access.get_table(block, "_citation") if table.tags[0] == "_citation.id": return cif_dict = {"id": list(table.find_column("id"))} for itm in table.tags: itm = itm.split(".", maxsplit=1)[1] if itm == "id": continue cif_dict[itm] = list(table.find_column(itm)) table.erase() edit.add_category( block, "_citation", item_data=cif_dict, index="before:_citation_author", raw=True, )
[docs] def fix_citation(block): """Normalise the AlphaFold 3 citation in a `ModelCIF`_ ``block``. Ensures that the AlphaFold 3 publication (`PMID 38718835 <https://pubmed.ncbi.nlm.nih.gov/38718835/>`_) is not marked as the "primary" citation, assigns a numeric citation ID instead. Fixes an incomplete AlphaFold 3 citation. Replaces the author list with the full curated list of names and updates its citation ID. Reorders citations so that the primary entry appears first and links the citation to the corresponding software record. This adjustment is not required for valid `ModelCIF`_ files, but follows `ModelArchive`_ conventions where the primary citation must refer to the deposited model rather than the software used to generate it. Examples: >>> from gemmi import cif >>> from modelarchive.modelcif import access, fix_af3 >>> # start with an empty CIF document >>> CIF_DATA = '''data_test ... _citation.id primary ... _citation.country UK ... _citation.journal_full Nature ... _citation.journal_id_ASTM NATUAS ... _citation.journal_id_CSD 0006 ... _citation.journal_id_ISSN 0028-0836 ... _citation.journal_volume 630 ... _citation.page_first 493 ... _citation.page_last 500 ... _citation.pdbx_database_id_DOI 10.1038/s41586-024-07487-w ... _citation.pdbx_database_id_PubMed 38718835 ... _citation.title 'Accurate structure prediction of biomolecular ...' ... _citation.year 2024 ... # ... loop_ ... _citation_author.citation_id ... _citation_author.name ... _citation_author.ordinal ... primary "Google DeepMind AlphaFold Team" 1 ... primary "Isomorphic Labs Team" 2 ... # ... loop_ ... _software.classification ... _software.date ... _software.description ... _software.name ... _software.pdbx_ordinal ... _software.type ... _software.version ... other ? "Structure prediction" AlphaFold 1 package AlphaFold-beta ... ''' >>> block = cif.read_string(CIF_DATA).sole_block() >>> fix_af3.fix_citation(block) >>> # The usual block.as_string() output would be too much for a >>> # docstring, just check some important values. >>> table = access.get_table(block, "_citation") >>> assert table[0]["id"] == "1" >>> table = access.get_table(block, "_citation_author") >>> assert table[0]["name"] != "Google DeepMind AlphaFold Team" >>> table = access.get_table(block, "_software") >>> assert table[0]["citation_id"] == "1" Args: block (|gemmicifBlock|): CIF block to operate on. Returns: None Raises: edit.NotFoundCategoryError: If _software category can not be found. NotIdentifiedSingleRecordError: If required item is missing from _citation category. If item values are not as expected for _citation category. NotIdentifiedDuplicatedRecordError: If multiple entries for AlphaFold are found in _software category. In that case, the "right" record can not be identified. """ old_af3_sw_cit_id = None new_af3_sw_cit_id = None cat = "_citation" itms = ["id", "pdbx_database_id_PubMed"] table = access.get_table(block, cat, itms) # pick first numeric value not yet taken primary_citation_id = _quote( _get_ordinal_ids(set(row["id"] for row in table), 1)[0] ) # correct IDs and find AF3 citation for row in table: if row["pdbx_database_id_PubMed"] == "38718835": old_af3_sw_cit_id = row["id"] if row["id"] == "primary": row["id"] = primary_citation_id new_af3_sw_cit_id = primary_citation_id else: new_af3_sw_cit_id = old_af3_sw_cit_id if old_af3_sw_cit_id is None or _is_null(old_af3_sw_cit_id): # citation available w/o PMID? Fallback option replacing citation old_af3_sw_cit_id, new_af3_sw_cit_id = _fix_citation_fallback( block, primary_citation_id ) if len(table) > 0: # _citation table may have changed, sort it to have "primary" first edit.sort(table, "id", key=_get_key_primary) # make sure _citation.id comes first _ensure_citation_id_first(block) # fix authors (completely replace ones for AF3 publication) cit_auth_dict = block.get_mmcif_category("_citation_author.") cit_auth_dict_new = {"citation_id": [], "name": [], "ordinal": []} if cit_auth_dict: for idx, citation_id in enumerate(cit_auth_dict["citation_id"]): if citation_id != old_af3_sw_cit_id: cit_auth_dict_new["citation_id"].append(citation_id) cit_auth_dict_new["name"].append(cit_auth_dict["name"][idx]) # note: fixed to be in correct style and without special characters af3_authors = [ "Abramson, J.", "Adler, J.", "Dunger, J.", "Evans, R.", "Green, T.", "Pritzel, A.", "Ronneberger, O.", "Willmore, L.", "Ballard, A.J.", "Bambrick, J.", "Bodenstein, S.W.", "Evans, D.A.", "Hung, C.C.", "O'Neill, M.", "Reiman, D.", "Tunyasuvunakool, K.", "Wu, Z.", "Zemgulyte, A.", "Arvaniti, E.", "Beattie, C.", "Bertolli, O.", "Bridgland, A.", "Cherepanov, A.", "Congreve, M.", "Cowen-Rivers, A.I.", "Cowie, A.", "Figurnov, M.", "Fuchs, F.B.", "Gladman, H.", "Jain, R.", "Khan, Y.A.", "Low, C.M.R.", "Perlin, K.", "Potapenko, A.", "Savy, P.", "Singh, S.", "Stecula, A.", "Thillaisundaram, A.", "Tong, C.", "Yakneen, S.", "Zhong, E.D.", "Zielinski, M.", "Zidek, A.", "Bapst, V.", "Kohli, P.", "Jaderberg, M.", "Hassabis, D.", "Jumper, J.M.", ] cit_auth_dict_new["citation_id"].extend( [new_af3_sw_cit_id] * len(af3_authors) ) cit_auth_dict_new["name"].extend(af3_authors) if cit_auth_dict: cit_auth_dict_new["ordinal"] = list( range(1, len(cit_auth_dict_new["name"]) + 1) ) block.set_mmcif_category("_citation_author.", cit_auth_dict_new) _fix_software(block, new_af3_sw_cit_id)
def _is_af3_server(block): """Check if block was produced with AF3 server or code. True means server, False means "from code". On problems to identify record, raise exception. """ # this is a heuristic and may fail! table = access.get_table(block, "_pdbx_data_usage", items=["details"]) is_server = any(True for r in table if "server" in r["details"].lower()) github_url = "github.com/google-deepmind/alphafold3" is_code = any(True for r in table if github_url in r["details"].lower()) if is_server and is_code or (not is_server and not is_code): raise NotIdentifiedContextRecordError( "_pdbx_data_usage", context=": AlphaFold 3 license type" ) return is_server
[docs] def fix_software_location(block): """Ensures the AlphaFold 3 _software entry has a correct location URL. Determines whether the `ModelCIF`_ ``block`` originates from the AlphaFold 3 server or a local installation and sets the corresponding URL in _software.location. If the column does not yet exist it is created; otherwise only the row for AlphaFold 3 is updated. Examples: >>> from gemmi import cif >>> from modelarchive.modelcif import access, fix_af3 >>> # start with an empty CIF document >>> CIF_DATA = '''data_test ... _pdbx_data_usage.details "... alphafoldserver.com/output-terms." ... _pdbx_data_usage.id 1 ... _pdbx_data_usage.type license ... _pdbx_data_usage.url ? ... # ... loop_ ... _software.classification ... _software.date ... _software.description ... _software.name ... _software.pdbx_ordinal ... _software.type ... _software.version ... other ? "Structure prediction" AlphaFold 1 package AlphaFold-beta ... ''' >>> block = cif.read_string(CIF_DATA).sole_block() >>> fix_af3.fix_software_location(block) >>> # Just check that _software.location exists and has the right value >>> table = access.get_table(block, "_software") >>> assert "_software.location" in table.tags >>> assert table[0]["location"] == "https://alphafoldserver.com/" >>> # Change block to look like ModelCIF file from local installation >>> table = access.get_table(block, "_pdbx_data_usage") >>> table[0]["details"] = "...github.com/google-deepmind/alphafold3..." >>> fix_af3.fix_software_location(block) >>> # Check _software.location to point to GitHub, now >>> table = access.get_table(block, "_software") >>> assert table[0]["location"] == \ "https://github.com/google-deepmind/alphafold3" Args: block (|gemmicifBlock|): CIF block to operate on. Returns: None Raises: NotIdentifiedContextRecordError: If no AlphaFold 3 entry is found in the _software table. NotIdentifiedContextRecordError: If the origin of the AlphaFold 3 license could not be identified in the _pdbx_data_usage table. NotIdentifiedDuplicatedRecordError: If multiple entries for AlphaFold 3 are found in the _software table. """ is_server = _is_af3_server(block) if is_server: af3_sw_url = "https://alphafoldserver.com/" else: af3_sw_url = "https://github.com/google-deepmind/alphafold3" sw_table = access.get_table(block, "_software") af3_url_setter = _AF3ItemSetter(af3_sw_url, item="location") if sw_table: if "_software.location" not in sw_table.tags: edit.add_column(block, "_software", "location", af3_url_setter) else: for row in sw_table: row["location"] = af3_url_setter(row, same=True) if not af3_url_setter.sw_found: raise NotIdentifiedContextRecordError( "_software", context=": AlphaFold 3 not found" )
def _get_cat_dict(block, category): """Fetch category as dict, raise exception otherwise.""" cat_dict = block.get_mmcif_category(f"{category}.") if not cat_dict: raise edit.NotFoundCategoryError(category) return cat_dict def _get_af3_sw_group(block): """Get software_group_id for AF3 (to be used in protocols and QE). Assumption: single SW group exists which points to AF3. """ sw_group_table = access.get_table(block, "_ma_software_group") if not sw_group_table: raise edit.NotFoundCategoryError("_ma_software_group") if len(sw_group_table) == 1: return sw_group_table[0]["group_id"] # Multiple SW groups, look for AF3 entry sw_id = None sw_table = access.get_table(block, "_software") if not sw_table: raise edit.NotFoundCategoryError("_software") for row in sw_table: if _is_af3_sw_name(row["name"]): if sw_id is not None: raise NotIdentifiedDuplicatedRecordError( "_software", "AlphaFold" ) sw_id = row["pdbx_ordinal"] for row in sw_group_table: if row["software_id"] == sw_id: return row["group_id"] raise NotIdentifiedContextRecordError( "_software", context=": AlphaFold 3 not found" )
[docs] def fix_protocol(block): """Fix the MA protocol to a single well-formed step. Rewrites _ma_data, _ma_data_group, and _ma_protocol_step from scratch based on the existing _ma_target_entity, _ma_model_list and _ma_software_group categories. Any prior content in those three categories is silently overwritten. Data layout after the call: _ma_data: One record per target entity (content_type "target") followed by one record per model (content_type "model coordinates"). IDs are assigned sequentially starting at 1. _ma_data_group: Group 1 - all target data IDs (input side). Group 2 - all model data IDs (output side). _ma_protocol_step: A single step referencing the AF3 software group, group 1 as input, and group 2 as output. Examples: >>> from gemmi import cif >>> from modelarchive.modelcif import access, fix_af3 >>> # start with an empty CIF document >>> CIF_DATA = '''data_test ... # ... loop_ ... _entity.id ... _entity.pdbx_description ... _entity.type ... 1 "bestest polymer in universe" polymer ... 2 "second best polythingi in universe" polymer ... # ... loop_ ... _ma_target_entity.data_id ... _ma_target_entity.entity_id ... _ma_target_entity.origin ... 1 1 . ... 1 2 . ... # ... _ma_model_list.data_id 1 ... _ma_model_list.model_group_id 1 ... _ma_model_list.model_group_name "AlphaFold-beta-20231127 (...)" ... _ma_model_list.model_id 1 ... _ma_model_list.model_name "Top ranked model" ... _ma_model_list.model_type "Ab initio model" ... _ma_model_list.ordinal_id 1 ... # ... loop_ ... _ma_software_group.group_id ... _ma_software_group.ordinal_id ... _ma_software_group.software_id ... 1 1 1 ... # ... loop_ ... _software.classification ... _software.date ... _software.description ... _software.name ... _software.pdbx_ordinal ... _software.type ... _software.version ... other ? "Structure prediction" AlphaFold 1 package AlphaFold-beta ... ''' >>> block = cif.read_string(CIF_DATA).sole_block() >>> fix_af3.fix_protocol(block) >>> access.get_table(block, "_entity").erase() >>> access.get_table(block, "_ma_data").erase() >>> access.get_table(block, "_ma_data_group").erase() >>> access.get_table(block, "_ma_model_list").erase() >>> access.get_table(block, "_ma_software_group").erase() >>> access.get_table(block, "_ma_target_entity").erase() >>> access.get_table(block, "_software").erase() >>> print(block.as_string()) data_test loop_ _ma_protocol_step.ordinal_id _ma_protocol_step.protocol_id _ma_protocol_step.step_id _ma_protocol_step.method_type _ma_protocol_step.details _ma_protocol_step.software_group_id _ma_protocol_step.input_data_group_id _ma_protocol_step.output_data_group_id 1 1 1 modeling 'Model generated with AlphaFold 3.' 1 1 2 <BLANKLINE> Args: block (|gemmicifBlock|): CIF block to operate on. Returns: None Raises: edit.NotFoundCategoryError: If any required source category is absent: _entity, _ma_target_entity, _ma_model_list, or _ma_software_group. edit.NotFoundItemError: If _ma_target_entity.data_id, _ma_model_list.data_id or _ma_model_list.model_name are missing. NotIdentifiedDuplicatedRecordError: If multiple _ma_software_group records exist and the AF3 entry cannot be unambiguously identified in _software. NotIdentifiedContextRecordError: If multiple _ma_software_group records exist but no AF3 entry can be found in _software at all. """ # collect data to add data_dict = {"id": [], "name": [], "content_type": []} data_ids_in = [] data_ids_out = [] # add targets entity_dict = _get_cat_dict(block, "_entity") entity_descs = dict(zip(entity_dict["id"], entity_dict["pdbx_description"])) trg_ent_table = access.get_table(block, "_ma_target_entity") if not trg_ent_table: raise edit.NotFoundCategoryError("_ma_target_entity") # Raise when _ma_target_entity.data_id does not exist, that could mean # a lot has changed in the way AF writes ModelCIF files. if "_ma_target_entity.data_id" not in trg_ent_table.tags: raise edit.NotFoundItemError("_ma_target_entity.data_id") for i, row in enumerate(trg_ent_table, start=1): row["data_id"] = str(i) data_dict["id"].append(row["data_id"]) data_ids_in.append(row["data_id"]) data_dict["name"].append(entity_descs[row["entity_id"]]) data_dict["content_type"].append("target") # add model mdl_list_table = access.get_table(block, "_ma_model_list") if not mdl_list_table: raise edit.NotFoundCategoryError("_ma_model_list") for itm in ["data_id", "model_name"]: if f"_ma_model_list.{itm}" not in mdl_list_table.tags: raise edit.NotFoundItemError(f"_ma_model_list.{itm}") for i, row in enumerate(mdl_list_table, start=len(data_dict["id"]) + 1): row["data_id"] = str(i) data_dict["id"].append(row["data_id"]) data_dict["name"].append(cif.as_string(row["model_name"])) data_dict["content_type"].append("model coordinates") data_ids_out.append(row["data_id"]) # write data (need to be able to overwrite!) # Using set_mmcif_category() here is OK, there are at least to data records, # target & model, so it will alwyas be a loop. block.set_mmcif_category("_ma_data.", data_dict) edit.move_category(block, "_ma_data", "after:_ma_model_list") # add 2 data groups (1 for input, 2 for output) num_data_ids = len(data_ids_in) + len(data_ids_out) block.set_mmcif_category( "_ma_data_group.", { "ordinal_id": list(range(1, num_data_ids + 1)), "group_id": [1] * len(data_ids_in) + [2] * len(data_ids_out), "data_id": data_ids_in + data_ids_out, }, ) edit.move_category(block, "_ma_data_group", "after:_ma_data") # find SW group af3_sw_group = _get_af3_sw_group(block) # add single protocol step block.set_mmcif_category( "_ma_protocol_step.", { "ordinal_id": [1], "protocol_id": [1], "step_id": [1], "method_type": ["modeling"], "details": ["Model generated with AlphaFold 3."], "software_group_id": [af3_sw_group], "input_data_group_id": [1], "output_data_group_id": [2], }, ) edit.move_category(block, "_ma_protocol_step", "after:_ma_data_group")
[docs] def add_per_residue_plddt(block): """Add average per-residue pLDDT scores to an AF3 ModelCIF file. Adds _ma_qa_metric_local data derived from B-factor values in _atom_site. The per-residue pLDDT is computed as the mean over all atoms of a residue. Non-polymer residues (missing value in _atom_site.label_seq_id) are excluded. If _ma_qa_metric_local is already present in the block, the function exits early with a warning. If no local pLDDT entry exists in _ma_qa_metric, it will be added; if more than local pLDDT entry is found, an exception is raised as this is most likely an error in the ModelCIF file. This fix targets AF3 files predating version 3.0.1, which lack _ma_qa_metric_local. Examples: >>> from gemmi import cif >>> from modelarchive.modelcif import access, fix_af3 >>> # Please note: the example CIF document for this case has the >>> # _atom_site category reduce to the bare minimum to make the >>> # mechanics of add_per_residue_plddt() work, to keep the example >>> # shorter. >>> CIF_DATA = '''data_test ... # ... loop_ ... _ma_software_group.group_id ... _ma_software_group.ordinal_id ... _ma_software_group.software_id ... 1 1 1 ... # ... loop_ ... _software.classification ... _software.date ... _software.description ... _software.name ... _software.pdbx_ordinal ... _software.type ... _software.version ... other ? "Structure prediction" AlphaFold 1 package AlphaFold-beta ... # ... loop_ ... _atom_site.group_PDB ... _atom_site.label_comp_id ... _atom_site.label_asym_id ... _atom_site.label_seq_id ... _atom_site.B_iso_or_equiv ... _atom_site.pdbx_PDB_model_num ... ATOM MET A 1 35.00 1 ... ATOM ALA A 2 50.30 1 ... ATOM THR A 3 65.75 1 ... ''' >>> block = cif.read_string(CIF_DATA).sole_block() >>> fix_af3.add_per_residue_plddt(block) >>> # After execution, the CIF document has categories _ma_qa_metric >>> # and _ma_qa_metric_local added >>> # There should be only 1 record in _ma_qa_metric >>> qa_dict = block.get_mmcif_category("_ma_qa_metric.") >>> print(qa_dict) {'id': ['1'], 'mode': ['local'], 'name': ['pLDDT'], \ 'software_group_id': ['1'], 'type': ['pLDDT']} >>> # There should be 3 records of local scores >>> table = access.get_table(block, "_ma_qa_metric_local.") >>> print("# chain res seqID pLDDT") # chain res seqID pLDDT >>> for r in table: ... print( ... f"{r['ordinal_id']} {r['label_asym_id']} " ... + f"{r['label_comp_id']} {r['label_seq_id']} " ... + f"{r['metric_value']}" ... ) 1 A MET 1 35.0 2 A ALA 2 50.3 3 A THR 3 65.75 Args: block (|gemmicifBlock|): CIF block to operate on. Returns: None Raises: RuntimeError: If ``_ma_qa_metric`` contains more than one local pLDDT entry. """ if "_ma_qa_metric_local." in block.get_mmcif_category_names(): _utils.warn_msg("_ma_qa_metric_local already there, skipped.") return # find metric ID for pLDDT table = block.get_mmcif_category("_ma_qa_metric.") if len(table) == 0: table = { "id": [], "mode": [], "name": [], "software_group_id": [], "type": [], } qa_local = [ mid for mid, mtype, mode in zip(table["id"], table["type"], table["mode"]) if mtype == "pLDDT" and mode == "local" ] if len(qa_local) == 1: # expected route metric_id = qa_local[0] elif len(qa_local) == 0: # add it if possible assert sorted(table.keys()) == [ "id", "mode", "name", "software_group_id", "type", ] metric_id = _get_ordinal_ids(table["id"], 1)[0] table["id"].append(metric_id) table["mode"].append("local") table["name"].append("pLDDT") table["software_group_id"].append(_get_af3_sw_group(block)) table["type"].append("pLDDT") # set_mmcif_category() will create a loop, even for single records... # that is OK here as having multiple metrics (at least 1 global and # 1 local) is more than common. block.set_mmcif_category("_ma_qa_metric.", table) else: raise RuntimeError( "Unexpected number of local pLDDT entries in _ma_qa_metric." ) # code adapted from src/alphafold3/model/mmcif_metadata.py as modified in # https://github.com/google-deepmind/alphafold3/commit/121716e res_id_keys = [ "label_asym_id", "label_seq_id", "label_comp_id", "pdbx_PDB_model_num", ] table = access.get_table( block, "_atom_site", res_id_keys + ["B_iso_or_equiv"] ) plddt_grouped_by_res = {} for row in table: group_key = tuple(row[v] for v in res_id_keys) plddt_grouped_by_res.setdefault(group_key, []).append( cif.as_number(row["B_iso_or_equiv"]) ) qa_local = { "ordinal_id": [], "label_asym_id": [], "label_comp_id": [], "label_seq_id": [], "metric_id": [], "metric_value": [], "model_id": [], } for ordinal_id, ( (chain_id, res_id, res_name, model_id), res_plddts, ) in enumerate(plddt_grouped_by_res.items(), start=1): # skip non-polymer residues (note AF3 includes those) if not cif.is_null(res_id): res_plddt = round(np.mean(res_plddts), 2) qa_local["ordinal_id"].append(str(ordinal_id)) qa_local["label_asym_id"].append(chain_id) qa_local["label_seq_id"].append(res_id) qa_local["label_comp_id"].append(res_name) qa_local["metric_id"].append(metric_id) qa_local["metric_value"].append(str(res_plddt)) qa_local["model_id"].append(model_id) block.set_mmcif_category("_ma_qa_metric_local.", qa_local, raw=True)
def _add_auth_comp_id_2_atom_site(block): """Add item auth_comp_id to category _atom_site, if missing. As values, item label_comp_id will be copied.""" edit.add_column( block, "_atom_site", "auth_comp_id", edit.make_copy_value_in_row("label_comp_id"), pos=17, ) _return_mon_id = edit.make_copy_value_in_row("mon_id") def _add_pdb_mon_id_2_pdbx_poly_seq_scheme(block): """Add item pdb_mon_id to _pdbx_poly_seq_scheme.""" edit.add_column( block, "_pdbx_poly_seq_scheme", "pdb_mon_id", _return_mon_id, pos=9, ) def _add_pdb_mon_id_2_pdbx_branch_scheme(block): """Add item pdb_mon_id to _pdbx_branch_scheme from _atom_site.auth_comp_id""" try: edit.add_column( block, "_pdbx_branch_scheme", "pdb_mon_id", _return_mon_id, pos=9, ) except edit.NotFoundCategoryError: pass def _add_pdbx_entity_branch_list_2_data(block): """Add category _pdbx_entity_branch_list to data block.""" # first check that pdbx_branch_scheme exists pbs = access.get_table(block, "_pdbx_branch_scheme") if not pbs: return # gather values bl_items = {"entity_id": [], "num": [], "comp_id": []} for row in pbs: bl_items["entity_id"].append(row["entity_id"]) bl_items["num"].append(row["num"]) bl_items["comp_id"].append(row["mon_id"]) # place above pdbx_branch_scheme edit.add_category( block, "_pdbx_entity_branch_list", bl_items, index="before:_pdbx_branch_scheme", ) def _add_ndb_seq_num_2_pdbx_nonpoly_scheme(block): """Add item ndb_seq_num to category _pdbx_nonpoly_scheme, if missing.""" try: edit.add_column( block, "_pdbx_nonpoly_scheme", "ndb_seq_num", edit.make_res_per_chain_counter("asym_id"), pos=6, ) except edit.NotFoundCategoryError: pass def _relabel_lone_sugars(block): """AF3 has single sugars marked as 'branched' entities (wrong) but lists them in '_pdbx_nonpoly_scheme' (right). So switch the '_entity.type' to 'non-polymer' as it appears in RCSB.""" # single sugars marked as branched: # check for branched entities, that don't appear in pdbx_branch_scheme cats = {} cats["_entity"] = access.get_table(block, "_entity", items=["id", "type"]) cats["_pdbx_branch_scheme"] = access.get_table( block, "_pdbx_branch_scheme", items=["entity_id"] ) if not cats["_pdbx_branch_scheme"]: return for ent_row in cats["_entity"]: if ent_row["type"].upper() != "BRANCHED": continue found = False for pbs_row in cats["_pdbx_branch_scheme"]: if pbs_row["entity_id"] == ent_row["id"]: found = True break if not found: ent_row["type"] = "non-polymer" ENT_COLS = ["pdbx_description", "type"] def _delete_record(cat, item, valuep, valuec, eitems, block, cache): """Delete a record from a category.""" # Hard to reduce arguments # pylint: disable=too-many-arguments,too-many-positional-arguments if cat not in cache: cache[cat] = access.get_table(block, cat, items=[item] + eitems) # fetch values for comparison/ delete child record ntp = {} ntc = {} cidx = None for row in cache[cat]: if row[item] == valuep: for i in eitems: ntp[i] = row[i] elif row[item] == valuec: cidx = row.row_index for i in eitems: ntc[i] = row[i] if cidx is None: raise NotIdentifiedSingleRecordError(cat, item=item, value=valuec) if ntp != ntc: raise RuntimeError( f"Records to be merged not equal: {ntp} vs. {ntc} ({item}: " + f"{valuep}/ {valuec})" ) cache[cat].remove_row(cidx) def _merge_entities(mntp, mntc, etype, block, cache): """Merge two entities. Remove one entity and update the block.""" # _entity _delete_record("_entity", "id", mntp, mntc, ENT_COLS, block, cache) # _ma_target_entity _delete_record( "_ma_target_entity", "entity_id", mntp, mntc, ["data_id", "origin"], block, cache, ) # _ma_target_entity_instance _update_records( "_ma_target_entity_instance", "entity_id", mntp, mntc, block, cache, ) # _struct_asym _update_records("_struct_asym", "entity_id", mntp, mntc, block, cache) # _atom_site _update_records("_atom_site", "label_entity_id", mntp, mntc, block, cache) if etype.upper() == "POLYMER": # _entity_poly _aggregate_entity_poly(mntp, mntc, block, cache) # _entity_poly_seq _delete_records("_entity_poly_seq", "entity_id", mntc, cache) # _pdbx_poly_seq_scheme _update_records( "_pdbx_poly_seq_scheme", "entity_id", mntp, mntc, block, cache ) elif etype.upper() == "BRANCHED": # _pdbx_entity_branch_list _update_records( "_pdbx_entity_branch_list", "entity_id", mntp, mntc, block, cache ) # _pdbx_branch_scheme _update_records( "_pdbx_branch_scheme", "entity_id", mntp, mntc, block, cache ) elif etype.upper() == "NON-POLYMER": # _pdbx_nonpoly_scheme _update_records( "_pdbx_nonpoly_scheme", "entity_id", mntp, mntc, block, cache ) def _reduce_duplicated_entities(block): """Multimeric AF3 ModelCIF files may have entities duplicated. When seeing the same molecular entity multiple times, reduce it to a single molecular enityt but multiple chains.""" cats = {} # check for duplicated entities cats["_entity"] = access.get_table( block, "_entity", items=["id"] + ENT_COLS ) ments = {} for row in cats["_entity"]: ments[row["id"]] = {"seq": "", "type": row["type"]} # fetch polymers cats["_entity_poly_seq"] = access.get_table( block, "_entity_poly_seq", items=["entity_id", "mon_id"] ) for row in cats["_entity_poly_seq"]: ments[row["entity_id"]]["seq"] += row["mon_id"] # fetch non-polymers cats["_pdbx_nonpoly_scheme"] = access.get_table( block, "_pdbx_nonpoly_scheme", items=["entity_id", "mon_id"] ) for row in cats["_pdbx_nonpoly_scheme"]: ments[row["entity_id"]]["seq"] += row["mon_id"] # fetch branched entities cats["_pdbx_branch_scheme"] = access.get_table( block, "_pdbx_branch_scheme", items=["entity_id", "mon_id"] ) for row in cats["_pdbx_branch_scheme"]: ments[row["entity_id"]]["seq"] += row["mon_id"] # look for missed entities for k, v in ments.items(): if len(v["seq"]) == 0: raise RuntimeError(f"Empty entity found: '{k}'") # fix duplicated entities... while len(ments) > 1: kc, vc = ments.popitem() for kp, vp in ments.items(): if vp["seq"] == vc["seq"]: _merge_entities(kp, kc, vp["type"], block, cats) break # make entities run sequentially (close holes) _make_entity_ids_sequential(block, cats) def _make_entity_ids_sequential(block, cache): """After removing duplicated entities, IDs may not be sequential anymore. Fix that.""" # iterate entity table, check for ID that is not +1 of the last seen cat = "_entity" # This should never be needed, should not be possible to reach this point # w/o reading "_entity" at least once. # if cat not in cache: # cache[cat] = access.get_table(block, cat, items=["id"] + ENT_COLS) last_id = 0 for row in cache[cat]: cur_id = cif.as_int(row["id"]) last_id = last_id + 1 if cur_id != last_id: cur_id = str(cur_id) last_id_str = str(last_id) _update_record("_entity", "id", cur_id, last_id, cache) _update_records( "_atom_site", "label_entity_id", last_id_str, cur_id, block, cache, ) _update_record( "_ma_target_entity", "entity_id", cur_id, last_id, cache ) _update_records( "_ma_target_entity_instance", "entity_id", last_id_str, cur_id, block, cache, ) _update_records( "_struct_asym", "entity_id", last_id_str, cur_id, block, cache ) if row["type"].upper() == "POLYMER": _update_records( "_entity_poly", "entity_id", last_id_str, cur_id, block, cache, ) _update_records( "_entity_poly_seq", "entity_id", last_id_str, cur_id, block, cache, ) _update_records( "_pdbx_poly_seq_scheme", "entity_id", last_id_str, cur_id, block, cache, ) elif row["type"].upper() == "BRANCHED": _update_records( "_pdbx_entity_branch_list", "entity_id", last_id_str, cur_id, block, cache, ) _update_records( "_pdbx_branch_scheme", "entity_id", last_id_str, cur_id, block, cache, ) elif row["type"].upper() == "NON-POLYMER": _update_records( "_pdbx_nonpoly_scheme", "entity_id", last_id_str, cur_id, block, cache, ) def _update_records(cat, item, valuep, valuec, block, cache): """Exchange the value of a record with another.""" # Hard to reduce arguments # pylint: disable=too-many-arguments,too-many-positional-arguments if cat not in cache: cache[cat] = access.get_table(block, cat, items=[item]) # get indeces to be deleted for row in cache[cat]: if row[item] == valuec: row[item] = valuep def _update_record(cat, item, value_old, value_new, cache): """Exchange the value of a single record.""" # Hard to reduce arguments # pylint: disable=too-many-arguments,too-many-positional-arguments # Does not seem to be needed, records to be deleted are most likely vistied # beforehand. # if cat not in cache: # cache[cat] = access.get_table(block, cat, items=[item]) for row in cache[cat]: if row[item] == value_old: row[item] = str(value_new) break def _aggregate_entity_poly(valuep, valuec, block, cache): """For entity_poly, that needs reduction of duplicated entities, gather chain names. """ # _entity_poly.entity_id # _entity_poly.pdbx_strand_id # _entity_poly.type cat = "_entity_poly" if cat not in cache: cache[cat] = access.get_table( block, cat, items=["entity_id", "pdbx_strand_id", "type"] ) # find parent row rowp = None for row in cache[cat]: if row["entity_id"] == valuep: rowp = row typep = row["type"] break if rowp is None: raise RuntimeError(f"No '_entity_poly' with entity_id={valuep} found.") # get child records strand_ids = rowp["pdbx_strand_id"].split(",") idxsc = [] for row in cache[cat]: if row["entity_id"] == valuec: if row["type"] != typep: raise RuntimeError( "'entity_poly' with duplicated entities mismatch at " + f"'type': '{typep}' vs '{row['type']}'" ) strand_ids.extend(row["pdbx_strand_id"].split(",")) idxsc.append(row.row_index) # update & delete for i in sorted(idxsc, reverse=True): cache[cat].remove_row(i) rowp["pdbx_strand_id"] = ",".join(strand_ids) def _delete_records(cat, item, valuec, cache): """Delete multiple records with a certain item and value.""" # if cat not in cache: # cache[cat] = access.get_table(block, cat, items=[item]) # get indeces to be deleted cidxs = [] for row in cache[cat]: if row[item] == valuec: cidxs.append(row.row_index) # if len(cidxs) == 0: # raise NotIdentifiedSingleRecordError(cat, item=item, value=valuec) for i in sorted(cidxs, reverse=True): cache[cat].remove_row(i) def _fix_atom_names(block): """Sometimes AF3 doesn't get the atom names following IUPAC.""" def _fix_atom_symbol(table): for row in table: for i in row[0]: if i.islower(): row[0] = row[0].upper() break table = access.get_table(block, "_atom_type", items=["symbol"]) _fix_atom_symbol(table) table = access.get_table(block, "_atom_site", items=["type_symbol"]) _fix_atom_symbol(table) def _load_cache(cache_file): """Load or start the cache file for chem.comps.""" # ToDo: update mechanism # cache is of form {"<SMILES>": <COMPOUND>} cache = {} if cache_file.exists(): with open(cache_file, encoding="ascii") as jfh: cache = json.load(jfh) return cache def _save_cache(cache, cache_file): """Store cache on disk.""" # ToDo: update mechanism with open(cache_file, "w", encoding="ascii") as jfh: json.dump(cache, jfh) def _get_chem_comp_by_smiles(smiles, cache): """Fetch a chem. component from RCSB using SMILES.""" # 1. Try to solve by cache if smiles in cache: return cache[smiles] # 2. get ID payload = { "query": { "type": "terminal", "service": "chemical", "parameters": { "value": smiles, "type": "descriptor", "descriptor_type": "SMILES", "match_type": "graph-exact", }, }, "return_type": "mol_definition", } response = requests.post( "https://search.rcsb.org/rcsbsearch/v2/query", json=payload, headers={"Content-Type": "application/json"}, timeout=60, ) if not response.ok: raise RuntimeError( f"Querying chem.comp. name failed for SMILES '{smiles}'." ) try: json_rspnse = response.json() except requests.exceptions.JSONDecodeError: cache[smiles] = None return cache[smiles] if json_rspnse["total_count"] != 1: cache[smiles] = None return cache[smiles] json_rspnse = json_rspnse["result_set"][0] if json_rspnse["score"] != 1.0: raise RuntimeError(f"Score not 1.0 for SMILES '{smiles}'.") # 3. get more info about the ligand to fill the _chem_comp table response = requests.get( "https://data.rcsb.org/rest/v1/core/chemcomp/" + f"{json_rspnse['identifier']}", timeout=60, ) if not response.ok: raise RuntimeError( "Querying chem.comp. data failed for ID " + f"'{json_rspnse['identifier']}'." ) json_rspnse = response.json() cache[smiles] = json_rspnse["chem_comp"] return cache[smiles] def _update_items(cat, items, exchange, block): """Exchange the value of a record with another.""" table = access.get_table(block, cat) for row in table: for itm in items: try: row[itm] = exchange[row[itm]] except RuntimeError as exc: if str(exc) == f"Column name not found: {itm}": pass else: raise # pragma: no cover (more of a gemmi.cif thing) except KeyError as exc: if cif.as_string(str(exc)) == "HOH": # RCSB seems to not have HOH in _chem_comp in all PDB # entries, so we allow it missing here, too pass else: raise def _fix_chem_comp_id_ligands(block, cache_file): """Fix the 'LIG_<CHAR> naming scheme.""" if not isinstance(cache_file, Path): cache_file = Path(cache_file) cache = _load_cache(cache_file) table = access.get_table(block, "_chem_comp") del_lst = [] smiles_lst = {} unk_count = 1 updtd = {} update_needed = False for row in table: updtd[row["id"]] = row["id"] if row["name"] == "?": update_needed = True smiles = cif.as_string(row["pdbx_smiles"]) if smiles in smiles_lst: del_lst.append(row.row_index) updtd[row["id"]] = smiles_lst[smiles] continue comp = _get_chem_comp_by_smiles(smiles, cache) if comp is None: # Compounds not found in the PDB use the reserved IDs 01-99 to # preserve the SMILES strings. Using 'UNL' instead would mean to # wipe the SMILES string as 'UNL' is defined with every # attribute blank to match all unknown ligands. if unk_count >= 99: _utils.warn_msg( "No. of unknown ligands exceeds 99, won't renumber " + f"chem.comp. '{row['id']}'.", ) continue updtd[row["id"]] = f"{unk_count:02d}" smiles_lst[smiles] = f"{unk_count:02d}" row["id"] = f"{unk_count:02d}" unk_count += 1 continue updtd[row["id"]] = _quote(comp["id"]) smiles_lst[smiles] = _quote(comp["id"]) for i in ["formula", "formula_weight", "id", "name", "type"]: try: row[i] = _quote(comp[i]) except TypeError: row[i] = str(comp[i]) # remove rows with duplicated ligands for i in sorted(del_lst, reverse=True): table.remove_row(i) _save_cache(cache, cache_file) # Update categories if not update_needed: return cat_2_updt = [ ("_atom_site", ["label_comp_id"]), ("_entity_poly_seq", ["parent_mon_id", "mon_id"]), ("_pdbx_branch_scheme", ["mon_id", "pdb_mon_id", "auth_mon_id"]), ("_pdbx_nonpoly_scheme", ["mon_id", "pdb_mon_id", "auth_mon_id"]), ("_pdbx_poly_seq_scheme", ["mon_id", "pdb_mon_id", "auth_mon_id"]), ] for cat in cat_2_updt: _update_items(cat[0], cat[1], updtd, block)
[docs] def fix_modelcif_issues(block, compdict_cache=".compdict_cache"): """Fix multiple small issues in AF3 ModelCIF files. Things corrected: * _atom_site.auth_comp_id gets added if not present * _pdbx_poly_seq_scheme.pdb_mon_id gets added if not present * _pdbx_branch_scheme.pdb_mon_id gets added if not present * _pdbx_entity_branch_list gets added when _pdbx_branch_scheme exists * _pdbx_nonpoly_scheme.ndb_seq_num gets added if not present * single sugars mistakenly marked as 'branched' entity will be relabelled to 'non-polymer' in the _entity category * duplicated molecular entities are reduced to a single molecular entity (AF3 adds a molecular entity per copy of a molecule) * atom names are changed to comply with IUPAC * ligand naming scheme ``LIG_<CHARACTER>`` is replaced with proper molecule names (if possible, via RCSB) Args: block (|gemmicifBlock|): CIF block to operate on. compdict_cache (str | Path): Path to the cache file for RCSB API calls for chemical compounds. Defaults to ``.compdict_cache``. Returns: None Raises: NotIdentifiedSingleRecordError: If a record to be deleted cannot be found. RuntimeError: If entities to be merged have differing data, if an empty entity is found, or if a SMILES string cannot be identified. RuntimeError: If entities to be merged have differing data, if an empty entity is found, if no _entity_poly record is found for an entity ID, if polymer types of duplicated entities mismatch, or if a SMILES string cannot be identified via the RCSB API. """ _add_auth_comp_id_2_atom_site(block) _add_pdb_mon_id_2_pdbx_poly_seq_scheme(block) _add_pdb_mon_id_2_pdbx_branch_scheme(block) _add_pdbx_entity_branch_list_2_data(block) _add_ndb_seq_num_2_pdbx_nonpoly_scheme(block) _relabel_lone_sugars(block) _reduce_duplicated_entities(block) _fix_atom_names(block) _fix_chem_comp_id_ligands(block, compdict_cache)
[docs] def add_data_from_json_files( block, input_path, full_qe_path, summary_qe_path, out_zip_path, pairwise_in_zip=True, use_local_pairwise_if_possible=False, ): """Add QA metrics and metadata from AF3 JSON files to an ModelCIF block. Reads the AF3 input JSON, full confidence JSON, and summary confidence JSON to populate QA metric categories in ``block``. Packs the input JSON and, optionally, pairwise QA scores into a ZIP archive at ``out_zip_path``. Updates _ma_qa_metric, _ma_qa_metric_global, _ma_qa_metric_feature, _ma_qa_metric_local_pairwise, _ma_qa_metric_feature_pairwise, _ma_entry_associated_files, _ma_associated_archive_file_details, _audit_conform|, and _ma_software_parameter (the latter only if model seeds or recycle counts are present in the input JSON). The function derives feature lists from _atom_site. Per-residue pairwise scores are written to _ma_qa_metric_local_pairwise when ``use_local_pairwise_if_possible`` is ``True`` and all tokens are polymer residues (no ``HETATM``); in all other cases _ma_qa_metric_feature_pairwise is used. Args: block (|gemmicifBlock|): mmCIF block to be updated in place. Must already contain _atom_site, _entity, _ma_qa_metric, _ma_qa_metric_global, _ma_software_group, and _entry categories. input_path (str): Path to the AF3 input JSON file. Server output: ``<JOBNAME>_job_request.json``. full_qe_path (str): Path to the JSON file containing per-atom and per-token confidence arrays (pLDDT, PAE, contact probabilities). Server output: ``<JOBNAME>_full_data_<N>.json``; code output: ``<JOBNAME>_confidences.json``. summary_qe_path (str): Path to the JSON file containing summary confidence values (pTM, ipTM, ranking score, etc.). Server output: ``<JOBNAME>_summary_confidences_<N>.json``; code output: ``<JOBNAME>_summary_confidences.json``. out_zip_path (str): Path for the output ZIP archive. The archive always contains ``input.json`` (a copy of ``input_path``). If ``pairwise_in_zip`` is ``True``, a ``pairwise_qa.cif`` file with the pairwise QA metrics is also included. The value of _ma_entry_associated_files.file_url is set to the bare filename (i.e. without any directory component), so the main CIF file and the ZIP must reside in the same directory. pairwise_in_zip (bool): If ``True``, pairwise QA scores are written to a separate ``pairwise_qa.cif`` file and packaged inside the ZIP archive rather than embedded directly in ``block``. Defaults to ``True``. use_local_pairwise_if_possible (bool): If ``True`` and every token in the structure is a polymer residue (no ``HETATM`` records), _ma_qa_metric_local_pairwise is used for pairwise token scores instead of _ma_qa_metric_feature_pairwise. Defaults to ``False``. Returns: None Raises: RuntimeError: If ``full_qe_path`` or ``summary_qe_path`` contain score keys that are not listed in ``known_scores``. """ # No idea how to reduce args, not going fixing the "to many of this and that" # messages at the moment # pylint: disable=too-many-positional-arguments,too-many-arguments # pylint: disable=too-many-locals,too-many-branches,too-many-statements sw_params_group_id = 1 def _add_sw_param(sw_params, data_type, name, value): """Helper to add stuff to SW params.""" sw_params["parameter_id"].append(len(sw_params["parameter_id"]) + 1) sw_params["group_id"].append(sw_params_group_id) sw_params["data_type"].append(data_type) sw_params["name"].append(name) sw_params["value"].append(value) def _add_feature( feature_dict, feature_type, entity_type, specific_dict, **kwargs ): """Helper to add features. Valid feature_type: "entity instance", "residue", "atom" Extra arguments used to fill specific_dict besides ordinal_id and feature_id. Can set feature_dict to None to avoid updating it and only update specific_dict. """ if feature_dict is None: feature_id = None else: feature_id = len(feature_dict["feature_id"]) + 1 feature_dict["feature_id"].append(feature_id) feature_dict["feature_type"].append(cif.quote(feature_type)) feature_dict["entity_type"].append(entity_type) specific_dict["ordinal_id"].append(len(specific_dict["ordinal_id"]) + 1) specific_dict["feature_id"].append(feature_id) for k, v in kwargs.items(): specific_dict[k].append(v) return feature_id # dict for SW params sw_params = { "parameter_id": [], "group_id": [], "data_type": [], "name": [], "value": [], } # anything we can fetch from input json? with open(input_path, encoding="utf8") as jfh: input_data = json.load(jfh) if isinstance(input_data, list): assert len(input_data) == 1 input_data = input_data[0] assert isinstance(input_data, dict) if "modelSeeds" in input_data: modelseeds = input_data["modelSeeds"] assert all(str(int(s)) == s for s in modelseeds) _add_sw_param( sw_params, "integer-csv", "modelSeeds", ",".join(str(s) for s in modelseeds), ) # get entity types entity_dict = block.get_mmcif_category("_entity.", raw=True) entity_types = dict(zip(entity_dict["id"], entity_dict["type"])) # parse atom_site into features atom_feature = {"ordinal_id": [], "feature_id": [], "atom_id": []} # NOTE: res_feature also works to set _ma_qa_metric_local_pairwise # (ignore ordinal_id and feature_id for that) res_feature = { "ordinal_id": [], "feature_id": [], "label_asym_id": [], "label_comp_id": [], "label_seq_id": [], } ch_feature = {"ordinal_id": [], "feature_id": [], "label_asym_id": []} feature = {"feature_id": [], "feature_type": [], "entity_type": []} # IDs ordered to match AF3 scores atom_feature_ids = [] atom_feature_id_map = {} # from atom ID to feature ID token_feature_ids = [] # all None if use_local_pairwise asym_feature_ids = [] # for sanity checks (vs AF3 json) atom_chain_ids = [] token_chain_ids = [] token_res_ids = [] # also keep track of model IDs to fill scores model_ids = [] # go row by row # -> token-logic: HETATM gets per atom, rest per residue atom_site_cols = [ "group_PDB", "id", "label_asym_id", "label_comp_id", "label_entity_id", "label_seq_id", "pdbx_PDB_model_num", "auth_seq_id", ] # ToDo: replace with access.get_table() table = block.find("_atom_site.", atom_site_cols) # first check if we can use local pairwise for token based scores all_res_tokens = all(row["group_PDB"] != "HETATM" for row in table) use_local_pairwise = all_res_tokens and use_local_pairwise_if_possible # do all atom features first for row in table: # check model ID if row["pdbx_PDB_model_num"] not in model_ids: model_ids.append(row["pdbx_PDB_model_num"]) # every atom gets a token anyway atom_feature_id = _add_feature( feature_dict=feature, feature_type="atom", entity_type=entity_types[row["label_entity_id"]], specific_dict=atom_feature, atom_id=row["id"], ) atom_feature_ids.append(atom_feature_id) atom_feature_id_map[row["id"]] = atom_feature_id atom_chain_ids.append(row["label_asym_id"]) # then all chain features cur_label_asym_id = None for row in table: # check if new chain if row["label_asym_id"] != cur_label_asym_id: cur_label_asym_id = row["label_asym_id"] asym_feature_ids.append( _add_feature( feature_dict=feature, feature_type="entity instance", entity_type=entity_types[row["label_entity_id"]], specific_dict=ch_feature, label_asym_id=row["label_asym_id"], ) ) # then all tokens for pairwise scores cur_label_asym_id = None cur_label_seq_id = None for row in table: # check if new chain if row["label_asym_id"] != cur_label_asym_id: cur_label_asym_id = row["label_asym_id"] cur_label_seq_id = None # check if new token if row["group_PDB"] == "HETATM": token_feature_ids.append(atom_feature_id_map[row["id"]]) token_chain_ids.append(row["label_asym_id"]) token_res_ids.append(cif.as_int(row["auth_seq_id"])) assert not use_local_pairwise elif row["label_seq_id"] != cur_label_seq_id: cur_label_seq_id = row["label_seq_id"] res_feature_id = _add_feature( feature_dict=None if use_local_pairwise else feature, feature_type="residue", entity_type=entity_types[row["label_entity_id"]], specific_dict=res_feature, label_asym_id=row["label_asym_id"], label_comp_id=row["label_comp_id"], label_seq_id=row["label_seq_id"], ) token_feature_ids.append(res_feature_id) token_chain_ids.append(row["label_asym_id"]) token_res_ids.append(cif.as_int(row["auth_seq_id"])) # cannot deal with multiple models! assert len(model_ids) == 1 model_id = model_ids[0] # write them all block.set_mmcif_category("_ma_feature_list.", feature, raw=True) block.set_mmcif_category("_ma_atom_feature.", atom_feature, raw=True) if not use_local_pairwise: # for convenience we abuse of res_feature to fill # _ma_qa_metric_local_pairwise which does not use the feature list! block.set_mmcif_category( "_ma_poly_residue_feature.", res_feature, raw=True ) block.set_mmcif_category( "_ma_entity_instance_feature.", ch_feature, raw=True ) # get all scores with open(summary_qe_path, encoding="utf8") as jfh: summary_qe = json.load(jfh) with open(full_qe_path, encoding="utf8") as jfh: full_qe = json.load(jfh) # some sanity checks assert full_qe["atom_chain_ids"] == atom_chain_ids assert full_qe["token_chain_ids"] == token_chain_ids assert full_qe["token_res_ids"] == token_res_ids # handle extra stuff if "num_recycles" in summary_qe: num_recycles = summary_qe["num_recycles"] assert int(num_recycles) - num_recycles == 0 _add_sw_param(sw_params, "integer", "num_recycles", num_recycles) # collect scores known_extra_keys = ["num_recycles"] all_scores = { k: v for k, v in summary_qe.items() if k not in known_extra_keys } known_extra_keys = ["atom_chain_ids", "token_chain_ids", "token_res_ids"] for k, v in full_qe.items(): assert k not in all_scores if k not in known_extra_keys: all_scores[k] = v # info on known scores if use_local_pairwise: token_pairwise_mode = "local-pairwise" token_pairwise_scores = "per-residue-pair" else: token_pairwise_mode = "per-feature-pair" token_pairwise_scores = "per-token-pair" # keys in all_scores linked to fill _ma_qa_metric + "scores" for handling: # global, per-chain, per-chain-pair, per-atom, per-residue-pair, # per-token-pair known_scores = { "chain_iptm": { "mode": "per-feature", "name": "ipTM per chain", "type": "ipTM", "type_other_details": False, "scores": "per-chain", }, "chain_pair_iptm": { "mode": "per-feature-pair", "name": "ipTM per chain pair", "type": "ipTM", "type_other_details": False, "scores": "per-chain-pair", }, "chain_pair_pae_min": { "mode": "per-feature-pair", "name": "min. PAE per chain pair", "type": "PAE", "type_other_details": False, "scores": "per-chain-pair", }, "chain_ptm": { "mode": "per-feature", "name": "pTM per chain", "type": "pTM", "type_other_details": False, "scores": "per-chain", }, "fraction_disordered": { "mode": "global", "name": "fraction of prediction which is disordered", "type": "normalized score", "type_other_details": False, "scores": "global", }, "has_clash": { "mode": "global", "name": "significant number of clashing atoms?", "type": "boolean", "type_other_details": False, "scores": "global", }, "iptm": { "mode": "global", "name": "ipTM", "type": "ipTM", "type_other_details": False, "scores": "global", }, "ptm": { "mode": "global", "name": "pTM", "type": "pTM", "type_other_details": False, "scores": "global", }, # ranking_score = 0.8*ipTM + 0.2*pTM + 0.5*disorder − 100*has_clash # -> exp_range = [-100, 1.5] "ranking_score": { "mode": "global", "name": "ranking score", "type": "other", "type_other_details": "Combined score in range [-100, 1.5] " + "(higher is better)", "scores": "global", }, "atom_plddts": { "mode": "per-feature", "name": "pLDDT per atom", "type": "pLDDT to polymer", "type_other_details": False, "scores": "per-atom", }, "contact_probs": { "mode": token_pairwise_mode, "name": "contact probability per token pair", "type": "contact probability", "type_other_details": False, "scores": token_pairwise_scores, }, "pae": { "mode": token_pairwise_mode, "name": "PAE per token pair", "type": "PAE", "type_other_details": False, "scores": token_pairwise_scores, }, } extra_scores = sorted(set(all_scores) - set(known_scores)) if extra_scores: raise RuntimeError( f"Unexpected scores {sorted(extra_scores)} found which cannot " f"be handled" ) # add QA metrics qa_metric = block.get_mmcif_category("_ma_qa_metric.") # get IDs to use for each score metric_ids = dict( zip( all_scores.keys(), _get_ordinal_ids(qa_metric["id"], len(all_scores)), ) ) # fix existing entries if "type_other_details" not in qa_metric: qa_metric["type_other_details"] = [False] * len(qa_metric["id"]) for idx in range(len(qa_metric["name"])): qa_metric["type"][idx] = "pLDDT to polymer" # add new metrics for col, val_list in qa_metric.items(): if col == "id": val_list.extend([metric_ids[score_key] for score_key in all_scores]) elif col in ["mode", "name", "type", "type_other_details"]: val_list.extend( [known_scores[score_key][col] for score_key in all_scores] ) elif col == "software_group_id": af3_sw_group = _get_af3_sw_group(block) val_list.extend([af3_sw_group] * len(metric_ids)) else: val_list.extend([False] * len(metric_ids)) block.set_mmcif_category("_ma_qa_metric.", qa_metric) # add scores def _add_scores(qa_dict, metric_id, **kwargs): """Helper to add values specific for QA dict. Uses constant model_id from enclosing function! """ num_items = None scores = kwargs["metric_value"] for k, vl in kwargs.items(): assert len(vl) == len(scores) # remove Nones from output vcut = [v for v, s in zip(vl, scores) if s is not None] qa_dict[k].extend(vcut) if num_items is None: num_items = len(vcut) else: assert num_items == len(vcut) # add ordinals, model ids and metric ids qa_dict["ordinal_id"].extend( [(len(qa_dict["ordinal_id"]) + idx + 1) for idx in range(num_items)] ) qa_dict["model_id"].extend([model_id] * num_items) qa_dict["metric_id"].extend([metric_id] * num_items) # qa_global keys: metric_id, metric_value, model_id, ordinal_id qa_global = block.get_mmcif_category("_ma_qa_metric_global.") qa_feature = { k: [] for k in [ "ordinal_id", "model_id", "feature_id", "metric_id", "metric_value", ] } qa_local_pair = { k: [] for k in [ "ordinal_id", "label_asym_id_1", "label_asym_id_2", "label_comp_id_1", "label_comp_id_2", "label_seq_id_1", "label_seq_id_2", "metric_id", "metric_value", "model_id", ] } qa_feature_pair = { k: [] for k in [ "ordinal_id", "feature_id_1", "feature_id_2", "metric_id", "metric_value", "model_id", ] } for score_key, scores in all_scores.items(): metric_id = metric_ids[score_key] score_handling = known_scores[score_key]["scores"] if score_handling == "global": _add_scores( qa_global, metric_id, metric_value=[scores], ) elif score_handling == "per-chain": _add_scores( qa_feature, metric_id, feature_id=asym_feature_ids, metric_value=scores, ) elif score_handling == "per-chain-pair": for idx, sub_scores in enumerate(scores): _add_scores( qa_feature_pair, metric_id, feature_id_1=[asym_feature_ids[idx]] * len(asym_feature_ids), feature_id_2=asym_feature_ids, metric_value=sub_scores, ) elif score_handling == "per-atom": _add_scores( qa_feature, metric_id, feature_id=atom_feature_ids, metric_value=scores, ) elif score_handling == "per-residue-pair": label_asym_id = res_feature["label_asym_id"] label_comp_id = res_feature["label_comp_id"] label_seq_id = res_feature["label_seq_id"] for idx, sub_scores in enumerate(scores): _add_scores( qa_local_pair, metric_id, label_asym_id_1=[label_asym_id[idx]] * len(label_asym_id), label_asym_id_2=label_asym_id, label_comp_id_1=[label_comp_id[idx]] * len(label_comp_id), label_comp_id_2=label_comp_id, label_seq_id_1=[label_seq_id[idx]] * len(label_seq_id), label_seq_id_2=label_seq_id, metric_value=sub_scores, ) elif score_handling == "per-token-pair": for idx, sub_scores in enumerate(scores): _add_scores( qa_feature_pair, metric_id, feature_id_1=[token_feature_ids[idx]] * len(token_feature_ids), feature_id_2=token_feature_ids, metric_value=sub_scores, ) # start accompanying file if desired entry = block.get_mmcif_category("_entry.") assert len(entry["id"]) == 1 entry_id = entry["id"][0] # also needed for _ma_entry_associated_files if pairwise_in_zip: # create new file d = cif.Document() block_zip = d.add_new_block(block.name) edit.add_category(block_zip, "_entry", entry) block_zip.set_mmcif_category( "_entry_link.", { "id": ["1"], "entry_id": [entry_id], "details": [ "This file is an associated file consisting of pairwise QA " + "metrics. This is a partial mmCIF file and can be " + "validated by merging with the main mmCIF file " + "containing the model coordinates and other associated " + "data." ], }, ) block_zip.set_mmcif_category( "_ma_qa_metric.", block.get_mmcif_category("_ma_qa_metric.") ) block_for_pairwise_qa = block_zip else: block_for_pairwise_qa = block # write all QA categories block.set_mmcif_category("_ma_qa_metric_global.", qa_global, raw=True) block.set_mmcif_category("_ma_qa_metric_feature.", qa_feature, raw=True) block_for_pairwise_qa.set_mmcif_category( "_ma_qa_metric_local_pairwise.", qa_local_pair, raw=True ) block_for_pairwise_qa.set_mmcif_category( "_ma_qa_metric_feature_pairwise.", qa_feature_pair, raw=True ) # add SW params to block if needed if len(sw_params["parameter_id"]) > 0: sw_group_dict = block.get_mmcif_category("_ma_software_group.") sw_group_dict["parameter_group_id"] = [sw_params_group_id] block.set_mmcif_category("_ma_software_group.", sw_group_dict) block.set_mmcif_category("_ma_software_parameter.", sw_params) edit.move_category( block, "_ma_software_parameter", "after:_ma_software_group", ) # update _audit_conform to 1.4.7 (needed for all the *feature* tables) block.set_mmcif_category( "_audit_conform.", { # taken from python-modelcif "dict_location": [ "https://raw.githubusercontent.com/ihmwg/ModelCIF/80e1e22/" + "dist/mmcif_ma.dic" ], "dict_name": ["mmcif_ma.dic"], "dict_version": ["1.4.7"], }, ) # add info for accompanying zip file (note: this blindly overwrites data) if not isinstance(out_zip_path, Path): out_zip_path = Path(out_zip_path) archive_file_id = "1" block.set_mmcif_category( "_ma_entry_associated_files.", { "id": [archive_file_id], "entry_id": [entry_id], "file_url": [out_zip_path.name], "file_type": ["archive"], "file_format": ["zip"], "file_content": ["archive with multiple files"], }, ) associated_archive_file_details = { "id": ["1"], "archive_file_id": [archive_file_id], "file_path": ["input.json"], "file_format": ["json"], "file_content": ["other"], "description": ["Input data provided to AlphaFold 3"], } if pairwise_in_zip: to_add = { "id": "2", "archive_file_id": archive_file_id, "file_path": "pairwise_qa.cif", "file_format": "cif", "file_content": "QA metrics", "description": "Pairwise QA metrics", } for k, v in to_add.items(): associated_archive_file_details[k].append(v) block.set_mmcif_category( "_ma_associated_archive_file_details.", associated_archive_file_details ) # package files with zipfile.ZipFile(out_zip_path, "w", zipfile.ZIP_BZIP2) as cif_zip: cif_zip.write(input_path, arcname="input.json") if pairwise_in_zip: cif_zip.writestr("pairwise_qa.cif", block_zip.as_string())
# LocalWords: CIF homomeric mdl gemmi cif modelarchive modelcif af BLANKLINE # LocalWords: Args gemmicifBlock NotFoundItemError RuntimeError str mmcif qa # LocalWords: AlphaFold PMID ModelArchive ModelCIF NotFoundCategoryError qe # LocalWords: NotIdentifiedSingleRecordError pdbx pLDDT auth pdb mon ndb PAE # LocalWords: NotIdentifiedContextRecordError nonpoly num IUPAC ligand LIG # LocalWords: NotIdentifiedDuplicatedRecordError RCSB compdict HETATM json # LocalWords: JOBNAME pTM ipTM url bool