跳到主要内容

3. 状态管理

"状态是 agent 的记忆。没有适当的状态管理,agents 无法学习、恢复或改进。"

状态管理是指捕获、持久化和恢复 agent 状态(State)跨执行的实践。它使 agents 能够从故障中恢复,在长时间运行的工作流中维护上下文,并在多个 agent 实例之间协调。


3.1 Agent 状态类型

状态分类

对话状态(Conversation State)

对话历史和上下文。

@Document(collection = "conversation_state")
public class ConversationState {
@Id
private String id;

private String agentId;
private String userId;
private String sessionId;

private List<Message> messages;
private Map<String, Object> context;
private Map<String, Object> metadata;

private Instant createdAt;
private Instant updatedAt;

@Getter
@Setter
public static class Message {
private String role; // user, assistant, system, tool
private String content;
private Instant timestamp;
private Map<String, Object> metadata;
}
}

@Service
public class ConversationStateManager {

@Autowired
private MongoTemplate mongoTemplate;

public ConversationState save(ConversationState state) {
state.setUpdatedAt(Instant.now());
return mongoTemplate.save(state);
}

public Optional<ConversationState> load(String sessionId) {
return Optional.ofNullable(
mongoTemplate.findById(
sessionId,
ConversationState.class
)
);
}

public void addMessage(
String sessionId,
Message message
) {
ConversationState state = load(sessionId)
.orElseGet(() -> createNew(sessionId));

state.getMessages().add(message);
save(state);
}

public List<Message> getRecentMessages(
String sessionId,
int limit
) {
return load(sessionId)
.map(ConversationState::getMessages)
.map(messages -> messages.stream()
.skip(Math.max(0, messages.size() - limit))
.toList())
.orElse(List.of());
}
}

任务状态(Task State)

当前任务的进度和状态。

@Document(collection = "task_state")
public class TaskState {
@Id
private String id;

private String agentId;
private String taskId;
private String userId;

private TaskStatus status;
private String currentStep;
private List<String> completedSteps;

private Map<String, Object> inputData;
private Map<String, Object> outputData;
private Map<String, Object> intermediateData;

private ErrorInfo lastError;
private int retryCount;

private Instant createdAt;
private Instant startedAt;
private Instant completedAt;

public enum TaskStatus {
PENDING,
IN_PROGRESS,
BLOCKED,
COMPLETED,
FAILED,
CANCELLED
}

@Getter
@Setter
public static class ErrorInfo {
private String code;
private String message;
private String stackTrace;
private Instant timestamp;
}
}

@Service
public class TaskStateManager {

@Autowired
private MongoTemplate mongoTemplate;

public TaskState create(TaskState state) {
state.setCreatedAt(Instant.now());
state.setStatus(TaskStatus.PENDING);
return mongoTemplate.save(state);
}

public TaskState update(
String taskId,
Consumer<TaskState> updater
) {
TaskState state = load(taskId);
updater.accept(state);
state.setUpdatedAt(Instant.now());
return mongoTemplate.save(state);
}

public void markInProgress(String taskId) {
update(taskId, state -> {
state.setStatus(TaskStatus.IN_PROGRESS);
if (state.getStartedAt() == null) {
state.setStartedAt(Instant.now());
}
});
}

public void markCompleted(
String taskId,
Map<String, Object> output
) {
update(taskId, state -> {
state.setStatus(TaskStatus.COMPLETED);
state.setCompletedAt(Instant.now());
state.setOutputData(output);
});
}

public void markFailed(
String taskId,
String error,
String stackTrace
) {
update(taskId, state -> {
state.setStatus(TaskStatus.FAILED);
state.setCompletedAt(Instant.now());

TaskState.ErrorInfo errorInfo = new TaskState.ErrorInfo();
errorInfo.setCode("TASK_FAILED");
errorInfo.setMessage(error);
errorInfo.setStackTrace(stackTrace);
errorInfo.setTimestamp(Instant.now());
state.setLastError(errorInfo);
});
}
}

记忆状态(Memory State)

学到的信息和知识。

@Document(collection = "memory_state")
public class MemoryState {
@Id
private String id;

private String agentId;
private String userId;

private Map<String, Object> facts; // Entity memories(实体记忆)
private List<Memory> episodicMemories; // Past experiences(过去经历)
private Map<String, Object> semanticMemory; // Knowledge(知识)

private Instant createdAt;
private Instant updatedAt;

@Getter
@Setter
public static class Memory {
private String id;
private String type;
private String content;
private Map<String, Object> metadata;
private Instant timestamp;
private double importance; // 0-1
}
}

