class TrajectoryHook: """Hook for automatically recording LangGraph execution."""
def __init__(self, recorder: TrajectoryRecorder): self.recorder = recorder self._session_id: Optional[str] = None
def wrap_node(self, node_name: str, node_func: Callable) -> Callable: """Wrap a node function to record its execution.""" @wraps(node_func) async def wrapped_node(state: Dict[str, Any]) -> Any: if not self._session_id: return await node_func(state)
# Record node start await self.recorder.record_event( self._session_id, node_name=node_name, event_type="node_start", data={"state_keys": list(state.keys()) if isinstance(state, dict) else None} )
try: # Execute node if asyncio.iscoroutinefunction(node_func): result = await node_func(state) else: result = node_func(state)
# Record node end await self.recorder.record_event( self._session_id, node_name=node_name, event_type="node_end", data={"has_result": result is not None} )
# Record messages if present if isinstance(result, dict) and "messages" in result: messages = result["messages"] if isinstance(messages, list): for msg in messages: if hasattr(msg, "content"): # 确保是消息对象 await self.recorder.record_message(self._session_id, msg)
# Record node output await self.recorder.record_node_output( self._session_id, node_name, result )
return result
except Exception as e: # Record error await self.recorder.record_error( self._session_id, error_type=type(e).__name__, error_message=str(e), node_name=node_name ) raise
return wrapped_node
async def __aenter__(self): """Start recording session.""" self._session_id = await self.recorder.start_session() return self
async def __aexit__(self, exc_type, exc_val, exc_tb): """End recording session.""" if self._session_id: success = exc_type is None await self.recorder.end_session(self._session_id, success=success) self._session_id = None