# create a SMILES column by converting InChI to SMILES
def inchi_to_smiles_openbabel(inchi_str): try: # create Open Babel OBMol object from InChI obConversion = openbabel.OBConversion() obConversion.SetInAndOutFormats("inchi", "smiles") mol = openbabel.OBMol() # convert InChI to molecule # also remove extra newlines or spaces if obConversion.ReadString(mol, inchi_str): return obConversion.WriteString(mol).strip() else: return None except Exception as e: print(f"Error converting InChI to SMILES: {inchi_str}. Error: {e}") return None# apply the conversion to each InChI in the dataframe drug_df['smiles'] = drug_df['inchi'].apply(inchi_to_smiles_openbabel)
medical_kg = nx.DiGraph()# extract drug entities from DrugBank
# limit to 50 drugs for demo
drug_entities = drug_df['name'].dropna().unique().tolist()[:50]# create drug nodes for drug in drug_entities: medical_kg.add_node(drug, type='drug')# add biomedical entities (proteins, targets, diseases) protein_entities = ["Cytochrome P450", "Albumin", "P-glycoprotein", "GABA Receptor", "Serotonin Receptor", "Beta-Adrenergic Receptor", "ACE", "HMGCR"] disease_entities = ["Hypertension", "Diabetes", "Depression", "Epilepsy", "Asthma", "Rheumatoid Arthritis", "Parkinson's Disease"]for protein in protein_entities: medical_kg.add_node(protein, type='protein')for disease in disease_entities: medical_kg.add_node(disease, type='disease')# add relationships (based on common drug mechanisms and interactions)
# drug-protein relationships
drug_protein_relations = [ ("Warfarin", "binds_to", "Albumin"), ("Atorvastatin", "inhibits", "HMGCR"), ("Diazepam", "modulates", "GABA Receptor"), ("Fluoxetine", "inhibits", "Serotonin Receptor"), ("Phenytoin", "induces", "Cytochrome P450"), ("Metoprolol", "blocks", "Beta-Adrenergic Receptor"), ("Lisinopril", "inhibits", "ACE"), ("Rifampin", "induces", "P-glycoprotein"), ("Carbamazepine", "induces", "Cytochrome P450"), ("Verapamil", "inhibits", "P-glycoprotein") ]# drug-disease relationships drug_disease_relations = [ ("Lisinopril", "treats", "Hypertension"), ("Metformin", "treats", "Diabetes"), ("Fluoxetine", "treats", "Depression"), ("Phenytoin", "treats", "Epilepsy"), ("Albuterol", "treats", "Asthma"), ("Methotrexate", "treats", "Rheumatoid Arthritis"), ("Levodopa", "treats", "Parkinson's Disease") ]# known drug-drug interactions (based on actual medical knowledge) drug_drug_interactions = [ ("Goserelin", "interacts_with", "Desmopressin", "increases_anticoagulant_effect"), ("Goserelin", "interacts_with", "Cetrorelix", "increases_bleeding_risk"), ("Cyclosporine", "interacts_with", "Felypressin", "decreases_efficacy"), ("Octreotide", "interacts_with", "Cyanocobalamin", "increases_hypoglycemia_risk"), ("Tetrahydrofolic acid", "interacts_with", "L-Histidine", "increases_statin_concentration"), ("S-Adenosylmethionine", "interacts_with", "Pyruvic acid", "decreases_efficacy"), ("L-Phenylalanine", "interacts_with", "Biotin", "increases_sedation"), ("Choline", "interacts_with", "L-Lysine", "decreases_efficacy") ]# add all relationships to the knowledge graph for s, r, o in drug_protein_relations: if s in medical_kg and o in medical_kg: medical_kg.add_edge(s, o, relation=r)for s, r, o in drug_disease_relations: if s in medical_kg and o in medical_kg: medical_kg.add_edge(s, o, relation=r)for s, r, o, mechanism in drug_drug_interactions: if s in medical_kg and o in medical_kg: medical_kg.add_edge(s, o, relation=r, mechanism=mechanism)
# convert NetworkX graph to PyG graph for modern graph neural network processing
def convert_nx_to_pyg(nx_graph): # create node mappings node_to_idx = {node: i for i, node in enumerate(nx_graph.nodes())} # create edge lists src_nodes = [] dst_nodes = [] edge_types = [] edge_type_to_idx = {} for u, v, data in nx_graph.edges(data=True): relation = data.get('relation', 'unknown') if relation not in edge_type_to_idx: edge_type_to_idx[relation] = len(edge_type_to_idx) src_nodes.append(node_to_idx[u]) dst_nodes.append(node_to_idx[v]) edge_types.append(edge_type_to_idx[relation]) # create PyG graph edge_index = torch.tensor([src_nodes, dst_nodes], dtype=torch.long) edge_type = torch.tensor(edge_types, dtype=torch.long) # create node features node_types = [] for node in nx_graph.nodes(): node_type = nx_graph.nodes[node].get('type', 'unknown') node_types.append(node_type) # one-hot encode node types unique_node_types = sorted(set(node_types)) node_type_to_idx = {nt: i for i, nt in enumerate(unique_node_types)} node_type_features = torch.zeros(len(node_types), len(unique_node_types)) for i, nt in enumerate(node_types): node_type_features[i, node_type_to_idx[nt]] = 1.0 # create PyG Data object with the proper attributes g = Data( edge_index=edge_index, edge_type=edge_type, x=node_type_features # node features in PyG are stored in 'x' ) # create reverse mappings for later use idx_to_node = {idx: node for node, idx in node_to_idx.items()} idx_to_edge_type = {idx: edge_type for edge_type, idx in edge_type_to_idx.items()} return g, node_to_idx, idx_to_node, edge_type_to_idx, idx_to_edge_type# convert medical_kg to DGL graph pyg_graph, node_to_idx, idx_to_node, edge_type_to_idx, idx_to_edge_type = convert_nx_to_pyg(medical_kg)
这些视觉、文本和结构化表示被保存,以便模型可以将其融合用于相互作用预测。
# process drug data to create multi-modal representations
drug_data = []for idx, row in drug_df.iterrows(): if row['name'] in drug_entities and pd.notna(row.get('smiles')): # generate molecule image img = generate_molecule_image(row['smiles']) if img: img_path = f"data/drug_images/{row['drugbank_id']}.png" img.save(img_path) # Create text description description = create_drug_description(row) # Store drug information drug_data.append({ 'id': row['drugbank_id'], 'name': row['name'], 'smiles': row['smiles'], 'description': description, 'image_path': img_path })drug_data_df = pd.DataFrame(drug_data)
defpredict_interaction(model, drug1_name, drug2_name, drug_data_df, medical_kg): device = torch.device('cuda'iftorch.cuda.is_available()else'cpu') model = model.to(device) model.eval() # get drug indices drug1_idx = drug_data_df[drug_data_df['name'] == drug1_name].index[0] drug2_idx = drug_data_df[drug_data_df['name'] == drug2_name].index[0] # get drug data drug1_data = drug_data_df.iloc[drug1_idx] drug2_data = drug_data_df.iloc[drug2_idx] # prepare images transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) drug1_img = Image.open(drug1_data['image_path']).convert('RGB') drug1_img = transform(drug1_img).unsqueeze(0).to(device) drug1_text = [drug1_data['description']] drug2_img = Image.open(drug2_data['image_path']).convert('RGB') drug2_img = transform(drug2_img).unsqueeze(0).to(device) drug2_text = [drug2_data['description']] # get knowledge subgraph for the drug pair knowledge = retrieve_knowledge_subgraph(medical_kg, drug1_name, drug2_name) # make prediction withtorch.no_grad(): interaction_prob = model( drug1_img, drug1_text, drug1_name, drug2_img, drug2_text, drug2_name ) returninteraction_prob.item(), knowledgedef explain_interaction_prediction(drug1_name, drug2_name, probability, knowledge): explanation =f"KG-guided multimodal analysis for interaction between{drug1_name}and{drug2_name}:\n\n" # interpret the probability ifprobability >0.8: risk_level ="High" elifprobability >0.5: risk_level ="Moderate" else: risk_level ="Low" explanation +=f"Interaction Risk Level:{risk_level}(Probability:{probability:.2f})\n\n" # explain based on knowledge graph structure explanation +="Knowledge Graph Analysis:\n" ifknowledge['direct_interaction']: mechanism = knowledge['direct_interaction'].get('mechanism','unknown mechanism') explanation +=f"✓ Direct Connection: The knowledge graph contains a documented interaction between these drugs with{mechanism}.\n\n" ifknowledge['common_targets']: explanation +="✓ Common Target Nodes: These drugs connect to shared entities in the knowledge graph:\n" fortargetinknowledge['common_targets']: explanation +=f" -{target}\n" explanation +=" This graph structure suggests potential interaction through common binding sites or pathways.\n\n" ifknowledge['paths']andlen(knowledge['paths']) >0: explanation +="✓ Knowledge Graph Pathways: The model identified these connecting paths in the graph:\n" fori, pathinenumerate(knowledge['paths'][:3]): path_str =" → ".join(path) explanation +=f" - Path{i+1}:{path_str}\n" explanation +=" These graph structures guided the multimodal feature integration for prediction.\n\n" # focus on how KG structure guided the interpretation explanation +="Multimodal Integration Process:\n" explanation +=" - Knowledge graph structure determined which drug properties were most relevant\n" explanation +=" - Graph neural networks analyzed the local neighborhood of both drug nodes\n" explanation +=" - Node position in the graph guided the weighting of visual and textual features\n\n" # clinical implications (example - in a real system, this would be more comprehensive) ifprobability >0.5: explanation +="Clinical Recommendations (based on graph analysis):\n" explanation +=" - Consider alternative medications not connected in similar graph patterns\n" explanation +=" - If co-administration is necessary, monitor for interaction effects\n" explanation +=" - Review other drugs connected to the same nodes for potential complications\n" else: explanation +="Clinical Recommendations (based on graph analysis):\n" explanation +=" - Standard monitoring advised\n" explanation +=" - The knowledge graph structure suggests minimal interaction concerns\n" returnexplanation
drug_pair = ("Goserelin", "Desmopressin") prob, knowledge = predict_interaction(trained_model, drug_pair[0], drug_pair[1], drug_data_df, medical_kg)print(f"Predicted interaction probability between {drug_pair[0]} and {drug_pair[1]}: {prob:.4f}")print("\nKnowledge Graph Structure Analysis:") print(f"Direct connection: {knowledge['direct_interaction']}") print(f"Common target nodes: {knowledge['common_targets']}") print(f"Graph paths connecting drugs:") for path in knowledge['paths']: print(f" {' -> '.join(path)}")# visualize the subgraph for these drugs to show the KG-guided approach plt.figure(figsize=(12, 8)) subgraph_nodes = set([drug_pair[0], drug_pair[1]])
# add intermediate nodes in paths to highlight the KG structure
for path in knowledge['paths']: subgraph_nodes.update(path) # add a level of neighbors to show context in KG neighbors_to_add = set() for node in subgraph_nodes: if node in medical_kg: neighbors_to_add.update(list(medical_kg.neighbors(node))[:3]) subgraph_nodes.update(neighbors_to_add)subgraph = medical_kg.subgraph(subgraph_nodes)# use different colors for node types to emphasize KG structure node_colors = [] for node in subgraph.nodes(): if node == drug_pair[0] or node == drug_pair[1]: node_colors.append('lightcoral') elif subgraph.nodes[node].get('type') == 'protein': node_colors.append('lightblue') elif subgraph.nodes[node].get('type') == 'disease': node_colors.append('lightgreen') else: node_colors.append('lightgray')pos = nx.spring_layout(subgraph, seed=42) nx.draw(subgraph, pos, with_labels=True, node_color=node_colors, node_size=2000, arrows=True, arrowsize=20)edge_labels = {(s, o): subgraph[s][o]['relation'] for s, o in subgraph.edges()} nx.draw_networkx_edge_labels(subgraph, pos, edge_labels=edge_labels)plt.title(f"Knowledge Graph Structure Guiding {drug_pair[0]} and {drug_pair[1]} Interaction Analysis") plt.savefig('kg_guided_interaction_analysis.png') plt.show()# show explanation explanation = explain_interaction_prediction(drug_pair[0], drug_pair[1], prob, knowledge) print(explanation)