该系统围绕一个核心知识图谱构建,该图谱整合了所有组成部分。这个图谱捕获了药物、蛋白质和疾病之间的有向关系,例如药物"结合到" (binds_to) 蛋白质、"抑制" (inhibits) 靶点或"治疗" (treats) 疾病。通过将图谱置于设计的核心,处理的每一步都依赖于其结构化的医学知识图谱。
为了准备数据,系统将每个药物节点与两种表示形式关联起来:一种是分子图像,另一种是文本描述。第一种是使用 RDKit 从其 SMILES 分子式生成的分子图像。第二种是概括药物类别、官能团及其他相关细节的文本描述。图像和文本均直接连接至图谱中对应的药物节点,从而确保视觉和语言特征与底层知识结构保持一致。
对图谱本身进行建模依赖于图卷积网络 (GCN)。这些网络从每个节点的位置及其在图中的连接中学习,创建编码药物、蛋白质和疾病如何相互关联的嵌入。同时,多模态编码器将图像和文本转换为特征向量:一个 ResNet 处理分子图像,而一个 BERT 模型转换文本描述。
最终,一个图注意力网络 (GAT) 融合图嵌入与视觉和文本特征。注意力机制利用图结构对来自各模态的最重要特征进行加权。组合后的表示然后馈入预测模块,该模块确定两种药物是否会相互作用。同时,注意力权重揭示了哪些图谱连接、图像区域或文本元素对模型的决策贡献最大,从而为每次预测提供清晰的解释。
这一步骤确保了所有必要的深度学习、图处理和化学信息学软件包在环境中可用。实现过程首先安装并导入所需的库。它安装了用于神经网络的 PyTorch 和 torchvision,用于文本编码的 HuggingFace Transformers,用于图操作的 NetworkX 和 torch-geometric,用于处理分子结构的 RDKit 和 OpenBabel,以及 pandas、NumPy 和 Matplotlib 等支持库。安装完成后,导入所需的库和模块,以便在后续单元中使用。
# install necessary packages
!pip install torch torchvision transformers networkx spacy rdflib rdkit pillow scikit-learn matplotlib seaborn torch-geometric
# pip did not work
!apt-get install openbabel
!pip install openbabel-wheel#importlibraries
importtorch
importtorch.nnasnn
fromtorch.utils.dataimportDataset,DataLoader
importtorchvision.modelsasmodels
importtorchvision.transformsastransforms
fromtransformersimportBertModel,BertTokenizer
importnetworkxasnx
importnumpyasnp
importmatplotlib.pyplotasplt
importpandasaspd
importjson
importos
fromrdkitimportChem
fromrdkit.ChemimportDraw
fromPILimportImage
importio
importbase64
fromopenbabelimportopenbabel
fromtorch_geometric.dataimportData
importtorch_geometric.nnasgeom_nn首先创建一个目录用于存放药物图像。然后从公共仓库下载一个简化的 DrugBank 样本,并保存为 TSV 文件。该文件被加载到 pandas DataFrame 中,生成一张表格,包含每种药物的唯一标识符、名称、用于分子结构的 InChI 字符串,以及类别和组等描述性元数据。这个结构化数据集为后续步骤中生成视觉和文本表示奠定了基础。
# create directory for data storage
!mkdir -p data/drug_images# download DrugBank sample data (simplified version for demonstration)
!wget -q -O data/drugbank_sample.tsv https://raw.githubusercontent.com/dhimmel/drugbank/gh-pages/data/drugbank-slim.tsv# load DrugBank data
drug_df = pd.read_csv('data/drugbank_sample.tsv', sep='\t')分子结构以 InChI 格式提供。这些表示需要通过 OpenBabel 转换为 SMILES 格式。SMILES 代表简化分子输入线表示法 (Simplified Molecular Input Line Entry System),提供了一种简洁、基于文本的方式来描述化学结构。SMILES 字符串与 RDKit 等工具兼容,RDKit 可以从 SMILES 字符串生成分子图像。下面的代码展示了如何进行此转换。
# create a SMILES column by converting InChI to SMILES
def inchi_to_smiles_openbabel(inchi_str):
try: # create Open Babel OBMol object from InChI
obConversion = openbabel.OBConversion()
obConversion.SetInAndOutFormats("inchi", "smiles")
mol = openbabel.OBMol() # convert InChI to molecule # also remove extra newlines or spaces
if obConversion.ReadString(mol, inchi_str):
return obConversion.WriteString(mol).strip()
else:
return None
except Exception as e:
print(f"Error converting InChI to SMILES: {inchi_str}. Error: {e}")
return None# apply the conversion to each InChI in the dataframe
drug_df['smiles'] = drug_df['inchi'].apply(inchi_to_smiles_openbabel)系统构建了一个有向医学知识图谱,用于捕获药物、蛋白质和疾病之间的关系。每个节点代表一种药物、蛋白质或疾病,每条边编码了一种相互作用,如 binds_to、inhibits 或 treats。这些连接存储了关于药物如何影响生物靶点和病症的专家知识。
该图谱充当了一个结构化的关系信息来源,模型在处理图像和文本特征的同时会利用这些信息。通过明确地表示领域知识,图谱增强了预测准确性以及解释两种药物为何可能相互作用的能力。
# initialize a medical knowledge graph
medical_kg = nx.DiGraph()# extract drug entities from DrugBank
# limit to 50 drugs for demo
drug_entities = drug_df['name'].dropna().unique().tolist()[:50]# create drug nodes
for drug in drug_entities:
medical_kg.add_node(drug, type='drug')# add biomedical entities (proteins, targets, diseases)
protein_entities = ["Cytochrome P450", "Albumin", "P-glycoprotein", "GABA Receptor",
"Serotonin Receptor", "Beta-Adrenergic Receptor", "ACE", "HMGCR"]
disease_entities = ["Hypertension", "Diabetes", "Depression", "Epilepsy",
"Asthma", "Rheumatoid Arthritis", "Parkinson's Disease"]for protein in protein_entities:
medical_kg.add_node(protein, type='protein')for disease in disease_entities:
medical_kg.add_node(disease, type='disease')# add relationships (based on common drug mechanisms and interactions)
# drug-protein relationships
drug_protein_relations = [
("Warfarin", "binds_to", "Albumin"),
("Atorvastatin", "inhibits", "HMGCR"),
("Diazepam", "modulates", "GABA Receptor"),
("Fluoxetine", "inhibits", "Serotonin Receptor"),
("Phenytoin", "induces", "Cytochrome P450"),
("Metoprolol", "blocks", "Beta-Adrenergic Receptor"),
("Lisinopril", "inhibits", "ACE"),
("Rifampin", "induces", "P-glycoprotein"),
("Carbamazepine", "induces", "Cytochrome P450"),
("Verapamil", "inhibits", "P-glycoprotein")
]# drug-disease relationships
drug_disease_relations = [
("Lisinopril", "treats", "Hypertension"),
("Metformin", "treats", "Diabetes"),
("Fluoxetine", "treats", "Depression"),
("Phenytoin", "treats", "Epilepsy"),
("Albuterol", "treats", "Asthma"),
("Methotrexate", "treats", "Rheumatoid Arthritis"),
("Levodopa", "treats", "Parkinson's Disease")
]# known drug-drug interactions (based on actual medical knowledge)
drug_drug_interactions = [
("Goserelin", "interacts_with", "Desmopressin", "increases_anticoagulant_effect"),
("Goserelin", "interacts_with", "Cetrorelix", "increases_bleeding_risk"),
("Cyclosporine", "interacts_with", "Felypressin", "decreases_efficacy"),
("Octreotide", "interacts_with", "Cyanocobalamin", "increases_hypoglycemia_risk"),
("Tetrahydrofolic acid", "interacts_with", "L-Histidine", "increases_statin_concentration"),
("S-Adenosylmethionine", "interacts_with", "Pyruvic acid", "decreases_efficacy"),
("L-Phenylalanine", "interacts_with", "Biotin", "increases_sedation"),
("Choline", "interacts_with", "L-Lysine", "decreases_efficacy")
]# add all relationships to the knowledge graph
for s, r, o in drug_protein_relations:
if s in medical_kg and o in medical_kg:
medical_kg.add_edge(s, o, relation=r)for s, r, o in drug_disease_relations:
if s in medical_kg and o in medical_kg:
medical_kg.add_edge(s, o, relation=r)for s, r, o, mechanism in drug_drug_interactions:
if s in medical_kg and o in medical_kg:
medical_kg.add_edge(s, o, relation=r, mechanism=mechanism)每种药物都由三种互补的数据类型表示。首先,其 SMILES 表示法被转换为分子对象,并使用 RDKit 渲染成图像。
#functionto generate molecular structure imagesusingRDKit
defgenerate_molecule_image(smiles_string, size=(224,224)):
try:
mol =Chem.MolFromSmiles(smiles_string)
ifmol:
img =Draw.MolToImage(mol, size=size)
returnimg
else:
returnNone
except:
returnNone其次,通过结合药物名称、类别、组信息和任何可用的元数据来构建描述性文本。
#functionto create text descriptionfordrugs combining various information
defcreate_drug_description(row):
description = f"Drug name: {row['name']}. " ifpd.notna(row.get('category')):
description += f"Category: {row['category']}. " ifpd.notna(row.get('groups')):
description += f"Groups: {row['groups']}. " ifpd.notna(row.get('description')):
description += f"Description: {row['description']}"第三,图谱被嵌入其中,具体而言,每个节点和关系最初是随机向量,然后进行迭代调整,使得对于每个真实的连接,一个实体的向量加上其关系的向量会使其接近与之相连实体的向量。经过多次迭代,这形成了一个嵌入空间,其中相连元素自然聚类,并且关系方向由关系向量编码。结果是一对查找表,将每个节点和关系映射到紧凑、可训练的坐标,反映了知识图谱的完整结构。
# convert NetworkX graph to PyG graph for modern graph neural network processing
def convert_nx_to_pyg(nx_graph): # create node mappings
node_to_idx = {node: i for i, node in enumerate(nx_graph.nodes())} # create edge lists
src_nodes = []
dst_nodes = []
edge_types = []
edge_type_to_idx = {} for u, v, data in nx_graph.edges(data=True):
relation = data.get('relation', 'unknown')
if relation not in edge_type_to_idx:
edge_type_to_idx[relation] = len(edge_type_to_idx)
src_nodes.append(node_to_idx[u])
dst_nodes.append(node_to_idx[v])
edge_types.append(edge_type_to_idx[relation]) # create PyG graph
edge_index = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
edge_type = torch.tensor(edge_types, dtype=torch.long) # create node features
node_types = []
for node in nx_graph.nodes():
node_type = nx_graph.nodes[node].get('type', 'unknown')
node_types.append(node_type) # one-hot encode node types
unique_node_types = sorted(set(node_types))
node_type_to_idx = {nt: i for i, nt in enumerate(unique_node_types)}
node_type_features = torch.zeros(len(node_types), len(unique_node_types))
for i, nt in enumerate(node_types):
node_type_features[i, node_type_to_idx[nt]] = 1.0 # create PyG Data object with the proper attributes
g = Data(
edge_index=edge_index,
edge_type=edge_type,
x=node_type_features # node features in PyG are stored in 'x'
) # create reverse mappings for later use
idx_to_node = {idx: node for node, idx in node_to_idx.items()}
idx_to_edge_type = {idx: edge_type for edge_type, idx in edge_type_to_idx.items()} return g, node_to_idx, idx_to_node, edge_type_to_idx, idx_to_edge_type# convert medical_kg to DGL graph
pyg_graph, node_to_idx, idx_to_node, edge_type_to_idx, idx_to_edge_type = convert_nx_to_pyg(medical_kg)这些视觉、文本和结构化表示被保存,以便模型可以将其融合用于相互作用预测。
# process drug data to create multi-modal representations
drug_data = []for idx, row in drug_df.iterrows():
if row['name'] in drug_entities and pd.notna(row.get('smiles')): # generate molecule image
img = generate_molecule_image(row['smiles']) if img:
img_path = f"data/drug_images/{row['drugbank_id']}.png"
img.save(img_path) # Create text description
description = create_drug_description(row) # Store drug information
drug_data.append({
'id': row['drugbank_id'],
'name': row['name'],
'smiles': row['smiles'],
'description': description,
'image_path': img_path
})drug_data_df = pd.DataFrame(drug_data)MultimodalNodeEncoder 创建了一个单一编码器,将每个节点的分子图像及其文本摘要转换为兼容的特征向量。首先,它对原始化学图应用深度卷积网络,将其提炼成紧凑的视觉指纹。同时,它通过预训练语言模型处理药物描述,提取语义摘要。然后将两者的输出映射到同一个向量空间中,以便视觉和文本信号可以在知识图谱结构的引导下有意义地结合。
# processes visual and textual featuresfornodes
classMultimodalNodeEncoder(nn.Module): def__init__(self, output_dim=128):
super(MultimodalNodeEncoder, self).__init__()
# imageencoder(ResNet)
resnet = models.resnet18(pretrained=True)
# remove the final fully connected layer to get512features
self.image_encoder= nn.Sequential(*list(resnet.children())[:-1])
self.image_projection= nn.Linear(512, output_dim) # textencoder(BERT)
self.tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
self.text_encoder=BertModel.from_pretrained('bert-base-uncased')
#BERTbase outputs768features
self.text_projection= nn.Linear(768, output_dim) defforward(self, image, text):
# image encoding
img_features = self.image_encoder(image).squeeze(-1).squeeze(-1)
img_features = self.image_projection(img_features) # text encoding
encoded_input = self.tokenizer(text, padding=True, truncation=True,
return_tensors="pt", max_length=128)
# move encoded input to the same deviceasthe image
input_ids = encoded_input['input_ids'].to(image.device)
attention_mask = encoded_input['attention_mask'].to(image.device) text_outputs = self.text_encoder(input_ids=input_ids,
attention_mask=attention_mask)
# use the [CLS] tokenembedding(first token)
text_features = text_outputs.last_hidden_state[:,0, :]
text_features = self.text_projection(text_features) returnimg_features, text_featuresKG引导的多模态模型融合每个节点的视觉、文本和类型嵌入,并在知识图谱的指导下预测药物-药物相互作用。它首先将每个节点的图像和文本输出投影到共享空间,并为其节点类型分配单独的嵌入。这些嵌入随后在图上传播,使得每个节点将其自身特征与其邻居收集到的信号相融合。注意力步骤根据连接的强度和类型重新加权这些融合后的特征。评估一对药物时,模型获取它们精炼后的节点嵌入,通过连接、元素级乘积和差值进行组合,然后将结果馈入预测头,产生相互作用的概率。通过让图谱的拓扑决定多模态信号如何融合,模型生成的预测既准确又可直接追溯到潜在的网络结构。
# defineKG-guidedMultimodalModel
classKGGuidedMultimodalModel(nn.Module): def__init__(self, pyg_graph, num_node_types, num_edge_types, node_to_idx, idx_to_node, hidden_dim=128):
super(KGGuidedMultimodalModel, self).__init__()
self.pyg_graph= pyg_graph
self.node_to_idx= node_to_idx
self.idx_to_node= idx_to_node
self.hidden_dim= hidden_dim # multimodal encoderforprocessing node-associated data
self.multimodal_encoder=MultimodalNodeEncoder(output_dim=hidden_dim) # node type embeddings
self.node_type_embedding= nn.Embedding(num_node_types, hidden_dim) #GraphNeuralNetworklayersforknowledge graphprocessing(PyGGCNConvinsteadofdglnn.GraphConv)
self.gnn_layers= nn.ModuleList([
geom_nn.GCNConv(hidden_dim, hidden_dim),
geom_nn.GCNConv(hidden_dim, hidden_dim),
]) #GraphAttentionNetworkforintegrating multimodal featureswithgraphstructure(PyGGATConv)
# explicitly set output dimension so total output ishidden_dim(not hidden_dim * num_heads)
self.gat_layer= geom_nn.GATConv(hidden_dim, hidden_dim// 4, heads=4) # relation prediction layer - updated to match the actual input dimensions we'll have
self.relation_prediction= nn.Sequential(
nn.Linear(hidden_dim *4, hidden_dim *2),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim *2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim,1)
) defget_node_representation(self, node_name, image=None, text=None):
ifnode_name notinself.node_to_idx:
# handle unknown nodes
returntorch.zeros(self.hidden_dim, device=self.pyg_graph.edge_index.device) node_idx = self.node_to_idx[node_name] # get node type features - use x insteadofndata['type']
node_type_feat = self.pyg_graph.x[node_idx]
node_type_embedding = self.node_type_embedding(torch.argmax(node_type_feat)) #ifmultimodal data is provided, process it
ifimage is notNoneand text is notNone:
img_feat, text_feat = self.multimodal_encoder(image, text) # squeeze out the batch dimension to match shapes
img_feat = img_feat.squeeze(0)
text_feat = text_feat.squeeze(0) # knowledge graph structure guides how multimodal features are integrated
# use node_type_embeddingasa query to attend to multimodal features
attention_weights = torch.softmax(
torch.matmul(
torch.stack([img_feat, text_feat, node_type_embedding]),
node_type_embedding
),
dim=0
) # weighted combinationoffeatures
combined_feat = (
attention_weights[0] * img_feat +
attention_weights[1] * text_feat +
attention_weights[2] * node_type_embedding
) returncombined_feat
else:
#fornodes without multimodal data, just use type embedding
returnnode_type_embedding defforward(self, drug1_image, drug1_text, drug1_name, drug2_image, drug2_text, drug2_name):
# process the entire graph first
device = self.pyg_graph.edge_index.device
x = torch.zeros((self.pyg_graph.x.size(0), self.hidden_dim), device=device) # initialize known node features
fori, node_nameinenumerate([drug1_name, drug2_name]):
ifnode_nameinself.node_to_idx:
node_idx = self.node_to_idx[node_name]
ifi ==0:
x[node_idx] = self.get_node_representation(node_name, drug1_image, drug1_text)
else:
x[node_idx] = self.get_node_representation(node_name, drug2_image, drug2_text) # apply graph convolutions to propagate information -PyGstyle
edge_index = self.pyg_graph.edge_index
forlayerinself.gnn_layers:
x =layer(x, edge_index)
x = torch.relu(x) # apply graph attention to integrate features -PyGstyle
x = self.gat_layer(x, edge_index) # get final representationsforthe two drugs
drug1_idx = self.node_to_idx.get(drug1_name,0)
drug2_idx = self.node_to_idx.get(drug2_name,0) drug1_repr = x[drug1_idx]
drug2_repr = x[drug2_idx] # predict interaction
# concatenate representationsinmultiple ways to capture relationship
concat_repr = torch.cat([
drug1_repr,
drug2_repr,
drug1_repr * drug2_repr,
torch.abs(drug1_repr - drug2_repr)
], dim=0) interaction_prob = torch.sigmoid(self.relation_prediction(concat_repr.unsqueeze(0)).squeeze())
returninteraction_prob更大图谱中如何关联,会构建一个焦点子图。它首先会在图谱中查找两种药物之间是否存在任何直接边,如果找到则记录其属性。接下来,它识别与两种药物都关联的蛋白质或疾病,揭示共享机制。最后,它追踪所有不超过给定长度的简单路径,以揭示通过中间节点的间接连接。结果是由关键节点和边组成的紧凑网络,它捕获了预测相互作用背后的领域知识,并指导下游层强调最相关的多模态特征。
#functionto retrieve knowledge subgraph relevant to a drug pair
defretrieve_knowledge_subgraph(graph, drug1, drug2, max_path_length=3):
relevant_knowledge = {
'direct_interaction':None,
'common_targets': [],
'paths': []
} # checkfordirect interaction
ifgraph.has_edge(drug1, drug2):
edge_data = graph.get_edge_data(drug1, drug2)
relevant_knowledge['direct_interaction'] = edge_data # find commontargets(proteins, diseases)
drug1_neighbors =set(graph.neighbors(drug1))ifdrug1ingraphelseset()
drug2_neighbors =set(graph.neighbors(drug2))ifdrug2ingraphelseset() common_neighbors = drug1_neighbors.intersection(drug2_neighbors)
forcommon_nodeincommon_neighbors:
node_type = graph.nodes[common_node].get('type','')
ifnode_type =='protein'or node_type =='disease':
relevant_knowledge['common_targets'].append(common_node) # find paths betweendrugs(up to max_path_length)
try:
paths =list(nx.all_simple_paths(graph, drug1, drug2, cutoff=max_path_length))
relevant_knowledge['paths'] = paths
except(nx.NetworkXError, nx.NodeNotFound):
#Handlecases where pathsdonot exist or nodes are notingraph
pass returnrelevant_knowledge该函数负责准备每个训练批次,首先剔除所有加载失败或不完整的样本。然后,它将所有有效的分子图像堆叠成这对药物的批量张量,同时将其相应的文本摘要和标识符收集到并行列表中。相互作用标签也类似地组合成一个张量。通过返回一个包含这些批量化组件的统一字典——如果剩余样本无效则返回空占位符——它确保模型总是接收到结构良好、同质的输入,尽管底层数据存在异构性和偶尔缺失。
# custom collatefunctionto handleNonevalues
defcustom_collate_fn(batch):
# filter outNonevalues
batch = [itemforiteminbatchifitem is notNone] #returnempty batchifall items wereNone
iflen(batch) ==0:
return{
'drug1_img': torch.tensor([]),
'drug1_text': [],
'drug1_name': [],
'drug2_img': torch.tensor([]),
'drug2_text': [],
'drug2_name': [],
'label': torch.tensor([])
} # process non-Noneitems
drug1_imgs = torch.stack([item['drug1_img']foriteminbatch])
drug1_texts = [item['drug1_text']foriteminbatch]
drug1_names = [item['drug1_name']foriteminbatch] drug2_imgs = torch.stack([item['drug2_img']foriteminbatch])
drug2_texts = [item['drug2_text']foriteminbatch]
drug2_names = [item['drug2_name']foriteminbatch] labels = torch.stack([item['label']foriteminbatch]) return{
'drug1_img': drug1_imgs,
'drug1_text': drug1_texts,
'drug1_name': drug1_names,
'drug2_img': drug2_imgs,
'drug2_text': drug2_texts,
'drug2_name': drug2_names,
'label': labels
}训练示例结合所有真实相互作用和一组匹配的随机非相互作用对。该过程首先提取所有已知的药物相互作用对,然后采样等量的负样本对以平衡数据集。获取一个示例时,加载并预处理每种药物的分子图像和文本摘要,跳过任何缺少数据的对,生成包含两种药物的图像、描述、名称和一个二元标签的记录。通过将正样本与负样本配对、应用一致的图像变换以及稳健处理缺失数据,数据集提供了可靠、即用型的批次,用于训练相互作用预测模型。
# define datasetforDDIprediction
classDDIDataset(Dataset): def__init__(self, drug_data_df, drug_drug_interactions, medical_kg, node_to_idx, transform=None):
self.drug_data= drug_data_df
self.drug_name_to_idx= {row['name']: ifori, rowindrug_data_df.iterrows()}
self.node_to_idx= node_to_idx
self.transform= transform or transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
]) # create pairsofdrugswithinteraction labels
self.pairs= []
drug_names =list(self.drug_name_to_idx.keys()) # positivesamples(known interactions)
forinteractionindrug_drug_interactions:
drug1, _, drug2, _ = interaction
ifdrug1indrug_names and drug2indrug_names:
#1forpositive interaction
self.pairs.append((drug1, drug2,1))
positive_pairs =set((d1, d2)ford1, d2, _inself.pairs) # generate some negative samples
np.random.seed(42)
neg_count =0
max_neg =len(self.pairs)
whileneg_count <max_neg:
i, j = np.random.choice(len(drug_names),2, replace=False)
drug1, drug2 = drug_names[i], drug_names[j]
if(drug1, drug2) notinpositive_pairsand(drug2, drug1) notinpositive_pairs:
#0fornegative interaction
self.pairs.append((drug1, drug2,0))
neg_count +=1 def__len__(self):
returnlen(self.pairs) def__getitem__(self, idx):
try:
drug1_name, drug2_name, label = self.pairs[idx] # get drug1 data
drug1_idx = self.drug_name_to_idx[drug1_name]
drug1_data = self.drug_data.iloc[drug1_idx] # load drug1 imagewitherror handling
try:
drug1_img =Image.open(drug1_data['image_path']).convert('RGB')
drug1_img = self.transform(drug1_img)
exceptExceptionase:
print(f"Error loading drug1 image for {drug1_name}: {str(e)}")
returnNone drug1_text = drug1_data['description'] # get drug2 data
drug2_idx = self.drug_name_to_idx[drug2_name]
drug2_data = self.drug_data.iloc[drug2_idx] # load drug2 imagewitherror handling
try:
drug2_img =Image.open(drug2_data['image_path']).convert('RGB')
drug2_img = self.transform(drug2_img)
exceptExceptionase:
print(f"Error loading drug2 image for {drug2_name}: {str(e)}")
returnNone drug2_text = drug2_data['description'] return{
'drug1_img': drug1_img,
'drug1_text': drug1_text,
'drug1_name': drug1_name,
'drug2_img': drug2_img,
'drug2_text': drug2_text,
'drug2_name': drug2_name,
'label': torch.tensor(label, dtype=torch.float32)
}
exceptExceptionase:
print(f"Error in __getitem__ for index {idx}: {str(e)}")
returnNone模型训练开始时,将模型及其图谱移至选定的设备(GPU 或 CPU),并进行设定的轮数(epochs),每轮分为训练阶段和验证阶段。训练期间,成对药物的批次通过网络传递以产生相互作用分数,计算二元交叉熵损失,Adam 优化器通过反向传播更新所有参数。损失和正确预测计数被汇总,以便在每轮结束时报告平均训练损失和准确率,为保持稳定性跳过空或格式错误的批次。然后过程切换到评估模式——运行相同的批次但不进行梯度更新——以衡量验证损失和准确率。
# trainingfunction
deftrain_kg4mm_model(model, train_loader, val_loader, epochs=5):
device = torch.device('cuda'iftorch.cuda.is_available()else'cpu')
model = model.to(device)
model.pyg_graph= model.pyg_graph.to(device) criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) forepochinrange(epochs):
# training phase
model.train()
train_loss =0
train_correct =0
batch_count =0 forbatchintrain_loader:
# skip empty batches
iflen(batch['drug1_img']) ==0:
print("Skipping empty batch")
continue batch_count +=1 try:
drug1_img = batch['drug1_img'].to(device)
drug1_text = batch['drug1_text']
drug1_name = batch['drug1_name']
drug2_img = batch['drug2_img'].to(device)
drug2_text = batch['drug2_text']
drug2_name = batch['drug2_name']
labels = batch['label'].to(device) # forward pass - processing one pair at a timeforclarity
batch_size =len(drug1_name)
outputs = torch.zeros(batch_size,1, device=device) foriinrange(batch_size):
#thisloop isforillustration -inpractice, handle batch processing more efficiently
output =model(
drug1_img[i].unsqueeze(0),
[drug1_text[i]],
drug1_name[i],
drug2_img[i].unsqueeze(0),
[drug2_text[i]],
drug2_name[i]
)
outputs[i] = output # calculate loss
loss =criterion(outputs, labels.unsqueeze(1)) # backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item() # calculate accuracy
predictions = (outputs >=0.5).float()
train_correct += (predictions == labels.unsqueeze(1)).sum().item() print(f"Batch {batch_count}: Loss: {loss.item():.4f}") exceptExceptionase:
print(f"Error processing batch {batch_count}: {str(e)}")
importtraceback
traceback.print_exc()
continue avg_train_loss = train_loss /max(1, batch_count)
train_acc = train_correct /max(1, batch_count * batch['drug1_img'].size(0)) print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}') # validation phase
model.eval()
val_loss =0
val_correct =0
val_batch_count =0 withtorch.no_grad():
forbatchinval_loader:
# skip empty batches
iflen(batch['drug1_img']) ==0:
continue val_batch_count +=1 try:
drug1_img = batch['drug1_img'].to(device)
drug1_text = batch['drug1_text']
drug1_name = batch['drug1_name']
drug2_img = batch['drug2_img'].to(device)
drug2_text = batch['drug2_text']
drug2_name = batch['drug2_name']
labels = batch['label'].to(device) # forward pass - processing one pair at a timeforclarity
batch_size =len(drug1_name)
outputs = torch.zeros(batch_size,1, device=device) foriinrange(batch_size):
output =model(
drug1_img[i].unsqueeze(0),
[drug1_text[i]],
drug1_name[i],
drug2_img[i].unsqueeze(0),
[drug2_text[i]],
drug2_name[i]
)
outputs[i] = output # calculate loss
loss =criterion(outputs, labels.unsqueeze(1))
val_loss += loss.item() # calculate accuracy
predictions = (outputs >=0.5).float()
val_correct += (predictions == labels.unsqueeze(1)).sum().item() exceptExceptionase:
print(f"Error processing validation batch {val_batch_count}: {str(e)}")
continue avg_val_loss = val_loss /max(1, val_batch_count)
val_acc = val_correct /max(1, val_batch_count *4) #Assumingbatch_size=4 print(f'Epoch {epoch+1}/{epochs}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}') returnmodel准备好的数据集然后生成成对的药物示例——每个示例都包含其分子图像和文本摘要——并将它们分割为训练集和验证集,用于性能跟踪。数据加载器将这些多模态示例(图像、描述和标签)打包成批次,以便它们顺畅地输入模型。KG引导的预测网络被实例化,其维度源自图的节点和边类型,确保其层与知识图谱的结构对齐。最后,训练循环运行固定轮数,交替在训练数据上更新模型并在验证集上衡量其准确率。这一序列完成了从数据准备到主动、图驱动学习的转变。
# initialize dataset and model
ddi_dataset = DDIDataset(drug_data_df, drug_drug_interactions, medical_kg, node_to_idx)# split dataset into train and validation sets
train_size = int(0.8 \* len(ddi_dataset))
val_size = len(ddi_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(ddi_dataset, [train_size, val_size])# create data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=custom_collate_fn)# initialize the model with the DGL graph
num_node_types = pyg_graph.x.shape[1]
num_edge_types = len(edge_type_to_idx)# initialize the KG-guided multimodal model
model = KGGuidedMultimodalModel(pyg_graph, num_node_types, num_edge_types, node_to_idx, idx_to_node)# train the model
trained_model = train_kg4mm_model(model, train_loader, val_loader, epochs=5)进行预测时,模型首先加载每种药物处理后的图像和文本摘要,并确定每种药物在知识图谱中的位置。然后,它产生一个概率分数,显示视觉、文本和图谱信息如何协同作用。同时,系统检查图谱是否存在两种药物之间的任何直接链接、它们都连接的任何蛋白质或疾病,以及连接它们的任何不超过给定长度的简单路径。该概率被转换为低、中或高风险等级。然后构建解释,突出显示已知的相互作用机制、共享靶点以及指导决策的关键图谱路径。最后,系统根据风险等级提供示例性的临床建议,清楚展示知识图谱如何塑造预测及其解释。
defpredict_interaction(model, drug1_name, drug2_name, drug_data_df, medical_kg):
device = torch.device('cuda'iftorch.cuda.is_available()else'cpu')
model = model.to(device)
model.eval() # get drug indices
drug1_idx = drug_data_df[drug_data_df['name'] == drug1_name].index[0]
drug2_idx = drug_data_df[drug_data_df['name'] == drug2_name].index[0] # get drug data
drug1_data = drug_data_df.iloc[drug1_idx]
drug2_data = drug_data_df.iloc[drug2_idx] # prepare images
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
]) drug1_img = Image.open(drug1_data['image_path']).convert('RGB')
drug1_img = transform(drug1_img).unsqueeze(0).to(device)
drug1_text = [drug1_data['description']] drug2_img = Image.open(drug2_data['image_path']).convert('RGB')
drug2_img = transform(drug2_img).unsqueeze(0).to(device)
drug2_text = [drug2_data['description']] # get knowledge subgraph for the drug pair
knowledge = retrieve_knowledge_subgraph(medical_kg, drug1_name, drug2_name) # make prediction
withtorch.no_grad():
interaction_prob = model(
drug1_img,
drug1_text,
drug1_name,
drug2_img,
drug2_text,
drug2_name
) returninteraction_prob.item(), knowledgedef explain_interaction_prediction(drug1_name, drug2_name, probability, knowledge):
explanation =f"KG-guided multimodal analysis for interaction between{drug1_name}and{drug2_name}:\n\n" # interpret the probability
ifprobability >0.8:
risk_level ="High"
elifprobability >0.5:
risk_level ="Moderate"
else:
risk_level ="Low" explanation +=f"Interaction Risk Level:{risk_level}(Probability:{probability:.2f})\n\n" # explain based on knowledge graph structure
explanation +="Knowledge Graph Analysis:\n" ifknowledge['direct_interaction']:
mechanism = knowledge['direct_interaction'].get('mechanism','unknown mechanism')
explanation +=f"✓ Direct Connection: The knowledge graph contains a documented interaction between these drugs with{mechanism}.\n\n" ifknowledge['common_targets']:
explanation +="✓ Common Target Nodes: These drugs connect to shared entities in the knowledge graph:\n"
fortargetinknowledge['common_targets']:
explanation +=f" -{target}\n"
explanation +=" This graph structure suggests potential interaction through common binding sites or pathways.\n\n" ifknowledge['paths']andlen(knowledge['paths']) >0:
explanation +="✓ Knowledge Graph Pathways: The model identified these connecting paths in the graph:\n"
fori, pathinenumerate(knowledge['paths'][:3]):
path_str =" → ".join(path)
explanation +=f" - Path{i+1}:{path_str}\n"
explanation +=" These graph structures guided the multimodal feature integration for prediction.\n\n" # focus on how KG structure guided the interpretation
explanation +="Multimodal Integration Process:\n"
explanation +=" - Knowledge graph structure determined which drug properties were most relevant\n"
explanation +=" - Graph neural networks analyzed the local neighborhood of both drug nodes\n"
explanation +=" - Node position in the graph guided the weighting of visual and textual features\n\n" # clinical implications (example - in a real system, this would be more comprehensive)
ifprobability >0.5:
explanation +="Clinical Recommendations (based on graph analysis):\n"
explanation +=" - Consider alternative medications not connected in similar graph patterns\n"
explanation +=" - If co-administration is necessary, monitor for interaction effects\n"
explanation +=" - Review other drugs connected to the same nodes for potential complications\n"
else:
explanation +="Clinical Recommendations (based on graph analysis):\n"
explanation +=" - Standard monitoring advised\n"
explanation +=" - The knowledge graph structure suggests minimal interaction concerns\n" returnexplanation为了说明完整的工作流程,选择两种药物,并加载并像训练时一样预处理它们预先生成的图像和文本摘要。这些多模态输入然后通过训练好的模型传递——此时处于评估模式——产生一个量化其相互作用风险的概率分数。同时,为了可视化和解释,该过程通过收集所有直接连接、共享生物靶点以及连接它们的任何不超过给定长度的简单路径,提取知识图谱的相关部分,然后通过增加一层直接邻居来扩充此子图以获取更广泛的上下文。
提取出的子图采用清晰的配色方案进行布局,有效区分了两种药物、蛋白质、疾病及其他实体,使网络结构一目了然,增强了可读性和分析效率。紧随其后的是清晰的自然语言解释,通过突出显示任何已记录的相互作用机制、共享靶点和关键连接路径,将概率分数与这些图谱特征关联起来。风险估计、颜色编码可视化和叙述性解释共同说明了知识图谱的拓扑如何指导了视觉和文本信号的融合,并为模型的预测提供了透明的理由。
# example usage
drug_pair = ("Goserelin", "Desmopressin")
prob, knowledge = predict_interaction(trained_model, drug_pair[0], drug_pair[1], drug_data_df, medical_kg)print(f"Predicted interaction probability between {drug_pair[0]} and {drug_pair[1]}: {prob:.4f}")print("\nKnowledge Graph Structure Analysis:")
print(f"Direct connection: {knowledge['direct_interaction']}")
print(f"Common target nodes: {knowledge['common_targets']}")
print(f"Graph paths connecting drugs:")
for path in knowledge['paths']:
print(f" {' -> '.join(path)}")# visualize the subgraph for these drugs to show the KG-guided approach
plt.figure(figsize=(12, 8))
subgraph_nodes = set([drug_pair[0], drug_pair[1]])
# add intermediate nodes in paths to highlight the KG structure
for path in knowledge['paths']:
subgraph_nodes.update(path) # add a level of neighbors to show context in KG
neighbors_to_add = set()
for node in subgraph_nodes:
if node in medical_kg:
neighbors_to_add.update(list(medical_kg.neighbors(node))[:3])
subgraph_nodes.update(neighbors_to_add)subgraph = medical_kg.subgraph(subgraph_nodes)# use different colors for node types to emphasize KG structure
node_colors = []
for node in subgraph.nodes():
if node == drug_pair[0] or node == drug_pair[1]:
node_colors.append('lightcoral')
elif subgraph.nodes[node].get('type') == 'protein':
node_colors.append('lightblue')
elif subgraph.nodes[node].get('type') == 'disease':
node_colors.append('lightgreen')
else:
node_colors.append('lightgray')pos = nx.spring_layout(subgraph, seed=42)
nx.draw(subgraph, pos, with_labels=True, node_color=node_colors,
node_size=2000, arrows=True, arrowsize=20)edge_labels = {(s, o): subgraph[s][o]['relation'] for s, o in subgraph.edges()}
nx.draw_networkx_edge_labels(subgraph, pos, edge_labels=edge_labels)plt.title(f"Knowledge Graph Structure Guiding {drug_pair[0]} and {drug_pair[1]} Interaction Analysis")
plt.savefig('kg_guided_interaction_analysis.png')
plt.show()# show explanation
explanation = explain_interaction_prediction(drug_pair[0], drug_pair[1], prob, knowledge)
print(explanation)在戈舍瑞林 (Goserelin) 和去氨加压素 (Desmopressin) 上测试时,模型返回了 0.54 的概率,将其归类为中等风险对。知识图谱揭示了两种药物之间存在一个直接的"相互作用"(interacts_with)关系,该关系的具体描述/标签为"增加抗凝作用"(increases_anticoagulant_effect),没有共享的蛋白质或疾病连接,因此模型主要关注了该机制。在子图绘制中,两种药物以红色突出显示,单条有向边突出显示,清晰显示是哪种关系驱动了预测。
KG4MM 的研究表明,将知识图谱作为工作流程的核心,可以更好地融合分子图像和文本,效果优于单一来源的方法。每个预测都由清晰的图谱证据支持——直接边、共享靶点和连接路径——这使得结果与真实的生物关系关联起来。通过这样做,KG4MM 在生物化学、材料科学和医学诊断等领域都提供了更强的预测能力和内置的可解释性。
| 欢迎光临 链载Ai (https://www.lianzai.com/) | Powered by Discuz! X3.5 |