"""ModelCIF files generated by AlphaFold 3 deviate from the official ModelCIF
definition dictionary in specific cases. In particular, for homomeric
assemblies, each molecular entity copy is written as a separate entity in the
CIF document, instead of defining a single entity referenced multiple times.
This module provides functionality to correct the deviations.
For ModelCIF conversions done for ModelArchive, we first fix entity desctiptions
and then call the following functions in order:
:func:`fix_modelcif_issues`,
:func:`fix_citation`,
:func:`fix_software_location`,
:func:`fix_model_name`,
:func:`fix_protocol`,
:func:`add_per_residue_plddt`, and
:func:`add_data_from_json_files` (if AF3 JSON files available and with
`pairwise_in_zip=True` and `use_local_pairwise_if_possible=True`).
"""
# Copyright (c) 2026, SIB - Swiss Institute of Bioinformatics and
# Biozentrum - University of Basel
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# pylint: disable=too-many-lines
from pathlib import Path
import json
import zipfile
from gemmi import cif
import numpy as np
import requests
from .. import _utils
from . import access
from . import edit
def _is_null(value):
"""Borrowed from gemmi."""
# ToDo: This may become a public function in the future.
return len(value) == 1 and value[0] in ("?", ".")
def _char_table(c):
"""Borrowed from gemmi."""
# fmt: off
table = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 2, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0
] + [0] * 128
# fmt: on
return table[ord(c) % 256]
def _quote(v):
"""Borrowed from gemmi, prefer double quotes above single quotes for AF."""
if all(_char_table(c) == 1 for c in v) and len(v) > 0 and not _is_null(v):
return v
q = ";"
if "\n" not in v:
if '"' not in v:
q = '"'
elif "'" not in v:
q = "'"
v = q + v
if q == ";":
v += "\n"
v += q
return v
[docs]
def fix_model_name(block, mdl_rank):
"""Normalise _ma_model_list.model_name for given rank.
AlphaFold 3 sets _ma_model_list.model_name to "Top ranked model" for
all models, regardless of their rank. This function rewrites the value such
that only ``mdl_rank == 1`` is labelled "Top ranked model". All other ranks
are renamed to "#<``mdl_rank``> ranked model".
Examples:
>>> from gemmi import cif
>>> from modelarchive.modelcif import fix_af3
>>> # get sample CIF data
>>> cif_data = '''data_test
... _ma_model_list.data_id 1
... _ma_model_list.model_name "Top ranked model"
... _ma_model_list.model_type "Ab initio model"
... _ma_model_list.ordinal_id 1
... '''
>>> block = cif.read_string(cif_data).sole_block()
>>> fix_af3.fix_model_name(block, 2)
>>> print(block.as_string())
data_test
_ma_model_list.data_id 1
_ma_model_list.model_name "#2 ranked model"
_ma_model_list.model_type "Ab initio model"
_ma_model_list.ordinal_id 1
<BLANKLINE>
>>> fix_af3.fix_model_name(block, 1)
>>> print(block.as_string())
data_test
_ma_model_list.data_id 1
_ma_model_list.model_name "Top ranked model"
_ma_model_list.model_type "Ab initio model"
_ma_model_list.ordinal_id 1
<BLANKLINE>
Args:
block (|gemmicifBlock|): CIF block to operate on.
mdl_rank (int): Rank of the AlphaFold 3 model. If ``mdl_rank == 1``, the
name is set to "Top ranked model".
Returns:
None
Raises:
RuntimeError: If the ``_ma_model_list`` category contains more than one
row.
edit.NotFoundCategoryError: no software entry found for AF3.
edit.NotFoundItemError: If _ma_model_list.model_name can not be
found in ``block``.
"""
if mdl_rank == 1:
mdl_name = "Top ranked model"
else:
mdl_name = f"#{mdl_rank} ranked model"
table = access.get_table(block, "_ma_model_list", items=["model_name"])
if not table:
raise edit.NotFoundItemError(
msg="File is missing _ma_model_list.model_name, single model "
+ "required"
)
if len(table) != 1:
raise RuntimeError("File must have a single model in _ma_model_list.")
table[0]["model_name"] = _quote(mdl_name)
def _get_ordinal_ids(cur_ids, num_ids_needed):
"""Find set of ordinal IDs avoiding existing ones.
- cur_ids: IDs as strings (hopefully something like 1, 2, 3, ...)
- num_ids_needed: number of IDs to provide (next numerals not in cur_ids)
"""
possible_ids = [
str(i)
for i in range(1, len(cur_ids) + num_ids_needed + 1)
if str(i) not in cur_ids
]
return possible_ids[:num_ids_needed]
[docs]
class NotIdentifiedRecordError(RuntimeError):
"""General exception for records that can not be identified in a table.
This exception should not be raised directly, it exists to define other
"NotIdentified" exceptions inheriting from it.
Args:
msg (str): Exception message.
"""
def __init__(self, msg):
super().__init__(msg)
[docs]
class NotIdentifiedDuplicatedRecordError(NotIdentifiedRecordError):
"""Exception if a duplicated record is found in a table.
Attributes:
category (str): Category with the non-unique records.
Args:
category (str): Missing category.
record_id (str): Identifier for the duplicated record. Not bound to a
specific item on purpose.
"""
def __init__(self, category, record_id):
self.category = category
msg = (
f"Duplicated records found in category '{category}' for "
+ f"'{record_id}'"
)
super().__init__(msg)
[docs]
class NotIdentifiedSingleRecordError(NotIdentifiedRecordError):
"""Exception if a specific record can not be identified in a table.
Attributes:
category (str): Affected category.
item(str|None): Affected item.
Args:
category (str): Affected category.
item (str, optional): Missing item, extends the exception message.
value (str, optional): Value, in case a record is found but with
mismatching value. Extends the exception message.
"""
def __init__(self, category, item=None, value=None):
self.category = category
self.item = item
msg = f"Could not identify record in category '{category}'"
if item is not None and value is not None:
msg += f", mismatch at item '{item}={value}'"
elif item is not None:
msg += f", missing item '{item}'"
msg += "."
super().__init__(msg)
[docs]
class NotIdentifiedContextRecordError(NotIdentifiedRecordError):
"""Exception if a record for a specific context can not be identified.
Attributes:
category (str): Category for which the exception is raised.
item (str|None): Involved item, if any.
Args:
category (str): Affected category.
item (str, optional): Affected item.
context (str, optional): Context, part of the message.
"""
def __init__(self, category, item=None, context=None):
self.category = category
self.item = item
msg = f"Could not identify record in category '{category}'"
if item is not None:
msg += f", item '{item}'"
if context is not None:
msg += context
msg += "."
super().__init__(msg)
def _fix_citation_fallback(block, primary_citation_id):
"""Fix citation for AF3 if usual approach failed.
Return old_af3_sw_cit_id, new_af3_sw_cit_id.
Both set to None if something failed.
"""
exp_cit = {
"country": "UK",
"journal_full": "Nature",
"journal_id_ASTM": "NATUAS",
"journal_id_CSD": "0006",
"journal_id_ISSN": "0028-0836",
"journal_volume": "630",
"page_first": "493",
"page_last": "500",
"pdbx_database_id_DOI": "10.1038/s41586-024-07487-w",
"pdbx_database_id_PubMed": "38718835",
"title": "Accurate structure prediction of biomolecular interactions "
+ "with AlphaFold 3",
"year": "2024",
}
cat = "_citation."
cit_table = access.get_table(block, cat)
num_rows = len(cit_table) if cit_table else 0
# check and abort as needed
if num_rows == 0:
new_af3_sw_cit_id = primary_citation_id
exp_cit = {"id": new_af3_sw_cit_id, **exp_cit}
block.set_pairs(cat, exp_cit)
return None, new_af3_sw_cit_id
# check that all items exist
for itm in exp_cit:
if f"{cat}{itm}" not in cit_table.tags:
raise NotIdentifiedSingleRecordError(cat[:-1], item=itm)
# search for a record with all matching values or '?'
for i in range(num_rows):
found = True
for key, val in exp_cit.items():
# Checking for empty string "" is because gemmi as_string()
# translates "?" and "." to "".
if cif.as_string(cit_table[i][f"{cat}{key}"]) not in ["", val]:
found = False
break
if found:
break
if not found:
# At this point, 'key' and 'val' are defined as 'num_rows' must be
# greateer than 0. Silence Pylint warning.
# pylint: disable=undefined-loop-variable
raise NotIdentifiedSingleRecordError(cat[:-1], item=key, value=val)
old_af3_sw_cit_id = cit_table[i]["id"]
if old_af3_sw_cit_id in ["?", "."] or old_af3_sw_cit_id == "primary":
new_af3_sw_cit_id = primary_citation_id
else:
new_af3_sw_cit_id = old_af3_sw_cit_id
# fix dict
cit_table[i]["id"] = new_af3_sw_cit_id
for key, val in exp_cit.items():
cit_table[i][key] = _quote(val)
return old_af3_sw_cit_id, new_af3_sw_cit_id
def _is_af3_sw_name(sw_name):
"""Check if given _software.name is for AF3."""
return sw_name.lower().startswith("alphafold")
class _AF3ItemSetter:
"""Class as a callback for adding columns.
Sets a fixed value in an AF3 row.
Fails if there are multiple AF3 rows."""
# This is a tiny helper-class with the purpose of storing a state, only
# for local use, disable Pylint warning
# pylint: disable=too-few-public-methods
def __init__(self, value, item="citation_id"):
self.item = item
self.sw_found = False
self.value = value
def __call__(self, row, same=False):
if _is_af3_sw_name(row["name"]):
if self.sw_found:
raise NotIdentifiedDuplicatedRecordError(
"_software", "AlphaFold"
)
self.sw_found = True
return self.value
if same:
return row[self.item]
return "?"
def _fix_software(block, new_af3_sw_cit_id):
"""Update _software with _citation record ID."""
sw_table = access.get_table(block, "_software")
af3_cid_setter = _AF3ItemSetter(new_af3_sw_cit_id)
if sw_table:
if "_software.citation_id" not in sw_table.tags:
edit.add_column(block, "_software", "citation_id", af3_cid_setter)
else:
for row in sw_table:
row["citation_id"] = af3_cid_setter(row, same=True)
if not af3_cid_setter.sw_found:
raise NotIdentifiedContextRecordError(
"_software", context=": AlphaFold 3 not found"
)
def _get_key_primary(row):
"""For edit.sort(), make sure 'primary' comes first"""
if row["id"] == "primary":
return (-1, "")
try:
return (0, int(row["id"]))
except ValueError:
return (1, row["id"])
def _ensure_citation_id_first(block):
"""Make sure _citation.id is the first tag of a table."""
cif_dict = block.get_mmcif_category("_citation")
table = access.get_table(block, "_citation")
if table.tags[0] == "_citation.id":
return
cif_dict = {"id": list(table.find_column("id"))}
for itm in table.tags:
itm = itm.split(".", maxsplit=1)[1]
if itm == "id":
continue
cif_dict[itm] = list(table.find_column(itm))
table.erase()
edit.add_category(
block,
"_citation",
item_data=cif_dict,
index="before:_citation_author",
raw=True,
)
[docs]
def fix_citation(block):
"""Normalise the AlphaFold 3 citation in a `ModelCIF`_ ``block``.
Ensures that the AlphaFold 3 publication
(`PMID 38718835 <https://pubmed.ncbi.nlm.nih.gov/38718835/>`_) is not marked
as the "primary" citation, assigns a numeric citation ID instead. Fixes an
incomplete AlphaFold 3 citation. Replaces the author list with the full
curated list of names and updates its citation ID. Reorders citations so
that the primary entry appears first and links the citation to the
corresponding software record.
This adjustment is not required for valid `ModelCIF`_ files, but follows
`ModelArchive`_ conventions where the primary citation must refer to
the deposited model rather than the software used to generate it.
Examples:
>>> from gemmi import cif
>>> from modelarchive.modelcif import access, fix_af3
>>> # start with an empty CIF document
>>> CIF_DATA = '''data_test
... _citation.id primary
... _citation.country UK
... _citation.journal_full Nature
... _citation.journal_id_ASTM NATUAS
... _citation.journal_id_CSD 0006
... _citation.journal_id_ISSN 0028-0836
... _citation.journal_volume 630
... _citation.page_first 493
... _citation.page_last 500
... _citation.pdbx_database_id_DOI 10.1038/s41586-024-07487-w
... _citation.pdbx_database_id_PubMed 38718835
... _citation.title 'Accurate structure prediction of biomolecular ...'
... _citation.year 2024
... #
... loop_
... _citation_author.citation_id
... _citation_author.name
... _citation_author.ordinal
... primary "Google DeepMind AlphaFold Team" 1
... primary "Isomorphic Labs Team" 2
... #
... loop_
... _software.classification
... _software.date
... _software.description
... _software.name
... _software.pdbx_ordinal
... _software.type
... _software.version
... other ? "Structure prediction" AlphaFold 1 package AlphaFold-beta
... '''
>>> block = cif.read_string(CIF_DATA).sole_block()
>>> fix_af3.fix_citation(block)
>>> # The usual block.as_string() output would be too much for a
>>> # docstring, just check some important values.
>>> table = access.get_table(block, "_citation")
>>> assert table[0]["id"] == "1"
>>> table = access.get_table(block, "_citation_author")
>>> assert table[0]["name"] != "Google DeepMind AlphaFold Team"
>>> table = access.get_table(block, "_software")
>>> assert table[0]["citation_id"] == "1"
Args:
block (|gemmicifBlock|): CIF block to operate on.
Returns:
None
Raises:
edit.NotFoundCategoryError: If _software category can not be found.
NotIdentifiedSingleRecordError: If required item is missing from
_citation category. If item values are not as expected for
_citation category.
NotIdentifiedDuplicatedRecordError: If multiple entries for AlphaFold
are found in _software category. In that case, the "right"
record can not be identified.
"""
old_af3_sw_cit_id = None
new_af3_sw_cit_id = None
cat = "_citation"
itms = ["id", "pdbx_database_id_PubMed"]
table = access.get_table(block, cat, itms)
# pick first numeric value not yet taken
primary_citation_id = _quote(
_get_ordinal_ids(set(row["id"] for row in table), 1)[0]
)
# correct IDs and find AF3 citation
for row in table:
if row["pdbx_database_id_PubMed"] == "38718835":
old_af3_sw_cit_id = row["id"]
if row["id"] == "primary":
row["id"] = primary_citation_id
new_af3_sw_cit_id = primary_citation_id
else:
new_af3_sw_cit_id = old_af3_sw_cit_id
if old_af3_sw_cit_id is None or _is_null(old_af3_sw_cit_id):
# citation available w/o PMID? Fallback option replacing citation
old_af3_sw_cit_id, new_af3_sw_cit_id = _fix_citation_fallback(
block, primary_citation_id
)
if len(table) > 0:
# _citation table may have changed, sort it to have "primary" first
edit.sort(table, "id", key=_get_key_primary)
# make sure _citation.id comes first
_ensure_citation_id_first(block)
# fix authors (completely replace ones for AF3 publication)
cit_auth_dict = block.get_mmcif_category("_citation_author.")
cit_auth_dict_new = {"citation_id": [], "name": [], "ordinal": []}
if cit_auth_dict:
for idx, citation_id in enumerate(cit_auth_dict["citation_id"]):
if citation_id != old_af3_sw_cit_id:
cit_auth_dict_new["citation_id"].append(citation_id)
cit_auth_dict_new["name"].append(cit_auth_dict["name"][idx])
# note: fixed to be in correct style and without special characters
af3_authors = [
"Abramson, J.",
"Adler, J.",
"Dunger, J.",
"Evans, R.",
"Green, T.",
"Pritzel, A.",
"Ronneberger, O.",
"Willmore, L.",
"Ballard, A.J.",
"Bambrick, J.",
"Bodenstein, S.W.",
"Evans, D.A.",
"Hung, C.C.",
"O'Neill, M.",
"Reiman, D.",
"Tunyasuvunakool, K.",
"Wu, Z.",
"Zemgulyte, A.",
"Arvaniti, E.",
"Beattie, C.",
"Bertolli, O.",
"Bridgland, A.",
"Cherepanov, A.",
"Congreve, M.",
"Cowen-Rivers, A.I.",
"Cowie, A.",
"Figurnov, M.",
"Fuchs, F.B.",
"Gladman, H.",
"Jain, R.",
"Khan, Y.A.",
"Low, C.M.R.",
"Perlin, K.",
"Potapenko, A.",
"Savy, P.",
"Singh, S.",
"Stecula, A.",
"Thillaisundaram, A.",
"Tong, C.",
"Yakneen, S.",
"Zhong, E.D.",
"Zielinski, M.",
"Zidek, A.",
"Bapst, V.",
"Kohli, P.",
"Jaderberg, M.",
"Hassabis, D.",
"Jumper, J.M.",
]
cit_auth_dict_new["citation_id"].extend(
[new_af3_sw_cit_id] * len(af3_authors)
)
cit_auth_dict_new["name"].extend(af3_authors)
if cit_auth_dict:
cit_auth_dict_new["ordinal"] = list(
range(1, len(cit_auth_dict_new["name"]) + 1)
)
block.set_mmcif_category("_citation_author.", cit_auth_dict_new)
_fix_software(block, new_af3_sw_cit_id)
def _is_af3_server(block):
"""Check if block was produced with AF3 server or code.
True means server, False means "from code". On problems to identify record,
raise exception.
"""
# this is a heuristic and may fail!
table = access.get_table(block, "_pdbx_data_usage", items=["details"])
is_server = any(True for r in table if "server" in r["details"].lower())
github_url = "github.com/google-deepmind/alphafold3"
is_code = any(True for r in table if github_url in r["details"].lower())
if is_server and is_code or (not is_server and not is_code):
raise NotIdentifiedContextRecordError(
"_pdbx_data_usage", context=": AlphaFold 3 license type"
)
return is_server
[docs]
def fix_software_location(block):
"""Ensures the AlphaFold 3 _software entry has a correct location URL.
Determines whether the `ModelCIF`_ ``block`` originates from the AlphaFold 3
server or a local installation and sets the corresponding URL in
_software.location. If the column does not yet exist it is created;
otherwise only the row for AlphaFold 3 is updated.
Examples:
>>> from gemmi import cif
>>> from modelarchive.modelcif import access, fix_af3
>>> # start with an empty CIF document
>>> CIF_DATA = '''data_test
... _pdbx_data_usage.details "... alphafoldserver.com/output-terms."
... _pdbx_data_usage.id 1
... _pdbx_data_usage.type license
... _pdbx_data_usage.url ?
... #
... loop_
... _software.classification
... _software.date
... _software.description
... _software.name
... _software.pdbx_ordinal
... _software.type
... _software.version
... other ? "Structure prediction" AlphaFold 1 package AlphaFold-beta
... '''
>>> block = cif.read_string(CIF_DATA).sole_block()
>>> fix_af3.fix_software_location(block)
>>> # Just check that _software.location exists and has the right value
>>> table = access.get_table(block, "_software")
>>> assert "_software.location" in table.tags
>>> assert table[0]["location"] == "https://alphafoldserver.com/"
>>> # Change block to look like ModelCIF file from local installation
>>> table = access.get_table(block, "_pdbx_data_usage")
>>> table[0]["details"] = "...github.com/google-deepmind/alphafold3..."
>>> fix_af3.fix_software_location(block)
>>> # Check _software.location to point to GitHub, now
>>> table = access.get_table(block, "_software")
>>> assert table[0]["location"] == \
"https://github.com/google-deepmind/alphafold3"
Args:
block (|gemmicifBlock|): CIF block to operate on.
Returns:
None
Raises:
NotIdentifiedContextRecordError: If no AlphaFold 3 entry is found in
the _software table.
NotIdentifiedContextRecordError: If the origin of the AlphaFold 3
license could not be identified in the _pdbx_data_usage table.
NotIdentifiedDuplicatedRecordError: If multiple entries for AlphaFold
3 are found in the _software table.
"""
is_server = _is_af3_server(block)
if is_server:
af3_sw_url = "https://alphafoldserver.com/"
else:
af3_sw_url = "https://github.com/google-deepmind/alphafold3"
sw_table = access.get_table(block, "_software")
af3_url_setter = _AF3ItemSetter(af3_sw_url, item="location")
if sw_table:
if "_software.location" not in sw_table.tags:
edit.add_column(block, "_software", "location", af3_url_setter)
else:
for row in sw_table:
row["location"] = af3_url_setter(row, same=True)
if not af3_url_setter.sw_found:
raise NotIdentifiedContextRecordError(
"_software", context=": AlphaFold 3 not found"
)
def _get_cat_dict(block, category):
"""Fetch category as dict, raise exception otherwise."""
cat_dict = block.get_mmcif_category(f"{category}.")
if not cat_dict:
raise edit.NotFoundCategoryError(category)
return cat_dict
def _get_af3_sw_group(block):
"""Get software_group_id for AF3 (to be used in protocols and QE).
Assumption: single SW group exists which points to AF3.
"""
sw_group_table = access.get_table(block, "_ma_software_group")
if not sw_group_table:
raise edit.NotFoundCategoryError("_ma_software_group")
if len(sw_group_table) == 1:
return sw_group_table[0]["group_id"]
# Multiple SW groups, look for AF3 entry
sw_id = None
sw_table = access.get_table(block, "_software")
if not sw_table:
raise edit.NotFoundCategoryError("_software")
for row in sw_table:
if _is_af3_sw_name(row["name"]):
if sw_id is not None:
raise NotIdentifiedDuplicatedRecordError(
"_software", "AlphaFold"
)
sw_id = row["pdbx_ordinal"]
for row in sw_group_table:
if row["software_id"] == sw_id:
return row["group_id"]
raise NotIdentifiedContextRecordError(
"_software", context=": AlphaFold 3 not found"
)
[docs]
def fix_protocol(block):
"""Fix the MA protocol to a single well-formed step.
Rewrites _ma_data, _ma_data_group, and _ma_protocol_step from scratch based
on the existing _ma_target_entity, _ma_model_list and _ma_software_group
categories.
.. warning::
- Existing _ma_data, _ma_data_group, and _ma_protocol_step categories in
``block`` are overwritten without checking their prior contents.
- It is important that :func:`fix_modelcif_issues` and
:func:`fix_model_name` are called before this since this function uses
existing data from _ma_model_list and _ma_data_ref_db set there.
Data layout after the call:
_ma_data:
One record per target entity (content_type "target")
followed by one record per model (content_type
"model coordinates"). If _ma_data_ref_db is available, records for
them are added as well. Names for the data items are taken from
_entity.pdbx_description, _ma_model_list.model_name and
_ma_data_ref_db.name, respectively. IDs are assigned sequentially
starting at 1.
_ma_data_group:
Group 1 - all target data IDs (input side).
Group 2 - all model data IDs (output side).
_ma_protocol_step:
A single step referencing the AF3 software group, group 1
as input, and group 2 as output.
Examples:
>>> from gemmi import cif
>>> from modelarchive.modelcif import access, fix_af3
>>> # start with an empty CIF document
>>> CIF_DATA = '''data_test
... #
... loop_
... _entity.id
... _entity.pdbx_description
... _entity.type
... 1 "bestest polymer in universe" polymer
... 2 "second best polythingi in universe" polymer
... #
... loop_
... _ma_target_entity.data_id
... _ma_target_entity.entity_id
... _ma_target_entity.origin
... 1 1 .
... 1 2 .
... #
... _ma_model_list.data_id 1
... _ma_model_list.model_group_id 1
... _ma_model_list.model_group_name "AlphaFold-beta-20231127 (...)"
... _ma_model_list.model_id 1
... _ma_model_list.model_name "Top ranked model"
... _ma_model_list.model_type "Ab initio model"
... _ma_model_list.ordinal_id 1
... #
... loop_
... _ma_software_group.group_id
... _ma_software_group.ordinal_id
... _ma_software_group.software_id
... 1 1 1
... #
... loop_
... _software.classification
... _software.date
... _software.description
... _software.name
... _software.pdbx_ordinal
... _software.type
... _software.version
... other ? "Structure prediction" AlphaFold 1 package AlphaFold-beta
... '''
>>> block = cif.read_string(CIF_DATA).sole_block()
>>> fix_af3.fix_protocol(block)
>>> access.get_table(block, "_entity").erase()
>>> access.get_table(block, "_ma_data").erase()
>>> access.get_table(block, "_ma_data_group").erase()
>>> access.get_table(block, "_ma_model_list").erase()
>>> access.get_table(block, "_ma_software_group").erase()
>>> access.get_table(block, "_ma_target_entity").erase()
>>> access.get_table(block, "_software").erase()
>>> print(block.as_string())
data_test
loop_
_ma_protocol_step.ordinal_id
_ma_protocol_step.protocol_id
_ma_protocol_step.step_id
_ma_protocol_step.method_type
_ma_protocol_step.details
_ma_protocol_step.software_group_id
_ma_protocol_step.input_data_group_id
_ma_protocol_step.output_data_group_id
1 1 1 modeling 'Model generated with AlphaFold 3.' 1 1 2
<BLANKLINE>
Args:
block (|gemmicifBlock|): CIF block to operate on.
Returns:
None
Raises:
edit.NotFoundCategoryError: If any required source category is
absent: _entity, _ma_target_entity, _ma_model_list, or
_ma_software_group.
edit.NotFoundItemError: If _ma_target_entity.data_id,
_ma_model_list.data_id or _ma_model_list.model_name are missing.
NotIdentifiedDuplicatedRecordError: If multiple _ma_software_group
records exist and the AF3 entry cannot be unambiguously identified
in _software.
NotIdentifiedContextRecordError: If multiple _ma_software_group records
exist but no AF3 entry can be found in _software at all.
"""
# collect data to add
data_dict = {"id": [], "name": [], "content_type": []}
data_ids_in = []
data_ids_out = []
# add targets
entity_dict = _get_cat_dict(block, "_entity")
entity_descs = dict(zip(entity_dict["id"], entity_dict["pdbx_description"]))
trg_ent_table = access.get_table(block, "_ma_target_entity")
if not trg_ent_table:
raise edit.NotFoundCategoryError("_ma_target_entity")
# Raise when _ma_target_entity.data_id does not exist, that could mean
# a lot has changed in the way AF writes ModelCIF files.
if "_ma_target_entity.data_id" not in trg_ent_table.tags:
raise edit.NotFoundItemError("_ma_target_entity.data_id")
for i, row in enumerate(trg_ent_table, start=1):
row["data_id"] = str(i)
data_dict["id"].append(row["data_id"])
data_ids_in.append(row["data_id"])
data_dict["name"].append(entity_descs[row["entity_id"]])
data_dict["content_type"].append("target")
# add ref.DBs
data_ref_db = access.get_table(block, "_ma_data_ref_db")
for i, row in enumerate(data_ref_db, start=len(data_dict["id"]) + 1):
row["data_id"] = str(i)
data_dict["id"].append(row["data_id"])
data_ids_in.append(row["data_id"])
data_dict["name"].append(row["name"])
data_dict["content_type"].append("reference database")
# add model
mdl_list_table = access.get_table(block, "_ma_model_list")
if not mdl_list_table:
raise edit.NotFoundCategoryError("_ma_model_list")
for itm in ["data_id", "model_name"]:
if f"_ma_model_list.{itm}" not in mdl_list_table.tags:
raise edit.NotFoundItemError(f"_ma_model_list.{itm}")
for i, row in enumerate(mdl_list_table, start=len(data_dict["id"]) + 1):
row["data_id"] = str(i)
data_dict["id"].append(row["data_id"])
data_dict["name"].append(cif.as_string(row["model_name"]))
data_dict["content_type"].append("model coordinates")
data_ids_out.append(row["data_id"])
# write data (need to be able to overwrite!)
# Using set_mmcif_category() here is OK, there are at least to data records,
# target & model, so it will alwyas be a loop.
block.set_mmcif_category("_ma_data.", data_dict)
edit.move_category(block, "_ma_data", "after:_ma_model_list")
# add 2 data groups (1 for input, 2 for output)
num_data_ids = len(data_ids_in) + len(data_ids_out)
block.set_mmcif_category(
"_ma_data_group.",
{
"ordinal_id": list(range(1, num_data_ids + 1)),
"group_id": [1] * len(data_ids_in) + [2] * len(data_ids_out),
"data_id": data_ids_in + data_ids_out,
},
)
edit.move_category(block, "_ma_data_group", "after:_ma_data")
# find SW group
af3_sw_group = _get_af3_sw_group(block)
# add single protocol step
block.set_mmcif_category(
"_ma_protocol_step.",
{
"ordinal_id": [1],
"protocol_id": [1],
"step_id": [1],
"method_type": ["modeling"],
"details": ["Model generated with AlphaFold 3."],
"software_group_id": [af3_sw_group],
"input_data_group_id": [1],
"output_data_group_id": [2],
},
)
edit.move_category(block, "_ma_protocol_step", "after:_ma_data_group")
[docs]
def add_per_residue_plddt(block):
"""Add average per-residue pLDDT scores to an AF3 ModelCIF file.
Adds _ma_qa_metric_local data derived from B-factor values in
_atom_site. The per-residue pLDDT is computed as the mean over
all atoms of a residue. Non-polymer residues (missing value in
_atom_site.label_seq_id) are excluded.
If _ma_qa_metric_local is already present in the block, the
function exits early with a warning. If no local pLDDT entry exists
in _ma_qa_metric, it will be added; if more than one local pLDDT entry is
found, an exception is raised as this is most likely an error in the
ModelCIF file.
This fix targets AF3 files predating version 3.0.1, which lack
_ma_qa_metric_local.
Examples:
>>> from gemmi import cif
>>> from modelarchive.modelcif import access, fix_af3
>>> # Please note: the example CIF document for this case has the
>>> # _atom_site category reduce to the bare minimum to make the
>>> # mechanics of add_per_residue_plddt() work, to keep the example
>>> # shorter.
>>> CIF_DATA = '''data_test
... #
... loop_
... _ma_software_group.group_id
... _ma_software_group.ordinal_id
... _ma_software_group.software_id
... 1 1 1
... #
... loop_
... _software.classification
... _software.date
... _software.description
... _software.name
... _software.pdbx_ordinal
... _software.type
... _software.version
... other ? "Structure prediction" AlphaFold 1 package AlphaFold-beta
... #
... loop_
... _atom_site.group_PDB
... _atom_site.label_comp_id
... _atom_site.label_asym_id
... _atom_site.label_seq_id
... _atom_site.B_iso_or_equiv
... _atom_site.pdbx_PDB_model_num
... ATOM MET A 1 35.00 1
... ATOM ALA A 2 50.30 1
... ATOM THR A 3 65.75 1
... '''
>>> block = cif.read_string(CIF_DATA).sole_block()
>>> fix_af3.add_per_residue_plddt(block)
>>> # After execution, the CIF document has categories _ma_qa_metric
>>> # and _ma_qa_metric_local added
>>> # There should be only 1 record in _ma_qa_metric
>>> qa_dict = block.get_mmcif_category("_ma_qa_metric.")
>>> print(qa_dict)
{'id': ['1'], 'mode': ['local'], 'name': ['pLDDT'], \
'software_group_id': ['1'], 'type': ['pLDDT']}
>>> # There should be 3 records of local scores
>>> table = access.get_table(block, "_ma_qa_metric_local.")
>>> print("# chain res seqID pLDDT")
# chain res seqID pLDDT
>>> for r in table:
... print(
... f"{r['ordinal_id']} {r['label_asym_id']} "
... + f"{r['label_comp_id']} {r['label_seq_id']} "
... + f"{r['metric_value']}"
... )
1 A MET 1 35.0
2 A ALA 2 50.3
3 A THR 3 65.75
Args:
block (|gemmicifBlock|): CIF block to operate on.
Returns:
None
Raises:
RuntimeError: If ``_ma_qa_metric`` contains more than one local
pLDDT entry.
"""
if "_ma_qa_metric_local." in block.get_mmcif_category_names():
_utils.warn_msg("_ma_qa_metric_local already there, skipped.")
return
# find metric ID for pLDDT
table = block.get_mmcif_category("_ma_qa_metric.")
if len(table) == 0:
table = {
"id": [],
"mode": [],
"name": [],
"software_group_id": [],
"type": [],
}
qa_local = [
mid
for mid, mtype, mode in zip(table["id"], table["type"], table["mode"])
if mtype == "pLDDT" and mode == "local"
]
if len(qa_local) == 1:
# expected route
metric_id = qa_local[0]
elif len(qa_local) == 0:
# add it if possible
assert sorted(table.keys()) == [
"id",
"mode",
"name",
"software_group_id",
"type",
]
metric_id = _get_ordinal_ids(table["id"], 1)[0]
table["id"].append(metric_id)
table["mode"].append("local")
table["name"].append("pLDDT")
table["software_group_id"].append(_get_af3_sw_group(block))
table["type"].append("pLDDT")
# set_mmcif_category() will create a loop, even for single records...
# that is OK here as having multiple metrics (at least 1 global and
# 1 local) is more than common.
block.set_mmcif_category("_ma_qa_metric.", table)
else:
raise RuntimeError(
"Unexpected number of local pLDDT entries in _ma_qa_metric."
)
# code adapted from src/alphafold3/model/mmcif_metadata.py as modified in
# https://github.com/google-deepmind/alphafold3/commit/121716e
res_id_keys = [
"label_asym_id",
"label_seq_id",
"label_comp_id",
"pdbx_PDB_model_num",
]
table = access.get_table(
block, "_atom_site", res_id_keys + ["B_iso_or_equiv"]
)
plddt_grouped_by_res = {}
for row in table:
group_key = tuple(row[v] for v in res_id_keys)
plddt_grouped_by_res.setdefault(group_key, []).append(
cif.as_number(row["B_iso_or_equiv"])
)
qa_local = {
"ordinal_id": [],
"label_asym_id": [],
"label_comp_id": [],
"label_seq_id": [],
"metric_id": [],
"metric_value": [],
"model_id": [],
}
for ordinal_id, (
(chain_id, res_id, res_name, model_id),
res_plddts,
) in enumerate(plddt_grouped_by_res.items(), start=1):
# skip non-polymer residues (note AF3 includes those)
if not cif.is_null(res_id):
res_plddt = round(np.mean(res_plddts), 2)
qa_local["ordinal_id"].append(str(ordinal_id))
qa_local["label_asym_id"].append(chain_id)
qa_local["label_seq_id"].append(res_id)
qa_local["label_comp_id"].append(res_name)
qa_local["metric_id"].append(metric_id)
qa_local["metric_value"].append(str(res_plddt))
qa_local["model_id"].append(model_id)
block.set_mmcif_category("_ma_qa_metric_local.", qa_local, raw=True)
def _add_auth_comp_id_2_atom_site(block):
"""Add item auth_comp_id to category _atom_site, if missing.
As values, item label_comp_id will be copied."""
edit.add_column(
block,
"_atom_site",
"auth_comp_id",
edit.make_copy_value_in_row("label_comp_id"),
pos=17,
)
_return_mon_id = edit.make_copy_value_in_row("mon_id")
def _add_pdb_mon_id_2_pdbx_poly_seq_scheme(block):
"""Add item pdb_mon_id to _pdbx_poly_seq_scheme."""
edit.add_column(
block,
"_pdbx_poly_seq_scheme",
"pdb_mon_id",
_return_mon_id,
pos=9,
)
def _add_pdb_mon_id_2_pdbx_branch_scheme(block):
"""Add item pdb_mon_id to _pdbx_branch_scheme from
_atom_site.auth_comp_id"""
try:
edit.add_column(
block,
"_pdbx_branch_scheme",
"pdb_mon_id",
_return_mon_id,
pos=9,
)
except edit.NotFoundCategoryError:
pass
def _add_pdbx_entity_branch_list_2_data(block):
"""Add category _pdbx_entity_branch_list to data block."""
# first check that pdbx_branch_scheme exists
pbs = access.get_table(block, "_pdbx_branch_scheme")
if not pbs:
return
# gather values
bl_items = {"entity_id": [], "num": [], "comp_id": []}
for row in pbs:
bl_items["entity_id"].append(row["entity_id"])
bl_items["num"].append(row["num"])
bl_items["comp_id"].append(row["mon_id"])
# place above pdbx_branch_scheme
edit.add_category(
block,
"_pdbx_entity_branch_list",
bl_items,
index="before:_pdbx_branch_scheme",
)
def _add_ndb_seq_num_2_pdbx_nonpoly_scheme(block):
"""Add item ndb_seq_num to category _pdbx_nonpoly_scheme, if missing."""
try:
edit.add_column(
block,
"_pdbx_nonpoly_scheme",
"ndb_seq_num",
edit.make_res_per_chain_counter("asym_id"),
pos=6,
)
except edit.NotFoundCategoryError:
pass
def _relabel_lone_sugars(block):
"""AF3 has single sugars marked as 'branched' entities (wrong) but lists
them in '_pdbx_nonpoly_scheme' (right). So switch the '_entity.type' to
'non-polymer' as it appears in RCSB."""
# single sugars marked as branched:
# check for branched entities, that don't appear in pdbx_branch_scheme
cats = {}
cats["_entity"] = access.get_table(block, "_entity", items=["id", "type"])
cats["_pdbx_branch_scheme"] = access.get_table(
block, "_pdbx_branch_scheme", items=["entity_id"]
)
if not cats["_pdbx_branch_scheme"]:
return
for ent_row in cats["_entity"]:
if ent_row["type"].upper() != "BRANCHED":
continue
found = False
for pbs_row in cats["_pdbx_branch_scheme"]:
if pbs_row["entity_id"] == ent_row["id"]:
found = True
break
if not found:
ent_row["type"] = "non-polymer"
ENT_COLS = ["pdbx_description", "type"]
def _delete_record(cat, item, valuep, valuec, eitems, block, cache):
"""Delete a record from a category."""
# Hard to reduce arguments
# pylint: disable=too-many-arguments,too-many-positional-arguments
if cat not in cache:
cache[cat] = access.get_table(block, cat, items=[item] + eitems)
# fetch values for comparison/ delete child record
ntp = {}
ntc = {}
cidx = None
for row in cache[cat]:
if row[item] == valuep:
for i in eitems:
ntp[i] = row[i]
elif row[item] == valuec:
cidx = row.row_index
for i in eitems:
ntc[i] = row[i]
if cidx is None:
raise NotIdentifiedSingleRecordError(cat, item=item, value=valuec)
if ntp != ntc:
raise RuntimeError(
f"Records to be merged not equal: {ntp} vs. {ntc} ({item}: "
+ f"{valuep}/ {valuec})"
)
cache[cat].remove_row(cidx)
def _merge_entities(mntp, mntc, etype, block, cache):
"""Merge two entities.
Remove one entity and update the block."""
# _entity
_delete_record("_entity", "id", mntp, mntc, ENT_COLS, block, cache)
# _ma_target_entity
_delete_record(
"_ma_target_entity",
"entity_id",
mntp,
mntc,
["data_id", "origin"],
block,
cache,
)
# _ma_target_entity_instance
_update_records(
"_ma_target_entity_instance",
"entity_id",
mntp,
mntc,
block,
cache,
)
# _struct_asym
_update_records("_struct_asym", "entity_id", mntp, mntc, block, cache)
# _atom_site
_update_records("_atom_site", "label_entity_id", mntp, mntc, block, cache)
if etype.upper() == "POLYMER":
# _entity_poly
_aggregate_entity_poly(mntp, mntc, block, cache)
# _entity_poly_seq
_delete_records("_entity_poly_seq", "entity_id", mntc, cache)
# _pdbx_poly_seq_scheme
_update_records(
"_pdbx_poly_seq_scheme", "entity_id", mntp, mntc, block, cache
)
elif etype.upper() == "BRANCHED":
# _pdbx_entity_branch_list
_update_records(
"_pdbx_entity_branch_list", "entity_id", mntp, mntc, block, cache
)
# _pdbx_branch_scheme
_update_records(
"_pdbx_branch_scheme", "entity_id", mntp, mntc, block, cache
)
elif etype.upper() == "NON-POLYMER":
# _pdbx_nonpoly_scheme
_update_records(
"_pdbx_nonpoly_scheme", "entity_id", mntp, mntc, block, cache
)
def _reduce_duplicated_entities(block):
"""Multimeric AF3 ModelCIF files may have entities duplicated.
When seeing the same molecular entity multiple times, reduce it to a single
molecular enityt but multiple chains."""
cats = {}
# check for duplicated entities
cats["_entity"] = access.get_table(
block, "_entity", items=["id"] + ENT_COLS
)
ments = {}
for row in cats["_entity"]:
ments[row["id"]] = {"seq": "", "type": row["type"]}
# fetch polymers
cats["_entity_poly_seq"] = access.get_table(
block, "_entity_poly_seq", items=["entity_id", "mon_id"]
)
for row in cats["_entity_poly_seq"]:
ments[row["entity_id"]]["seq"] += row["mon_id"]
# fetch non-polymers
cats["_pdbx_nonpoly_scheme"] = access.get_table(
block, "_pdbx_nonpoly_scheme", items=["entity_id", "mon_id"]
)
for row in cats["_pdbx_nonpoly_scheme"]:
ments[row["entity_id"]]["seq"] += row["mon_id"]
# fetch branched entities
cats["_pdbx_branch_scheme"] = access.get_table(
block, "_pdbx_branch_scheme", items=["entity_id", "mon_id"]
)
for row in cats["_pdbx_branch_scheme"]:
ments[row["entity_id"]]["seq"] += row["mon_id"]
# look for missed entities
for k, v in ments.items():
if len(v["seq"]) == 0:
raise RuntimeError(f"Empty entity found: '{k}'")
# fix duplicated entities...
while len(ments) > 1:
kc, vc = ments.popitem()
for kp, vp in ments.items():
if vp["seq"] == vc["seq"]:
_merge_entities(kp, kc, vp["type"], block, cats)
break
# make entities run sequentially (close holes)
_make_entity_ids_sequential(block, cats)
def _make_entity_ids_sequential(block, cache):
"""After removing duplicated entities, IDs may not be sequential anymore.
Fix that."""
# iterate entity table, check for ID that is not +1 of the last seen
cat = "_entity"
# This should never be needed, should not be possible to reach this point
# w/o reading "_entity" at least once.
# if cat not in cache:
# cache[cat] = access.get_table(block, cat, items=["id"] + ENT_COLS)
last_id = 0
for row in cache[cat]:
cur_id = cif.as_int(row["id"])
last_id = last_id + 1
if cur_id != last_id:
cur_id = str(cur_id)
last_id_str = str(last_id)
_update_record("_entity", "id", cur_id, last_id, cache)
_update_records(
"_atom_site",
"label_entity_id",
last_id_str,
cur_id,
block,
cache,
)
_update_record(
"_ma_target_entity", "entity_id", cur_id, last_id, cache
)
_update_records(
"_ma_target_entity_instance",
"entity_id",
last_id_str,
cur_id,
block,
cache,
)
_update_records(
"_struct_asym", "entity_id", last_id_str, cur_id, block, cache
)
if row["type"].upper() == "POLYMER":
_update_records(
"_entity_poly",
"entity_id",
last_id_str,
cur_id,
block,
cache,
)
_update_records(
"_entity_poly_seq",
"entity_id",
last_id_str,
cur_id,
block,
cache,
)
_update_records(
"_pdbx_poly_seq_scheme",
"entity_id",
last_id_str,
cur_id,
block,
cache,
)
elif row["type"].upper() == "BRANCHED":
_update_records(
"_pdbx_entity_branch_list",
"entity_id",
last_id_str,
cur_id,
block,
cache,
)
_update_records(
"_pdbx_branch_scheme",
"entity_id",
last_id_str,
cur_id,
block,
cache,
)
elif row["type"].upper() == "NON-POLYMER":
_update_records(
"_pdbx_nonpoly_scheme",
"entity_id",
last_id_str,
cur_id,
block,
cache,
)
def _update_records(cat, item, valuep, valuec, block, cache):
"""Exchange the value of a record with another."""
# Hard to reduce arguments
# pylint: disable=too-many-arguments,too-many-positional-arguments
if cat not in cache:
cache[cat] = access.get_table(block, cat, items=[item])
# get indeces to be deleted
for row in cache[cat]:
if row[item] == valuec:
row[item] = valuep
def _update_record(cat, item, value_old, value_new, cache):
"""Exchange the value of a single record."""
# Hard to reduce arguments
# pylint: disable=too-many-arguments,too-many-positional-arguments
# Does not seem to be needed, records to be deleted are most likely vistied
# beforehand.
# if cat not in cache:
# cache[cat] = access.get_table(block, cat, items=[item])
for row in cache[cat]:
if row[item] == value_old:
row[item] = str(value_new)
break
def _aggregate_entity_poly(valuep, valuec, block, cache):
"""For entity_poly, that needs reduction of duplicated entities, gather
chain names.
"""
# _entity_poly.entity_id
# _entity_poly.pdbx_strand_id
# _entity_poly.type
cat = "_entity_poly"
if cat not in cache:
cache[cat] = access.get_table(
block, cat, items=["entity_id", "pdbx_strand_id", "type"]
)
# find parent row
rowp = None
for row in cache[cat]:
if row["entity_id"] == valuep:
rowp = row
typep = row["type"]
break
if rowp is None:
raise RuntimeError(f"No '_entity_poly' with entity_id={valuep} found.")
# get child records
strand_ids = rowp["pdbx_strand_id"].split(",")
idxsc = []
for row in cache[cat]:
if row["entity_id"] == valuec:
if row["type"] != typep:
raise RuntimeError(
"'entity_poly' with duplicated entities mismatch at "
+ f"'type': '{typep}' vs '{row['type']}'"
)
strand_ids.extend(row["pdbx_strand_id"].split(","))
idxsc.append(row.row_index)
# update & delete
for i in sorted(idxsc, reverse=True):
cache[cat].remove_row(i)
rowp["pdbx_strand_id"] = ",".join(strand_ids)
def _delete_records(cat, item, valuec, cache):
"""Delete multiple records with a certain item and value."""
# if cat not in cache:
# cache[cat] = access.get_table(block, cat, items=[item])
# get indeces to be deleted
cidxs = []
for row in cache[cat]:
if row[item] == valuec:
cidxs.append(row.row_index)
# if len(cidxs) == 0:
# raise NotIdentifiedSingleRecordError(cat, item=item, value=valuec)
for i in sorted(cidxs, reverse=True):
cache[cat].remove_row(i)
def _fix_atom_names(block):
"""Sometimes AF3 doesn't get the atom names following IUPAC."""
def _fix_atom_symbol(table):
for row in table:
for i in row[0]:
if i.islower():
row[0] = row[0].upper()
break
table = access.get_table(block, "_atom_type", items=["symbol"])
_fix_atom_symbol(table)
table = access.get_table(block, "_atom_site", items=["type_symbol"])
_fix_atom_symbol(table)
def _load_cache(cache_file):
"""Load or start the cache file for chem.comps."""
# ToDo: update mechanism
# cache is of form {"<SMILES>": <COMPOUND>}
cache = {}
if cache_file.exists():
with open(cache_file, encoding="ascii") as jfh:
cache = json.load(jfh)
return cache
def _save_cache(cache, cache_file):
"""Store cache on disk."""
# ToDo: update mechanism
with open(cache_file, "w", encoding="ascii") as jfh:
json.dump(cache, jfh)
def _get_chem_comp_by_smiles(smiles, cache):
"""Fetch a chem. component from RCSB using SMILES."""
# 1. Try to solve by cache
if smiles in cache:
return cache[smiles]
# 2. get ID
payload = {
"query": {
"type": "terminal",
"service": "chemical",
"parameters": {
"value": smiles,
"type": "descriptor",
"descriptor_type": "SMILES",
"match_type": "graph-exact",
},
},
"return_type": "mol_definition",
}
response = requests.post(
"https://search.rcsb.org/rcsbsearch/v2/query",
json=payload,
headers={"Content-Type": "application/json"},
timeout=60,
)
if not response.ok:
raise RuntimeError(
f"Querying chem.comp. name failed for SMILES '{smiles}'."
)
try:
json_rspnse = response.json()
except requests.exceptions.JSONDecodeError:
cache[smiles] = None
return cache[smiles]
if json_rspnse["total_count"] != 1:
cache[smiles] = None
return cache[smiles]
json_rspnse = json_rspnse["result_set"][0]
if json_rspnse["score"] != 1.0:
raise RuntimeError(f"Score not 1.0 for SMILES '{smiles}'.")
# 3. get more info about the ligand to fill the _chem_comp table
response = requests.get(
"https://data.rcsb.org/rest/v1/core/chemcomp/"
+ f"{json_rspnse['identifier']}",
timeout=60,
)
if not response.ok:
raise RuntimeError(
"Querying chem.comp. data failed for ID "
+ f"'{json_rspnse['identifier']}'."
)
json_rspnse = response.json()
cache[smiles] = json_rspnse["chem_comp"]
return cache[smiles]
def _update_items(cat, items, exchange, block):
"""Exchange the value of a record with another."""
table = access.get_table(block, cat)
for row in table:
for itm in items:
try:
row[itm] = exchange[row[itm]]
except RuntimeError as exc:
if str(exc) == f"Column name not found: {itm}":
pass
else:
raise # pragma: no cover (more of a gemmi.cif thing)
except KeyError as exc:
if cif.as_string(str(exc)) == "HOH":
# RCSB seems to not have HOH in _chem_comp in all PDB
# entries, so we allow it missing here, too
pass
else:
raise
def _fix_chem_comp_id_ligands(block, cache_file):
"""Fix the 'LIG_<CHAR> naming scheme."""
# Not refactoring right now, ignore Pylint
# pylint: disable=too-many-branches
if not isinstance(cache_file, Path):
cache_file = Path(cache_file)
cache = _load_cache(cache_file)
table = access.get_table(block, "_chem_comp")
del_lst = []
smiles_lst = {}
unk_count = 1
updtd = {}
update_needed = False
# Sometimes ligands are redundant, a LIG_<CHAR> can be a compound that
# already exists in _chem_comp. Use _chem_comp.id for comparison. Whole
# table needs to be scanned before the real loop, because of sorting of
# _chem_comp. This shouldn't mean too much overhead, as all IDs need to be
# stored anyways.
for row in table:
updtd[row["id"]] = row["id"]
for row in table:
if row["name"] == "?":
update_needed = True
smiles = cif.as_string(row["pdbx_smiles"])
if smiles in smiles_lst:
del_lst.append(row.row_index)
updtd[row["id"]] = smiles_lst[smiles]
continue
comp = _get_chem_comp_by_smiles(smiles, cache)
if comp is None:
# Compounds not found in the PDB use the reserved IDs 01-99 to
# preserve the SMILES strings. Using 'UNL' instead would mean to
# wipe the SMILES string as 'UNL' is defined with every
# attribute blank to match all unknown ligands.
if unk_count >= 99:
_utils.warn_msg(
"No. of unknown ligands exceeds 99, won't renumber "
+ f"chem.comp. '{row['id']}'.",
)
continue
updtd[row["id"]] = f"{unk_count:02d}"
smiles_lst[smiles] = f"{unk_count:02d}"
row["id"] = f"{unk_count:02d}"
unk_count += 1
continue
if _quote(comp["id"]) in updtd:
del_lst.append(row.row_index)
updtd[row["id"]] = _quote(comp["id"])
smiles_lst[smiles] = _quote(comp["id"])
for i in ["formula", "formula_weight", "id", "name", "type"]:
try:
row[i] = _quote(comp[i])
except TypeError:
row[i] = str(comp[i])
# remove rows with duplicated ligands
for i in sorted(del_lst, reverse=True):
table.remove_row(i)
_save_cache(cache, cache_file)
# Update categories
if not update_needed:
return
cat_2_updt = [
("_atom_site", ["label_comp_id"]),
("_entity_poly_seq", ["parent_mon_id", "mon_id"]),
("_ma_qa_metric_local", ["label_comp_id"]),
("_pdbx_branch_scheme", ["mon_id", "pdb_mon_id", "auth_mon_id"]),
("_pdbx_nonpoly_scheme", ["mon_id", "pdb_mon_id", "auth_mon_id"]),
("_pdbx_poly_seq_scheme", ["mon_id", "pdb_mon_id", "auth_mon_id"]),
]
for cat in cat_2_updt:
_update_items(cat[0], cat[1], updtd, block)
def _fix_ma_data_ref_db(block):
"""Replace _ma_data_ref_db.id (wrong) with _ma_data_ref_db.data_id
(right). Remove duplicate entries from _ma_data_ref_db."""
ma_data_ref_db = access.get_table(block, "_ma_data_ref_db")
if not ma_data_ref_db:
# _ma_data_ref_db is not required, just skip the fix if not present
return
# replace _ma_data_ref_db.id with _ma_data_ref_db.data_id
if "_ma_data_ref_db.id" in ma_data_ref_db.tags:
# get next _ma_data.id, if no _ma_data, create one
ma_data = access.get_table(
block, "_ma_data", items=["id", "name", "content_type"]
)
if len(ma_data) > 0:
data_id = _utils.get_next_ordinal(ma_data, "_ma_data", "id")
ma_data.ensure_loop()
ma_data = ma_data.loop
else:
data_id = 1
ma_data = block.init_loop(
"_ma_data.", ["id", "name", "content_type"]
)
new_data_ids = []
seen = set() # remove duplicates
del_rows = [] # remove_duplicates
for row in ma_data_ref_db:
rdb = "|".join(
row[tag]
for tag in ma_data_ref_db.tags
if tag != "_ma_data_ref_db.id"
)
# remove duplicates
if rdb in seen:
del_rows.append(row.row_index)
continue
seen.add(rdb)
new_data_ids.append(str(data_id))
ma_data.add_row(
[str(data_id), row["name"], _quote("reference database")]
)
data_id += 1
# remove duplicates
for i in reversed(del_rows):
ma_data_ref_db.remove_row(i)
ma_data_ref_db.ensure_loop()
pos = list(ma_data_ref_db.tags).index("_ma_data_ref_db.id")
ma_data_ref_db.loop.remove_column("_ma_data_ref_db.id")
ma_data_ref_db.loop.add_columns(
["_ma_data_ref_db.data_id"], value="?", pos=pos
)
for i, row in enumerate(ma_data_ref_db):
row[pos] = new_data_ids[i]
[docs]
def fix_modelcif_issues(block, compdict_cache=".compdict_cache"):
"""Fix multiple small issues in AF3 ModelCIF files.
Things corrected:
* _atom_site.auth_comp_id gets added if not present
* _pdbx_poly_seq_scheme.pdb_mon_id gets added if not present
* _pdbx_branch_scheme.pdb_mon_id gets added if not present
* _pdbx_entity_branch_list gets added when _pdbx_branch_scheme exists
* _pdbx_nonpoly_scheme.ndb_seq_num gets added if not present
* single sugars mistakenly marked as 'branched' entity will be
relabelled to 'non-polymer' in the _entity category
* duplicated molecular entities are reduced to a single molecular entity
(AF3 adds a molecular entity per copy of a molecule)
* atom names are changed to comply with IUPAC
* ligand naming scheme ``LIG_<CHARACTER>`` is replaced with proper
molecule names (if possible, via RCSB)
* if _ma_data_ref_db set, replace wrong "id" item with "data_id",
remove duplicate entries and set necessary data in _ma_data
Args:
block (|gemmicifBlock|): CIF block to operate on.
compdict_cache (str | Path): Path to the cache file for RCSB
API calls for chemical compounds. Defaults to ``.compdict_cache``.
Returns:
None
Raises:
NotIdentifiedSingleRecordError: If a record to be deleted cannot be
found.
RuntimeError: If entities to be merged have differing data, if an
empty entity is found, or if a SMILES string cannot be identified.
RuntimeError: If entities to be merged have differing data,
if an empty entity is found, if no _entity_poly record
is found for an entity ID, if polymer types of duplicated
entities mismatch, or if a SMILES string cannot be identified via
the RCSB API.
"""
_add_auth_comp_id_2_atom_site(block)
_add_pdb_mon_id_2_pdbx_poly_seq_scheme(block)
_add_pdb_mon_id_2_pdbx_branch_scheme(block)
_add_pdbx_entity_branch_list_2_data(block)
_add_ndb_seq_num_2_pdbx_nonpoly_scheme(block)
_relabel_lone_sugars(block)
_reduce_duplicated_entities(block)
_fix_atom_names(block)
_fix_chem_comp_id_ligands(block, compdict_cache)
_fix_ma_data_ref_db(block)
[docs]
def add_data_from_json_files(
block,
input_path,
full_qe_path,
summary_qe_path,
out_zip_path,
pairwise_in_zip=True,
use_local_pairwise_if_possible=False,
):
"""Add QA metrics and metadata from AF3 JSON files to a ModelCIF block.
Reads the AF3 input JSON, full confidence JSON, and summary confidence
JSON to populate QA metric categories in ``block``. Packs the input JSON
and, optionally, pairwise QA scores into a ZIP archive at ``out_zip_path``.
Updates _ma_qa_metric, _ma_qa_metric_global, _ma_qa_metric_feature,
_ma_qa_metric_local_pairwise, _ma_qa_metric_feature_pairwise,
_ma_entry_associated_files, _ma_associated_archive_file_details,
_audit_conform|, and _ma_software_parameter (the latter only if model seeds
or recycle counts are present in the input JSON).
The function derives feature lists from _atom_site. Per-residue pairwise
scores are written to _ma_qa_metric_local_pairwise when
``use_local_pairwise_if_possible`` is ``True`` and all tokens are polymer
residues (no ``HETATM``); in all other cases _ma_qa_metric_feature_pairwise
is used.
Args:
block (|gemmicifBlock|): mmCIF block to be updated in place. Must
already contain _atom_site, _entity, _ma_qa_metric,
_ma_qa_metric_global, _ma_software_group, and _entry categories.
input_path (~pathlib.Path | str): Path to the AF3 input JSON file.
Server output: ``<JOBNAME>_job_request.json``.
full_qe_path (~pathlib.Path | str): Path to the JSON file containing
per-atom and per-token confidence arrays (pLDDT, PAE, contact
probabilities). Server output: ``<JOBNAME>_full_data_<N>.json``;
code output: ``<JOBNAME>_confidences.json``.
summary_qe_path (~pathlib.Path | str): Path to the JSON file containing
summary confidence values (pTM, ipTM, ranking score, etc.). Server
output: ``<JOBNAME>_summary_confidences_<N>.json``; code
output: ``<JOBNAME>_summary_confidences.json``.
out_zip_path (~pathlib.Path | str): Path for the output ZIP archive. The
archive always contains ``input.json`` (a copy of ``input_path``).
If ``pairwise_in_zip`` is ``True``, a ``pairwise_qa.cif`` file
with the pairwise QA metrics is also included. The value of
_ma_entry_associated_files.file_url is set to the bare filename
(i.e. without any directory component), so the main CIF file and
the ZIP must reside in the same directory.
pairwise_in_zip (bool): If ``True``, pairwise QA scores are written
to a separate ``pairwise_qa.cif`` file and packaged inside the
ZIP archive rather than embedded directly in ``block``. Defaults
to ``True``.
use_local_pairwise_if_possible (bool): If ``True`` and every token
in the structure is a polymer residue (no ``HETATM`` records),
_ma_qa_metric_local_pairwise is used for pairwise token scores
instead of _ma_qa_metric_feature_pairwise. Defaults to
``False``.
Returns:
None
Raises:
RuntimeError: If ``full_qe_path`` or ``summary_qe_path`` contain
score keys that are not listed in ``known_scores``.
"""
# No idea how to reduce args, not going fixing the "to many of this and that"
# messages at the moment
# pylint: disable=too-many-positional-arguments,too-many-arguments
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
sw_params_group_id = 1
def _add_sw_param(sw_params, data_type, name, value):
"""Helper to add stuff to SW params."""
sw_params["parameter_id"].append(len(sw_params["parameter_id"]) + 1)
sw_params["group_id"].append(sw_params_group_id)
sw_params["data_type"].append(data_type)
sw_params["name"].append(name)
sw_params["value"].append(value)
def _add_feature(
feature_dict, feature_type, entity_type, specific_dict, **kwargs
):
"""Helper to add features.
Valid feature_type: "entity instance", "residue", "atom"
Extra arguments used to fill specific_dict besides ordinal_id and
feature_id.
Can set feature_dict to None to avoid updating it and only update
specific_dict.
"""
if feature_dict is None:
feature_id = None
else:
feature_id = len(feature_dict["feature_id"]) + 1
feature_dict["feature_id"].append(feature_id)
feature_dict["feature_type"].append(cif.quote(feature_type))
feature_dict["entity_type"].append(entity_type)
specific_dict["ordinal_id"].append(len(specific_dict["ordinal_id"]) + 1)
specific_dict["feature_id"].append(feature_id)
for k, v in kwargs.items():
specific_dict[k].append(v)
return feature_id
# dict for SW params
sw_params = {
"parameter_id": [],
"group_id": [],
"data_type": [],
"name": [],
"value": [],
}
# anything we can fetch from input json?
with open(input_path, encoding="utf8") as jfh:
input_data = json.load(jfh)
if isinstance(input_data, list):
assert len(input_data) == 1
input_data = input_data[0]
assert isinstance(input_data, dict)
if "modelSeeds" in input_data:
modelseeds = input_data["modelSeeds"]
assert all(str(int(s)) == s for s in modelseeds)
_add_sw_param(
sw_params,
"integer-csv",
"modelSeeds",
",".join(str(s) for s in modelseeds),
)
# get entity types
entity_dict = block.get_mmcif_category("_entity.", raw=True)
entity_types = dict(zip(entity_dict["id"], entity_dict["type"]))
# parse atom_site into features
atom_feature = {"ordinal_id": [], "feature_id": [], "atom_id": []}
# NOTE: res_feature also works to set _ma_qa_metric_local_pairwise
# (ignore ordinal_id and feature_id for that)
res_feature = {
"ordinal_id": [],
"feature_id": [],
"label_asym_id": [],
"label_comp_id": [],
"label_seq_id": [],
}
ch_feature = {"ordinal_id": [], "feature_id": [], "label_asym_id": []}
feature = {"feature_id": [], "feature_type": [], "entity_type": []}
# IDs ordered to match AF3 scores
atom_feature_ids = []
atom_feature_id_map = {} # from atom ID to feature ID
token_feature_ids = [] # all None if use_local_pairwise
asym_feature_ids = []
# for sanity checks (vs AF3 json)
atom_chain_ids = []
token_chain_ids = []
token_res_ids = []
# also keep track of model IDs to fill scores
model_ids = []
# go row by row
# -> token-logic: HETATM gets per atom, rest per residue
atom_site_cols = [
"group_PDB",
"id",
"label_asym_id",
"label_comp_id",
"label_entity_id",
"label_seq_id",
"pdbx_PDB_model_num",
"auth_seq_id",
]
# ToDo: replace with access.get_table()
table = block.find("_atom_site.", atom_site_cols)
# first check if we can use local pairwise for token based scores
all_res_tokens = all(row["group_PDB"] != "HETATM" for row in table)
use_local_pairwise = all_res_tokens and use_local_pairwise_if_possible
# do all atom features first
for row in table:
# check model ID
if row["pdbx_PDB_model_num"] not in model_ids:
model_ids.append(row["pdbx_PDB_model_num"])
# every atom gets a token anyway
atom_feature_id = _add_feature(
feature_dict=feature,
feature_type="atom",
entity_type=entity_types[row["label_entity_id"]],
specific_dict=atom_feature,
atom_id=row["id"],
)
atom_feature_ids.append(atom_feature_id)
atom_feature_id_map[row["id"]] = atom_feature_id
atom_chain_ids.append(row["label_asym_id"])
# then all chain features
cur_label_asym_id = None
for row in table:
# check if new chain
if row["label_asym_id"] != cur_label_asym_id:
cur_label_asym_id = row["label_asym_id"]
asym_feature_ids.append(
_add_feature(
feature_dict=feature,
feature_type="entity instance",
entity_type=entity_types[row["label_entity_id"]],
specific_dict=ch_feature,
label_asym_id=row["label_asym_id"],
)
)
# then all tokens for pairwise scores
cur_label_asym_id = None
cur_label_seq_id = None
for row in table:
# check if new chain
if row["label_asym_id"] != cur_label_asym_id:
cur_label_asym_id = row["label_asym_id"]
cur_label_seq_id = None
# check if new token
if row["group_PDB"] == "HETATM":
token_feature_ids.append(atom_feature_id_map[row["id"]])
token_chain_ids.append(row["label_asym_id"])
token_res_ids.append(cif.as_int(row["auth_seq_id"]))
assert not use_local_pairwise
elif row["label_seq_id"] != cur_label_seq_id:
cur_label_seq_id = row["label_seq_id"]
res_feature_id = _add_feature(
feature_dict=None if use_local_pairwise else feature,
feature_type="residue",
entity_type=entity_types[row["label_entity_id"]],
specific_dict=res_feature,
label_asym_id=row["label_asym_id"],
label_comp_id=row["label_comp_id"],
label_seq_id=row["label_seq_id"],
)
token_feature_ids.append(res_feature_id)
token_chain_ids.append(row["label_asym_id"])
token_res_ids.append(cif.as_int(row["auth_seq_id"]))
# cannot deal with multiple models!
assert len(model_ids) == 1
model_id = model_ids[0]
# write them all
block.set_mmcif_category("_ma_feature_list.", feature, raw=True)
block.set_mmcif_category("_ma_atom_feature.", atom_feature, raw=True)
if not use_local_pairwise:
# for convenience we abuse of res_feature to fill
# _ma_qa_metric_local_pairwise which does not use the feature list!
block.set_mmcif_category(
"_ma_poly_residue_feature.", res_feature, raw=True
)
block.set_mmcif_category(
"_ma_entity_instance_feature.", ch_feature, raw=True
)
# get all scores
with open(summary_qe_path, encoding="utf8") as jfh:
summary_qe = json.load(jfh)
with open(full_qe_path, encoding="utf8") as jfh:
full_qe = json.load(jfh)
# some sanity checks
assert full_qe["atom_chain_ids"] == atom_chain_ids
assert full_qe["token_chain_ids"] == token_chain_ids
assert full_qe["token_res_ids"] == token_res_ids
# handle extra stuff
if "num_recycles" in summary_qe:
num_recycles = summary_qe["num_recycles"]
assert int(num_recycles) - num_recycles == 0
_add_sw_param(sw_params, "integer", "num_recycles", num_recycles)
# collect scores
known_extra_keys = ["num_recycles"]
all_scores = {
k: v for k, v in summary_qe.items() if k not in known_extra_keys
}
known_extra_keys = ["atom_chain_ids", "token_chain_ids", "token_res_ids"]
for k, v in full_qe.items():
assert k not in all_scores
if k not in known_extra_keys:
all_scores[k] = v
# info on known scores
if use_local_pairwise:
token_pairwise_mode = "local-pairwise"
token_pairwise_scores = "per-residue-pair"
else:
token_pairwise_mode = "per-feature-pair"
token_pairwise_scores = "per-token-pair"
# keys in all_scores linked to fill _ma_qa_metric + "scores" for handling:
# global, per-chain, per-chain-pair, per-atom, per-residue-pair,
# per-token-pair
known_scores = {
"chain_iptm": {
"mode": "per-feature",
"name": "ipTM per chain",
"type": "ipTM",
"type_other_details": False,
"scores": "per-chain",
},
"chain_pair_iptm": {
"mode": "per-feature-pair",
"name": "ipTM per chain pair",
"type": "ipTM",
"type_other_details": False,
"scores": "per-chain-pair",
},
"chain_pair_pae_min": {
"mode": "per-feature-pair",
"name": "min. PAE per chain pair",
"type": "PAE",
"type_other_details": False,
"scores": "per-chain-pair",
},
"chain_ptm": {
"mode": "per-feature",
"name": "pTM per chain",
"type": "pTM",
"type_other_details": False,
"scores": "per-chain",
},
"fraction_disordered": {
"mode": "global",
"name": "fraction of prediction which is disordered",
"type": "normalized score",
"type_other_details": False,
"scores": "global",
},
"has_clash": {
"mode": "global",
"name": "significant number of clashing atoms?",
"type": "boolean",
"type_other_details": False,
"scores": "global",
},
"iptm": {
"mode": "global",
"name": "ipTM",
"type": "ipTM",
"type_other_details": False,
"scores": "global",
},
"ptm": {
"mode": "global",
"name": "pTM",
"type": "pTM",
"type_other_details": False,
"scores": "global",
},
# ranking_score = 0.8*ipTM + 0.2*pTM + 0.5*disorder − 100*has_clash
# -> exp_range = [-100, 1.5]
"ranking_score": {
"mode": "global",
"name": "ranking score",
"type": "other",
"type_other_details": "Combined score in range [-100, 1.5] "
+ "(higher is better)",
"scores": "global",
},
"atom_plddts": {
"mode": "per-feature",
"name": "pLDDT per atom",
"type": "pLDDT to polymer",
"type_other_details": False,
"scores": "per-atom",
},
"contact_probs": {
"mode": token_pairwise_mode,
"name": "contact probability per token pair",
"type": "contact probability",
"type_other_details": False,
"scores": token_pairwise_scores,
},
"pae": {
"mode": token_pairwise_mode,
"name": "PAE per token pair",
"type": "PAE",
"type_other_details": False,
"scores": token_pairwise_scores,
},
}
extra_scores = sorted(set(all_scores) - set(known_scores))
if extra_scores:
raise RuntimeError(
f"Unexpected scores {sorted(extra_scores)} found which cannot "
f"be handled"
)
# add QA metrics
qa_metric = block.get_mmcif_category("_ma_qa_metric.")
# get IDs to use for each score
metric_ids = dict(
zip(
all_scores.keys(),
_get_ordinal_ids(qa_metric["id"], len(all_scores)),
)
)
# fix existing entries
if "type_other_details" not in qa_metric:
qa_metric["type_other_details"] = [False] * len(qa_metric["id"])
for idx in range(len(qa_metric["name"])):
qa_metric["type"][idx] = "pLDDT to polymer"
# add new metrics
for col, val_list in qa_metric.items():
if col == "id":
val_list.extend([metric_ids[score_key] for score_key in all_scores])
elif col in ["mode", "name", "type", "type_other_details"]:
val_list.extend(
[known_scores[score_key][col] for score_key in all_scores]
)
elif col == "software_group_id":
af3_sw_group = _get_af3_sw_group(block)
val_list.extend([af3_sw_group] * len(metric_ids))
else:
val_list.extend([False] * len(metric_ids))
block.set_mmcif_category("_ma_qa_metric.", qa_metric)
# add scores
def _add_scores(qa_dict, metric_id, **kwargs):
"""Helper to add values specific for QA dict.
Uses constant model_id from enclosing function!
"""
num_items = None
scores = kwargs["metric_value"]
for k, vl in kwargs.items():
assert len(vl) == len(scores)
# remove Nones from output
vcut = [v for v, s in zip(vl, scores) if s is not None]
qa_dict[k].extend(vcut)
if num_items is None:
num_items = len(vcut)
else:
assert num_items == len(vcut)
# add ordinals, model ids and metric ids
qa_dict["ordinal_id"].extend(
[(len(qa_dict["ordinal_id"]) + idx + 1) for idx in range(num_items)]
)
qa_dict["model_id"].extend([model_id] * num_items)
qa_dict["metric_id"].extend([metric_id] * num_items)
# qa_global keys: metric_id, metric_value, model_id, ordinal_id
qa_global = block.get_mmcif_category("_ma_qa_metric_global.")
qa_feature = {
k: []
for k in [
"ordinal_id",
"model_id",
"feature_id",
"metric_id",
"metric_value",
]
}
qa_local_pair = {
k: []
for k in [
"ordinal_id",
"label_asym_id_1",
"label_asym_id_2",
"label_comp_id_1",
"label_comp_id_2",
"label_seq_id_1",
"label_seq_id_2",
"metric_id",
"metric_value",
"model_id",
]
}
qa_feature_pair = {
k: []
for k in [
"ordinal_id",
"feature_id_1",
"feature_id_2",
"metric_id",
"metric_value",
"model_id",
]
}
for score_key, scores in all_scores.items():
metric_id = metric_ids[score_key]
score_handling = known_scores[score_key]["scores"]
if score_handling == "global":
_add_scores(
qa_global,
metric_id,
metric_value=[scores],
)
elif score_handling == "per-chain":
_add_scores(
qa_feature,
metric_id,
feature_id=asym_feature_ids,
metric_value=scores,
)
elif score_handling == "per-chain-pair":
for idx, sub_scores in enumerate(scores):
_add_scores(
qa_feature_pair,
metric_id,
feature_id_1=[asym_feature_ids[idx]]
* len(asym_feature_ids),
feature_id_2=asym_feature_ids,
metric_value=sub_scores,
)
elif score_handling == "per-atom":
_add_scores(
qa_feature,
metric_id,
feature_id=atom_feature_ids,
metric_value=scores,
)
elif score_handling == "per-residue-pair":
label_asym_id = res_feature["label_asym_id"]
label_comp_id = res_feature["label_comp_id"]
label_seq_id = res_feature["label_seq_id"]
for idx, sub_scores in enumerate(scores):
_add_scores(
qa_local_pair,
metric_id,
label_asym_id_1=[label_asym_id[idx]] * len(label_asym_id),
label_asym_id_2=label_asym_id,
label_comp_id_1=[label_comp_id[idx]] * len(label_comp_id),
label_comp_id_2=label_comp_id,
label_seq_id_1=[label_seq_id[idx]] * len(label_seq_id),
label_seq_id_2=label_seq_id,
metric_value=sub_scores,
)
elif score_handling == "per-token-pair":
for idx, sub_scores in enumerate(scores):
_add_scores(
qa_feature_pair,
metric_id,
feature_id_1=[token_feature_ids[idx]]
* len(token_feature_ids),
feature_id_2=token_feature_ids,
metric_value=sub_scores,
)
# start accompanying file if desired
entry = block.get_mmcif_category("_entry.")
assert len(entry["id"]) == 1
entry_id = entry["id"][0] # also needed for _ma_entry_associated_files
if pairwise_in_zip:
# create new file
d = cif.Document()
block_zip = d.add_new_block(block.name)
edit.add_category(block_zip, "_entry", entry)
block_zip.set_mmcif_category(
"_entry_link.",
{
"id": ["1"],
"entry_id": [entry_id],
"details": [
"This file is an associated file consisting of pairwise QA "
+ "metrics. This is a partial mmCIF file and can be "
+ "validated by merging with the main mmCIF file "
+ "containing the model coordinates and other associated "
+ "data."
],
},
)
block_zip.set_mmcif_category(
"_ma_qa_metric.", block.get_mmcif_category("_ma_qa_metric.")
)
block_for_pairwise_qa = block_zip
else:
block_for_pairwise_qa = block
# write all QA categories
block.set_mmcif_category("_ma_qa_metric_global.", qa_global, raw=True)
block.set_mmcif_category("_ma_qa_metric_feature.", qa_feature, raw=True)
block_for_pairwise_qa.set_mmcif_category(
"_ma_qa_metric_local_pairwise.", qa_local_pair, raw=True
)
block_for_pairwise_qa.set_mmcif_category(
"_ma_qa_metric_feature_pairwise.", qa_feature_pair, raw=True
)
# add SW params to block if needed
if len(sw_params["parameter_id"]) > 0:
sw_group_dict = block.get_mmcif_category("_ma_software_group.")
sw_group_dict["parameter_group_id"] = [sw_params_group_id]
block.set_mmcif_category("_ma_software_group.", sw_group_dict)
block.set_mmcif_category("_ma_software_parameter.", sw_params)
edit.move_category(
block,
"_ma_software_parameter",
"after:_ma_software_group",
)
# update _audit_conform to 1.4.7 (needed for all the *feature* tables)
block.set_mmcif_category(
"_audit_conform.",
{
# taken from python-modelcif
"dict_location": [
"https://raw.githubusercontent.com/ihmwg/ModelCIF/80e1e22/"
+ "dist/mmcif_ma.dic"
],
"dict_name": ["mmcif_ma.dic"],
"dict_version": ["1.4.7"],
},
)
# add info for accompanying zip file (note: this blindly overwrites data)
if not isinstance(out_zip_path, Path):
out_zip_path = Path(out_zip_path)
archive_file_id = "1"
block.set_mmcif_category(
"_ma_entry_associated_files.",
{
"id": [archive_file_id],
"entry_id": [entry_id],
"file_url": [out_zip_path.name],
"file_type": ["archive"],
"file_format": ["zip"],
"file_content": ["archive with multiple files"],
},
)
associated_archive_file_details = {
"id": ["1"],
"archive_file_id": [archive_file_id],
"file_path": ["input.json"],
"file_format": ["json"],
"file_content": ["other"],
"description": ["Input data provided to AlphaFold 3"],
}
if pairwise_in_zip:
to_add = {
"id": "2",
"archive_file_id": archive_file_id,
"file_path": "pairwise_qa.cif",
"file_format": "cif",
"file_content": "QA metrics",
"description": "Pairwise QA metrics",
}
for k, v in to_add.items():
associated_archive_file_details[k].append(v)
block.set_mmcif_category(
"_ma_associated_archive_file_details.", associated_archive_file_details
)
# package files
with zipfile.ZipFile(out_zip_path, "w", zipfile.ZIP_BZIP2) as cif_zip:
cif_zip.write(input_path, arcname="input.json")
if pairwise_in_zip:
cif_zip.writestr("pairwise_qa.cif", block_zip.as_string())
[docs]
def add_json_files_in_archive_file(
block, input_path, full_qe_path, summary_qe_path, out_zip_path
):
"""Package AF3 JSON files as accompanying data without processing them.
Writes ModelCIF categories _ma_entry_associated_files and
_ma_associated_archive_file_details to ``block`` and packages the three AF3
JSON files into a ZIP archive at ``out_zip_path``. Alternative to
:func:`add_data_from_json_files` with the same ``..._path`` parameters; use
this function when the JSON files should be stored as-is rather than parsed
for QA metrics.
.. warning::
Existing _ma_entry_associated_files and
_ma_associated_archive_file_details categories in ``block``
are overwritten without checking their prior contents.
Args:
block (|gemmicifBlock|): mmCIF block to be updated in place. Must
already contain an _entry category with exactly one
row.
input_path (~pathlib.Path | str): Path to the AF3 input JSON file. Server
output: ``<JOBNAME>_job_request.json``.
full_qe_path (~pathlib.Path | str): Path to the JSON file containing
per-atom and per-token confidence arrays (pLDDT, PAE, contact
probabilities). Server output:
``<JOBNAME>_full_data_<N>.json``; code output:
``<JOBNAME>_confidences.json``.
summary_qe_path (~pathlib.Path | str): Path to the JSON file containing
summary confidence values (pTM, ipTM, ranking score, etc.).
Server output: ``<JOBNAME>_summary_confidences_<N>.json``;
code output: ``<JOBNAME>_summary_confidences.json``.
out_zip_path (~pathlib.Path | str): Path for the output ZIP archive. The
archive contains ``input.json``, ``summary_confidences.json``,
and ``confidences.json``. The value of
_ma_entry_associated_files.file_url is set to the
bare filename (i.e. without any directory component), so the
main CIF file and the ZIP must reside in the same directory.
Returns:
None
"""
input_path = Path(input_path)
full_qe_path = Path(full_qe_path)
summary_qe_path = Path(summary_qe_path)
out_zip_path = Path(out_zip_path)
# add info for accompanying zip file (note: this blindly overwrites data)
archive_file_id = "1"
entry = block.get_mmcif_category("_entry.")
assert len(entry["id"]) == 1
entry_id = entry["id"][0]
block.set_mmcif_category(
"_ma_entry_associated_files.",
{
"id": [archive_file_id],
"entry_id": [entry_id],
"file_url": [out_zip_path.name],
"file_type": ["archive"],
"file_format": ["zip"],
"file_content": ["archive with multiple files"],
},
)
associated_archive_file_details = {
"id": ["1", "2", "3"],
"archive_file_id": [archive_file_id] * 3,
"file_path": [
"input.json",
"summary_confidences.json",
"confidences.json",
],
"file_format": ["json"] * 3,
"file_content": ["other"] * 3,
"description": [
"Input data provided to AlphaFold 3",
"Summary confidence values for either the whole structure, per "
+ "chain or per chain-pair (e.g. pTM, ipTM)",
"Confidence values for full 1D and 2D arrays (e.g. pLDDT, PAE)",
],
}
block.set_mmcif_category(
"_ma_associated_archive_file_details.",
associated_archive_file_details,
)
# package files
with zipfile.ZipFile(out_zip_path, "w", zipfile.ZIP_BZIP2) as cif_zip:
cif_zip.write(input_path, arcname="input.json")
cif_zip.write(summary_qe_path, arcname="summary_confidences.json")
cif_zip.write(full_qe_path, arcname="confidences.json")
# LocalWords: CIF homomeric mdl gemmi cif modelarchive modelcif af BLANKLINE
# LocalWords: Args gemmicifBlock NotFoundItemError RuntimeError str mmcif qa
# LocalWords: AlphaFold PMID ModelArchive ModelCIF NotFoundCategoryError qe
# LocalWords: NotIdentifiedSingleRecordError pdbx pLDDT auth pdb mon ndb PAE
# LocalWords: NotIdentifiedContextRecordError nonpoly num IUPAC ligand LIG
# LocalWords: NotIdentifiedDuplicatedRecordError RCSB compdict HETATM json
# LocalWords: JOBNAME pTM ipTM url bool