@Service
public class MemoryStateManager {

@Autowired
private MongoTemplate mongoTemplate;

public void storeFact(
String agentId,
String key,
Object value
) {
Query query = Query.query(
Criteria.where("agentId").is(agentId)
);
Update update = new Update()
.set("facts." + key, value)
.set("updatedAt", Instant.now());

mongoTemplate.upsert(query, update, MemoryState.class);
}

public Optional<Object> getFact(
String agentId,
String key
) {
return Optional.ofNullable(
mongoTemplate.findOne(
Query.query(
Criteria.where("agentId").is(agentId)
),
MemoryState.class
)
).map(state -> state.getFacts().get(key));
}

public void addEpisodicMemory(
String agentId,
String content,
double importance
) {
Memory memory = new Memory();
memory.setId(UUID.randomUUID().toString());
memory.setType("episodic");
memory.setContent(content);
memory.setTimestamp(Instant.now());
memory.setImportance(importance);

Query query = Query.query(
Criteria.where("agentId").is(agentId)
);
Update update = new Update()
.push("episodicMemories", memory)
.set("updatedAt", Instant.now());

mongoTemplate.upsert(query, update, MemoryState.class);
}

public List<Memory> getImportantMemories(
String agentId,
double minImportance,
int limit
) {
return Optional.ofNullable(
mongoTemplate.findOne(
Query.query(
Criteria.where("agentId").is(agentId)
),
MemoryState.class
)
)
.map(MemoryState::getEpisodicMemories)
.map(memories -> memories.stream()
.filter(m -> m.getImportance() >= minImportance)
.sorted(Comparator
.comparing(Memory::getTimestamp)
.reversed())
.limit(limit)
.toList())
.orElse(List.of());
}
}

3.2 持久化策略

何时持久化

Redis 会话存储

用于活跃会话的快速、临时状态。

@Service
public class RedisStateStore {

@Autowired
private RedisTemplate<String, Object> redisTemplate;

private static final Duration SESSION_TTL = Duration.ofHours(24);

public void saveSession(
String sessionId,
AgentState state
) {
String key = "session:" + sessionId;
redisTemplate.opsForValue().set(
key,
serialize(state),
SESSION_TTL
);
}

public Optional<AgentState> loadSession(String sessionId) {
String key = "session:" + sessionId;
Object data = redisTemplate.opsForValue().get(key);
return Optional.ofNullable(data)
.map(this::deserialize);
}

public void deleteSession(String sessionId) {
String key = "session:" + sessionId;
redisTemplate.delete(key);
}

public void updateSession(
String sessionId,
Consumer<AgentState> updater
) {
AgentState state = loadSession(sessionId)
.orElse(new AgentState());
updater.accept(state);
saveSession(sessionId, state);
}

private byte[] serialize(AgentState state) {
// Serialize to JSON/MessagePack
try {
ObjectMapper mapper = new ObjectMapper();
return mapper.writeValueAsBytes(state);
} catch (Exception e) {
throw new RuntimeException("Serialization failed", e);
}
}

private AgentState deserialize(byte[] data) {
try {
ObjectMapper mapper = new ObjectMapper();
return mapper.readValue(data, AgentState.class);
} catch (Exception e) {
throw new RuntimeException("Deserialization failed", e);
}
}
}

PostgreSQL 持久化存储

用于长期保留的持久状态。

@Service
public class PostgreSQLStateStore {

@Autowired
private JdbcTemplate jdbcTemplate;

@Autowired
private ObjectMapper objectMapper;

public void saveState(AgentState state) {
String sql = """
INSERT INTO agent_state (
id, agent_id, user_id, session_id,
state_data, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (id) DO UPDATE SET
state_data = EXCLUDED.state_data,
updated_at = EXCLUDED.updated_at
""";

try {
String stateJson = objectMapper.writeValueAsString(state);

jdbcTemplate.update(sql,
state.getId(),
state.getAgentId(),
state.getUserId(),
state.getSessionId(),
stateJson,
state.getCreatedAt(),
Instant.now()
);

} catch (JsonProcessingException e) {
throw new RuntimeException("JSON serialization failed", e);
}
}

public Optional<AgentState> loadState(String stateId) {
String sql = "SELECT state_data FROM agent_state WHERE id = ?";

return jdbcTemplate.query(sql,
rs -> rs.next() ?
Optional.of(deserializeState(rs.getString("state_data"))) :
Optional.empty(),
stateId
);
}

public List<AgentState> loadStatesByAgent(String agentId) {
String sql = """
SELECT state_data FROM agent_state
WHERE agent_id = ?
ORDER BY updated_at DESC
""";

return jdbcTemplate.query(sql,
(rs, rowNum) -> deserializeState(rs.getString("state_data")),
agentId
);
}

private AgentState deserializeState(String json) {
try {
return objectMapper.readValue(json, AgentState.class);
} catch (JsonProcessingException e) {
throw new RuntimeException("JSON deserialization failed", e);
}
}
}

