Skip to content

Commit 8e9005b

Browse files
committed
Add visualization methods
1 parent f3a0459 commit 8e9005b

File tree

5 files changed

+308
-3
lines changed

5 files changed

+308
-3
lines changed

‎.gitignore‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ __pycache__
33
dickens/
44
book.txt
55
lightrag-dev/
6-
.idea/
6+
.idea/
7+
dist/

‎README.md‎

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
2222
</div>
2323

2424
## 🎉 News
25+
- [x] [2024.10.20]🎯🎯📢📢We add two methods to visualize the graph.
2526
- [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
2627
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
2728
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
@@ -221,7 +222,11 @@ with open("./newText.txt") as f:
221222

222223
### Graph Visualization
223224

224-
* Generate html file
225+
<details>
226+
<summary> Graph visualization with html </summary>
227+
228+
* The following code can be found in `examples/graph_visual_with_html.py`
229+
225230
```python
226231
import networkx as nx
227232
from pyvis.network import Network
@@ -238,6 +243,137 @@ net.from_nx(G)
238243
# Save and display the network
239244
net.show('knowledge_graph.html')
240245
```
246+
247+
</details>
248+
249+
<details>
250+
<summary> Graph visualization with Neo4j </summary>
251+
252+
* The following code can be found in `examples/graph_visual_with_neo4j.py`
253+
254+
```python
255+
import os
256+
import json
257+
from lightrag.utils import xml_to_json
258+
from neo4j import GraphDatabase
259+
260+
# Constants
261+
WORKING_DIR = "./dickens"
262+
BATCH_SIZE_NODES = 500
263+
BATCH_SIZE_EDGES = 100
264+
265+
# Neo4j connection credentials
266+
NEO4J_URI = "bolt://localhost:7687"
267+
NEO4J_USERNAME = "neo4j"
268+
NEO4J_PASSWORD = "your_password"
269+
270+
def convert_xml_to_json(xml_path, output_path):
271+
"""Converts XML file to JSON and saves the output."""
272+
if not os.path.exists(xml_path):
273+
print(f"Error: File not found - {xml_path}")
274+
return None
275+
276+
json_data = xml_to_json(xml_path)
277+
if json_data:
278+
with open(output_path, 'w', encoding='utf-8') as f:
279+
json.dump(json_data, f, ensure_ascii=False, indent=2)
280+
print(f"JSON file created: {output_path}")
281+
return json_data
282+
else:
283+
print("Failed to create JSON data")
284+
return None
285+
286+
def process_in_batches(tx, query, data, batch_size):
287+
"""Process data in batches and execute the given query."""
288+
for i in range(0, len(data), batch_size):
289+
batch = data[i:i + batch_size]
290+
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
291+
292+
def main():
293+
# Paths
294+
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
295+
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
296+
297+
# Convert XML to JSON
298+
json_data = convert_xml_to_json(xml_file, json_file)
299+
if json_data is None:
300+
return
301+
302+
# Load nodes and edges
303+
nodes = json_data.get('nodes', [])
304+
edges = json_data.get('edges', [])
305+
306+
# Neo4j queries
307+
create_nodes_query = """
308+
UNWIND $nodes AS node
309+
MERGE (e:Entity {id: node.id})
310+
SET e.entity_type = node.entity_type,
311+
e.description = node.description,
312+
e.source_id = node.source_id,
313+
e.displayName = node.id
314+
REMOVE e:Entity
315+
WITH e, node
316+
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
317+
RETURN count(*)
318+
"""
319+
320+
create_edges_query = """
321+
UNWIND $edges AS edge
322+
MATCH (source {id: edge.source})
323+
MATCH (target {id: edge.target})
324+
WITH source, target, edge,
325+
CASE
326+
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
327+
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
328+
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
329+
WHEN edge.keywords CONTAINS 'located' THEN 'located'
330+
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
331+
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
332+
END AS relType
333+
CALL apoc.create.relationship(source, relType, {
334+
weight: edge.weight,
335+
description: edge.description,
336+
keywords: edge.keywords,
337+
source_id: edge.source_id
338+
}, target) YIELD rel
339+
RETURN count(*)
340+
"""
341+
342+
set_displayname_and_labels_query = """
343+
MATCH (n)
344+
SET n.displayName = n.id
345+
WITH n
346+
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
347+
RETURN count(*)
348+
"""
349+
350+
# Create a Neo4j driver
351+
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
352+
353+
try:
354+
# Execute queries in batches
355+
with driver.session() as session:
356+
# Insert nodes in batches
357+
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
358+
359+
# Insert edges in batches
360+
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
361+
362+
# Set displayName and labels
363+
session.run(set_displayname_and_labels_query)
364+
365+
except Exception as e:
366+
print(f"Error occurred: {e}")
367+
368+
finally:
369+
driver.close()
370+
371+
if __name__ == "__main__":
372+
main()
373+
```
374+
375+
</details>
376+
241377
## Evaluation
242378
### Dataset
243379
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
@@ -484,8 +620,9 @@ def extract_queries(file_path):
484620
.
485621
├── examples
486622
│ ├── batch_eval.py
623+
│ ├── graph_visual_with_html.py
624+
│ ├── graph_visual_with_neo4j.py
487625
│ ├── generate_query.py
488-
│ ├── graph_visual.py
489626
│ ├── lightrag_azure_openai_demo.py
490627
│ ├── lightrag_bedrock_demo.py
491628
│ ├── lightrag_hf_demo.py
File renamed without changes.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import os
2+
import json
3+
from lightrag.utils import xml_to_json
4+
from neo4j import GraphDatabase
5+
6+
# Constants
7+
WORKING_DIR = "./dickens"
8+
BATCH_SIZE_NODES = 500
9+
BATCH_SIZE_EDGES = 100
10+
11+
# Neo4j connection credentials
12+
NEO4J_URI = "bolt://localhost:7687"
13+
NEO4J_USERNAME = "neo4j"
14+
NEO4J_PASSWORD = "your_password"
15+
16+
def convert_xml_to_json(xml_path, output_path):
17+
"""Converts XML file to JSON and saves the output."""
18+
if not os.path.exists(xml_path):
19+
print(f"Error: File not found - {xml_path}")
20+
return None
21+
22+
json_data = xml_to_json(xml_path)
23+
if json_data:
24+
with open(output_path, 'w', encoding='utf-8') as f:
25+
json.dump(json_data, f, ensure_ascii=False, indent=2)
26+
print(f"JSON file created: {output_path}")
27+
return json_data
28+
else:
29+
print("Failed to create JSON data")
30+
return None
31+
32+
def process_in_batches(tx, query, data, batch_size):
33+
"""Process data in batches and execute the given query."""
34+
for i in range(0, len(data), batch_size):
35+
batch = data[i:i + batch_size]
36+
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
37+
38+
def main():
39+
# Paths
40+
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
41+
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
42+
43+
# Convert XML to JSON
44+
json_data = convert_xml_to_json(xml_file, json_file)
45+
if json_data is None:
46+
return
47+
48+
# Load nodes and edges
49+
nodes = json_data.get('nodes', [])
50+
edges = json_data.get('edges', [])
51+
52+
# Neo4j queries
53+
create_nodes_query = """
54+
UNWIND $nodes AS node
55+
MERGE (e:Entity {id: node.id})
56+
SET e.entity_type = node.entity_type,
57+
e.description = node.description,
58+
e.source_id = node.source_id,
59+
e.displayName = node.id
60+
REMOVE e:Entity
61+
WITH e, node
62+
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
63+
RETURN count(*)
64+
"""
65+
66+
create_edges_query = """
67+
UNWIND $edges AS edge
68+
MATCH (source {id: edge.source})
69+
MATCH (target {id: edge.target})
70+
WITH source, target, edge,
71+
CASE
72+
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
73+
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
74+
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
75+
WHEN edge.keywords CONTAINS 'located' THEN 'located'
76+
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
77+
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
78+
END AS relType
79+
CALL apoc.create.relationship(source, relType, {
80+
weight: edge.weight,
81+
description: edge.description,
82+
keywords: edge.keywords,
83+
source_id: edge.source_id
84+
}, target) YIELD rel
85+
RETURN count(*)
86+
"""
87+
88+
set_displayname_and_labels_query = """
89+
MATCH (n)
90+
SET n.displayName = n.id
91+
WITH n
92+
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
93+
RETURN count(*)
94+
"""
95+
96+
# Create a Neo4j driver
97+
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
98+
99+
try:
100+
# Execute queries in batches
101+
with driver.session() as session:
102+
# Insert nodes in batches
103+
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
104+
105+
# Insert edges in batches
106+
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
107+
108+
# Set displayName and labels
109+
session.run(set_displayname_and_labels_query)
110+
111+
except Exception as e:
112+
print(f"Error occurred: {e}")
113+
114+
finally:
115+
driver.close()
116+
117+
if __name__ == "__main__":
118+
main()

‎lightrag/utils.py‎

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from functools import wraps
99
from hashlib import md5
1010
from typing import Any, Union
11+
import xml.etree.ElementTree as ET
1112

1213
import numpy as np
1314
import tiktoken
@@ -183,3 +184,51 @@ def list_of_list_to_csv(data: list[list]):
183184
def save_data_to_file(data, file_name):
184185
with open(file_name, "w", encoding="utf-8") as f:
185186
json.dump(data, f, ensure_ascii=False, indent=4)
187+
188+
def xml_to_json(xml_file):
189+
try:
190+
tree = ET.parse(xml_file)
191+
root = tree.getroot()
192+
193+
# Print the root element's tag and attributes to confirm the file has been correctly loaded
194+
print(f"Root element: {root.tag}")
195+
print(f"Root attributes: {root.attrib}")
196+
197+
data = {
198+
"nodes": [],
199+
"edges": []
200+
}
201+
202+
# Use namespace
203+
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
204+
205+
for node in root.findall('.//node', namespace):
206+
node_data = {
207+
"id": node.get('id').strip('"'),
208+
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
209+
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
210+
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
211+
}
212+
data["nodes"].append(node_data)
213+
214+
for edge in root.findall('.//edge', namespace):
215+
edge_data = {
216+
"source": edge.get('source').strip('"'),
217+
"target": edge.get('target').strip('"'),
218+
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
219+
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
220+
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
221+
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
222+
}
223+
data["edges"].append(edge_data)
224+
225+
# Print the number of nodes and edges found
226+
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
227+
228+
return data
229+
except ET.ParseError as e:
230+
print(f"Error parsing XML file: {e}")
231+
return None
232+
except Exception as e:
233+
print(f"An error occurred: {e}")
234+
return None

0 commit comments

Comments
 (0)