3.3 状态恢复

检查点(Checkpointing)

在长时间运行的任务中定期保存状态。

@Service
public class CheckpointService {

@Autowired
private RedisStateStore redisStore;

@Autowired
private PostgreSQLStateStore postgresStore;

private static final Duration CHECKPOINT_INTERVAL =
Duration.ofMinutes(5);

public void executeWithCheckpointing(
String taskId,
Consumer<CheckpointContext> task
) {
CheckpointContext context = new CheckpointContext(taskId);

// Load last checkpoint if exists
loadCheckpoint(context).ifPresent(checkpoint -> {
context.restore(checkpoint);
log.info("Resumed from checkpoint: {}", taskId);
});

Instant lastCheckpoint = Instant.now();

try {
while (!context.isComplete()) {
// Execute task step
task.accept(context);

// Check if we should checkpoint
if (shouldCheckpoint(lastCheckpoint)) {
saveCheckpoint(context);
lastCheckpoint = Instant.now();
}
}

// Final checkpoint
saveCheckpoint(context);

} catch (Exception e) {
// Checkpoint on error
saveCheckpoint(context);
throw e;
}
}

private boolean shouldCheckpoint(Instant lastCheckpoint) {
return Duration.between(
lastCheckpoint,
Instant.now()
).compareTo(CHECKPOINT_INTERVAL) > 0;
}

private void saveCheckpoint(CheckpointContext context) {
Checkpoint checkpoint = Checkpoint.builder()
.taskId(context.getTaskId())
.state(context.getState())
.progress(context.getProgress())
.timestamp(Instant.now())
.build();

// Save to Redis for fast recovery
redisStore.saveCheckpoint(checkpoint);

// Periodically save to PostgreSQL for durability
if (context.getCheckpointCount() % 5 == 0) {
postgresStore.saveCheckpoint(checkpoint);
}

context.incrementCheckpointCount();
}

private Optional<Checkpoint> loadCheckpoint(CheckpointContext context) {
// Try Redis first (fast)
Optional<Checkpoint> checkpoint =
redisStore.loadCheckpoint(context.getTaskId());

if (checkpoint.isEmpty()) {
// Fall back to PostgreSQL (durable)
checkpoint = postgresStore.loadCheckpoint(context.getTaskId());
}

return checkpoint;
}
}

从故障恢复

在故障后恢复状态并继续执行。

@Service
public class StateRecoveryService {

@Autowired
private TaskStateManager taskStateManager;

@Autowired
private CheckpointService checkpointService;

public AgentResult resumeOrExecute(
String taskId,
Supplier<AgentResult> task
) {
// Check if task exists
Optional<TaskState> existingTask =
taskStateManager.load(taskId);

if (existingTask.isEmpty()) {
// Execute new task
return task.get();
}

TaskState taskState = existingTask.get();

// Check status
if (taskState.getStatus() == TaskStatus.COMPLETED) {
log.info("Task already completed: {}", taskId);
return AgentResult.fromTaskState(taskState);
}

if (taskState.getStatus() == TaskStatus.FAILED) {
log.info("Resuming failed task: {}", taskId);

// Increment retry count
taskState.setRetryCount(
taskState.getRetryCount() + 1
);

// Check if exceeded max retries
if (taskState.getRetryCount() > 3) {
log.error("Max retries exceeded for task: {}", taskId);
return AgentResult.failed("Max retries exceeded");
}

// Restore checkpoint
restoreCheckpoint(taskId);

// Resume task
return task.get();
}

// Task is in progress, resume
log.info("Resuming in-progress task: {}", taskId);
return task.get();
}

private void restoreCheckpoint(String taskId) {
checkpointService.executeWithCheckpointing(taskId, context -> {
// Task execution logic here
// Context will be restored from checkpoint
});
}
}

3.4 分布式状态

多 Agent 协调

在多个 agent 实例之间共享状态。

@Service
public class DistributedStateService {

@Autowired
private RedisTemplate<String, Object> redisTemplate;

public void publishStateUpdate(
String agentId,
String key,
Object value
) {
String channel = "agent:" + agentId + ":state";

StateUpdate update = StateUpdate.builder()
.agentId(agentId)
.key(key)
.value(value)
.timestamp(Instant.now())
.build();

redisTemplate.convertAndSend(channel, update);

// Also store in shared state
String stateKey = "shared:agent:" + agentId + ":" + key;
redisTemplate.opsForValue().set(stateKey, value);
}

public void subscribeToStateUpdates(
String agentId,
Consumer<StateUpdate> handler
) {
String channel = "agent:" + agentId + ":state";

redisTemplate.getConnectionFactory()
.getConnection()
.subscribe(new MessageListener() {
@Override
public void onMessage(
Message message,
byte[] pattern
) {
try {
String body = new String(message.getBody());
StateUpdate update = objectMapper.readValue(
body,
StateUpdate.class
);
handler.accept(update);

} catch (Exception e) {
log.error("Failed to process state update", e);
}
}
}, channel);
}

public Object getSharedState(
String agentId,
String key
) {
String stateKey = "shared:agent:" + agentId + ":" + key;
return redisTemplate.opsForValue().get(stateKey);
}
}

分布式锁(Distributed Locks)

在多 Agent 场景中防止竞态条件。

@Service
public class DistributedLockService {

@Autowired
private StringRedisTemplate redisTemplate;

private static final Duration LOCK_TTL = Duration.ofSeconds(30);

public Optional<Lock> acquireLock(
String resourceId,
Duration timeout
) {
String lockKey = "lock:" + resourceId;
String lockValue = UUID.randomUUID().toString();

Boolean acquired = redisTemplate.opsForValue()
.setIfAbsent(
lockKey,
lockValue,
LOCK_TTL
);

if (Boolean.TRUE.equals(acquired)) {
return Optional.of(new Lock(lockKey, lockValue));
}

return Optional.empty();
}

public void releaseLock(Lock lock) {
// Lua script to ensure we only release our own lock
String script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
""";

redisTemplate.execute(
new DefaultRedisScript<>(script, Long.class),
Collections.singletonList(lock.getKey()),
lock.getValue()
);
}

@Getter
public static class Lock implements AutoCloseable {
private final String key;
private final String value;

@Override
public void close() {
// Release lock automatically
distributedLockService.releaseLock(this);
}
}
}

// Usage
@Service
public class AgentCoordinator {

@Autowired
private DistributedLockService lockService;

public void executeWithLock(
String resourceId,
Runnable operation
) {
Optional<Lock> lock = lockService.acquireLock(
resourceId,
Duration.ofSeconds(10)
);

if (lock.isEmpty()) {
throw new ConcurrentExecutionException(
"Resource is locked: " + resourceId
);
}

try (Lock acquired = lock.get()) {
operation.run();

} catch (Exception e) {
log.error("Operation failed", e);
throw e;
}
}
}

3.5 状态迁移

模式演进(Schema Evolution)

处理随时间变化的模式。

@Service
public class StateMigrationService {

private final List<StateMigration> migrations = new ArrayList<>();

@PostConstruct
public void init() {
// Register migrations in order
registerMigration(new Migration_1_2());
registerMigration(new Migration_2_3());
}

public AgentState migrate(AgentState oldState) {
int currentVersion = oldState.getSchemaVersion();
AgentState state = oldState;

for (StateMigration migration : migrations) {
if (migration.getFromVersion() == currentVersion) {
state = migration.migrate(state);
currentVersion = migration.getToVersion();
}
}

return state;
}

private static class Migration_1_2 implements StateMigration {
@Override
public int getFromVersion() { return 1; }

@Override
public int getToVersion() { return 2; }

@Override
public AgentState migrate(AgentState state) {
// Add new fields
state.setMetadata(new HashMap<>());

// Rename fields
if (state.containsKey("old_field")) {
state.put("new_field", state.get("old_field"));
state.remove("old_field");
}

state.setSchemaVersion(2);
return state;
}
}
}

3.6 关键要点

状态类型

类型用途存储持续时间
对话状态对话历史Redis小时
任务状态任务进度PostgreSQL
记忆状态学到的信息PostgreSQL永久

持久化策略

  • Redis: 快速、临时、会话状态
  • PostgreSQL: 持久、长期状态
  • S3: 归档、冷存储

恢复模式

执行 → 检查点 → 错误 → 恢复 → 继续

分布式状态

  • 发布/订阅: 实时更新
  • 共享状态: 共享数据
  • 分布式锁: 防止冲突

3.7 下一步

继续您的旅程:


频繁检查点

越频繁地检查点,失败时丢失的工作就越少。平衡检查点频率和开销。

测试恢复

不要假设状态恢复有效。通过在执行中途终止进程定期测试。

使用 TTL

在临时状态上设置 TTL 以防止内存泄漏。并非所有状态都需要永远存在。