Skip to content

在 NestJS 中暴露 LangChain 流式 API

使用 @Sse() 装饰器返回 ReadableStream

在现代 Web 应用开发中,将 LangChain 集成到后端框架中是一个常见需求。NestJS 作为一个流行的 Node.js 框架,提供了强大的功能来构建可扩展的服务器端应用。通过 @Sse() 装饰器,NestJS 可以轻松暴露 LangChain 的流式 API,为前端提供实时响应能力。本章将深入探讨如何在 NestJS 中集成 LangChain 并暴露流式 API。

NestJS 基础集成

首先,让我们看看如何在 NestJS 中基本集成 LangChain:

typescript
// langchain.module.ts
import { Module } from '@nestjs/common';
import { LangchainService } from './langchain.service';
import { LangchainController } from './langchain.controller';

@Module({
  controllers: [LangchainController],
  providers: [LangchainService],
  exports: [LangchainService]
})
export class LangchainModule {}

// langchain.service.ts
import { Injectable } from '@nestjs/common';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { PromptTemplate } from 'langchain/prompts';
import { StringOutputParser } from 'langchain/schema/output_parser';

@Injectable()
export class LangchainService {
  private readonly llm: ChatOpenAI;
  
  constructor() {
    this.llm = new ChatOpenAI({
      modelName: 'gpt-3.5-turbo',
      temperature: 0.7,
    });
  }
  
  async generateText(prompt: string): Promise<string> {
    const promptTemplate = PromptTemplate.fromTemplate('{input}');
    const chain = promptTemplate.pipe(this.llm).pipe(new StringOutputParser());
    return await chain.invoke({ input: prompt });
  }
  
  async *streamText(prompt: string): AsyncGenerator<string> {
    const promptTemplate = PromptTemplate.fromTemplate('{input}');
    const chain = promptTemplate.pipe(this.llm).pipe(new StringOutputParser());
    const stream = await chain.stream({ input: prompt });
    
    for await (const chunk of stream) {
      yield chunk;
    }
  }
}

使用 @Sse() 装饰器实现流式 API

@Sse() 装饰器是 NestJS 中用于实现 Server-Sent Events 的关键工具,非常适合暴露 LangChain 的流式 API:

typescript
// langchain.controller.ts
import {
  Controller,
  Post,
  Body,
  Sse,
  MessageEvent,
  Logger,
} from '@nestjs/common';
import { LangchainService } from './langchain.service';
import { Observable, from } from 'rxjs';
import { map } from 'rxjs/operators';

interface GenerateTextDto {
  prompt: string;
}

interface StreamTextDto {
  prompt: string;
}

@Controller('langchain')
export class LangchainController {
  private readonly logger = new Logger(LangchainController.name);
  
  constructor(private readonly langchainService: LangchainService) {}
  
  @Post('generate')
  async generateText(@Body() body: GenerateTextDto): Promise<{ result: string }> {
    try {
      const result = await this.langchainService.generateText(body.prompt);
      return { result };
    } catch (error) {
      this.logger.error('文本生成失败', error);
      throw error;
    }
  }
  
  @Sse('stream')
  streamText(@Body() body: StreamTextDto): Observable<MessageEvent> {
    this.logger.log(`开始流式生成文本: ${body.prompt.substring(0, 50)}...`);
    
    // 将 AsyncGenerator 转换为 Observable
    return from(this.langchainService.streamText(body.prompt)).pipe(
      map((chunk: string) => {
        // 将每个块包装为 MessageEvent
        return {
          data: chunk,
          type: 'text-chunk',
        };
      }),
      // 错误处理
      // 注意:在实际应用中,你可能需要更复杂的错误处理
    );
  }
}

高级流式 API 实现

实现更高级的流式 API,支持多种数据类型和更好的错误处理:

typescript
// advanced-langchain.service.ts
import { Injectable, Logger } from '@nestjs/common';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { PromptTemplate } from 'langchain/prompts';
import { StringOutputParser } from 'langchain/schema/output_parser';
import { Document } from 'langchain/document';
import { MemoryVectorStore } from 'langchain/vectorstores/memory';
import { OpenAIEmbeddings } from 'langchain/embeddings/openai';

export interface StreamResponse {
  type: 'text' | 'thinking' | 'tool' | 'error' | 'done';
  content: any;
  timestamp: number;
}

@Injectable()
export class AdvancedLangchainService {
  private readonly logger = new Logger(AdvancedLangchainService.name);
  private readonly llm: ChatOpenAI;
  private readonly embeddings: OpenAIEmbeddings;
  private vectorStore: MemoryVectorStore;
  
  constructor() {
    this.llm = new ChatOpenAI({
      modelName: 'gpt-3.5-turbo',
      temperature: 0.7,
      streaming: true, // 启用流式处理
    });
    
    this.embeddings = new OpenAIEmbeddings();
    this.vectorStore = new MemoryVectorStore(this.embeddings);
  }
  
  async initializeVectorStore(documents: Document[]): Promise<void> {
    await this.vectorStore.addDocuments(documents);
    this.logger.log(`向量存储初始化完成,添加了 ${documents.length} 个文档`);
  }
  
  async *streamChat(
    messages: Array<{ role: string; content: string }>,
    options?: { useRAG?: boolean }
  ): AsyncGenerator<StreamResponse> {
    try {
      let context = '';
      
      // 如果启用 RAG,检索相关文档
      if (options?.useRAG && messages.length > 0) {
        const lastMessage = messages[messages.length - 1];
        if (lastMessage.role === 'user') {
          yield {
            type: 'thinking',
            content: '正在检索相关文档...',
            timestamp: Date.now(),
          };
          
          const docs = await this.vectorStore.similaritySearch(
            lastMessage.content,
            3
          );
          
          context = docs.map(doc => doc.pageContent).join('\n\n');
          
          yield {
            type: 'thinking',
            content: `找到 ${docs.length} 个相关文档`,
            timestamp: Date.now(),
          };
        }
      }
      
      // 构建提示
      let promptTemplate: PromptTemplate;
      if (context) {
        promptTemplate = PromptTemplate.fromTemplate(
          `基于以下文档回答问题:
          
文档:
{context}

对话历史:
{history}

用户问题: {question}

回答:`
        );
      } else {
        promptTemplate = PromptTemplate.fromTemplate(
          `对话历史:
{history}

用户问题: {question}

回答:`
        );
      }
      
      // 格式化对话历史
      const history = messages
        .filter(m => m.role !== 'assistant')
        .map(m => `${m.role}: ${m.content}`)
        .join('\n');
      
      const lastUserMessage = messages.findLast(m => m.role === 'user');
      const question = lastUserMessage?.content || '';
      
      const formattedPrompt = await promptTemplate.format({
        context,
        history,
        question,
      });
      
      // 流式生成响应
      const stream = await this.llm.stream(formattedPrompt);
      
      let fullResponse = '';
      for await (const chunk of stream) {
        const content = chunk.content as string;
        fullResponse += content;
        
        yield {
          type: 'text',
          content,
          timestamp: Date.now(),
        };
      }
      
      yield {
        type: 'done',
        content: fullResponse,
        timestamp: Date.now(),
      };
      
    } catch (error) {
      this.logger.error('流式聊天生成失败', error);
      yield {
        type: 'error',
        content: error.message || '生成响应时发生错误',
        timestamp: Date.now(),
      };
    }
  }
  
  async *streamDocumentQA(
    question: string,
    documentIds?: string[]
  ): AsyncGenerator<StreamResponse> {
    try {
      yield {
        type: 'thinking',
        content: '正在分析问题...',
        timestamp: Date.now(),
      };
      
      // 检索相关文档
      const docs = await this.vectorStore.similaritySearch(question, 5);
      
      yield {
        type: 'thinking',
        content: `检索到 ${docs.length} 个相关文档`,
        timestamp: Date.now(),
      };
      
      // 构建 RAG 提示
      const context = docs.map(doc => doc.pageContent).join('\n\n');
      
      const promptTemplate = PromptTemplate.fromTemplate(
        `基于以下文档回答问题。如果文档中没有相关信息,请说明无法回答。
        
文档:
{context}

问题: {question}

答案:`
      );
      
      const formattedPrompt = await promptTemplate.format({
        context,
        question,
      });
      
      // 流式生成答案
      const stream = await this.llm.stream(formattedPrompt);
      
      for await (const chunk of stream) {
        const content = chunk.content as string;
        yield {
          type: 'text',
          content,
          timestamp: Date.now(),
        };
      }
      
      // 提供源文档信息
      yield {
        type: 'done',
        content: {
          sources: docs.map((doc, index) => ({
            id: index,
            source: doc.metadata?.source || 'unknown',
            relevance: 'high', // 简化实现
          })),
        },
        timestamp: Date.now(),
      };
      
    } catch (error) {
      this.logger.error('文档问答流式生成失败', error);
      yield {
        type: 'error',
        content: error.message || '处理文档问答时发生错误',
        timestamp: Date.now(),
      };
    }
  }
}

高级控制器实现

实现支持多种流式功能的控制器:

typescript
// advanced-langchain.controller.ts
import {
  Controller,
  Post,
  Body,
  Sse,
  MessageEvent,
  Logger,
  Query,
  HttpCode,
  HttpStatus,
} from '@nestjs/common';
import { AdvancedLangchainService, StreamResponse } from './advanced-langchain.service';
import { Observable, from, concat } from 'rxjs';
import { map, catchError } from 'rxjs/operators';
import { throwError } from 'rxjs';

interface ChatMessage {
  role: string;
  content: string;
}

interface StreamChatDto {
  messages: ChatMessage[];
  useRAG?: boolean;
}

interface StreamDocumentQADto {
  question: string;
  documentIds?: string[];
}

@Controller('advanced-langchain')
export class AdvancedLangchainController {
  private readonly logger = new Logger(AdvancedLangchainController.name);
  
  constructor(private readonly langchainService: AdvancedLangchainService) {}
  
  @Post('initialize')
  @HttpCode(HttpStatus.OK)
  async initializeVectorStore(
    @Body() body: { documents: Array<{ content: string; metadata?: any }> }
  ): Promise<{ success: boolean; message: string }> {
    try {
      const documents = body.documents.map(doc => ({
        pageContent: doc.content,
        metadata: doc.metadata || {},
      }));
      
      await this.langchainService.initializeVectorStore(documents);
      
      return {
        success: true,
        message: `成功初始化向量存储,添加了 ${documents.length} 个文档`,
      };
    } catch (error) {
      this.logger.error('向量存储初始化失败', error);
      return {
        success: false,
        message: '向量存储初始化失败: ' + error.message,
      };
    }
  }
  
  @Sse('stream-chat')
  streamChat(
    @Body() body: StreamChatDto
  ): Observable<MessageEvent> {
    this.logger.log('开始流式聊天');
    
    // 验证输入
    if (!body.messages || body.messages.length === 0) {
      return throwError(() => new Error('消息列表不能为空'));
    }
    
    // 将 AsyncGenerator 转换为 Observable
    const stream = from(this.langchainService.streamChat(
      body.messages,
      { useRAG: body.useRAG }
    ));
    
    return stream.pipe(
      map((response: StreamResponse) => {
        return {
          data: JSON.stringify(response),
          type: response.type,
        };
      }),
      catchError((error) => {
        this.logger.error('流式聊天错误', error);
        return throwError(() => error);
      })
    );
  }
  
  @Sse('stream-document-qa')
  streamDocumentQA(
    @Body() body: StreamDocumentQADto
  ): Observable<MessageEvent> {
    this.logger.log(`开始文档问答: ${body.question.substring(0, 50)}...`);
    
    if (!body.question) {
      return throwError(() => new Error('问题不能为空'));
    }
    
    const stream = from(this.langchainService.streamDocumentQA(
      body.question,
      body.documentIds
    ));
    
    return stream.pipe(
      map((response: StreamResponse) => {
        return {
          data: JSON.stringify(response),
          type: response.type,
        };
      }),
      catchError((error) => {
        this.logger.error('文档问答流错误', error);
        return throwError(() => error);
      })
    );
  }
  
  // 混合端点:支持多种流式操作
  @Sse('stream')
  streamUnified(
    @Query('type') type: 'chat' | 'document-qa' | 'text-generation',
    @Body() body: any
  ): Observable<MessageEvent> {
    this.logger.log(`开始统一流式处理: ${type}`);
    
    let stream: AsyncGenerator<StreamResponse>;
    
    switch (type) {
      case 'chat':
        stream = this.langchainService.streamChat(
          body.messages,
          { useRAG: body.useRAG }
        );
        break;
        
      case 'document-qa':
        stream = this.langchainService.streamDocumentQA(
          body.question,
          body.documentIds
        );
        break;
        
      default:
        return throwError(() => new Error(`不支持的流式类型: ${type}`));
    }
    
    return from(stream).pipe(
      map((response: StreamResponse) => {
        return {
          data: JSON.stringify(response),
          type: response.type,
        };
      }),
      catchError((error) => {
        this.logger.error(`统一流式处理错误 (${type})`, error);
        return throwError(() => error);
      })
    );
  }
}

实际应用示例

让我们看一个完整的实际应用示例,展示如何在 NestJS 应用中集成和使用 LangChain 流式 API:

typescript
// main.ts
import { NestFactory } from '@nestjs/core';
import { AppModule } from './app.module';
import { Logger, ValidationPipe } from '@nestjs/common';

async function bootstrap() {
  const app = await NestFactory.create(AppModule);
  const logger = new Logger('Bootstrap');
  
  // 全局管道配置
  app.useGlobalPipes(new ValidationPipe({
    whitelist: true,
    forbidNonWhitelisted: true,
    transform: true,
  }));
  
  // CORS 配置
  app.enableCors({
    origin: process.env.FRONTEND_URL || 'http://localhost:3000',
    methods: 'GET,HEAD,PUT,PATCH,POST,DELETE,OPTIONS',
    credentials: true,
  });
  
  // SSE 配置
  app.set('sseHeaders', {
    'Cache-Control': 'no-cache',
    'Connection': 'keep-alive',
    'Content-Type': 'text/event-stream',
  });
  
  const port = process.env.PORT || 3001;
  await app.listen(port);
  
  logger.log(`应用正在端口 ${port} 上运行`);
}
bootstrap();

// app.module.ts
import { Module } from '@nestjs/common';
import { LangchainModule } from './langchain/langchain.module';
import { AdvancedLangchainModule } from './advanced-langchain/advanced-langchain.module';

@Module({
  imports: [
    LangchainModule,
    AdvancedLangchainModule,
  ],
})
export class AppModule {}

// 客户端使用示例 (前端代码)
/*
// 流式聊天客户端
async function streamChat(messages: Array<{ role: string; content: string }>) {
  const response = await fetch('http://localhost:3001/advanced-langchain/stream-chat', {
    method: 'POST',
    headers: {
      'Content-Type': 'application/json',
    },
    body: JSON.stringify({
      messages,
      useRAG: true,
    }),
  });
  
  if (!response.body) {
    throw new Error('响应体为空');
  }
  
  const reader = response.body.getReader();
  const decoder = new TextDecoder();
  
  try {
    while (true) {
      const { done, value } = await reader.read();
      
      if (done) {
        console.log('流完成');
        break;
      }
      
      const chunk = decoder.decode(value);
      
      // 解析 SSE 数据
      const lines = chunk.split('\n');
      for (const line of lines) {
        if (line.startsWith('data: ')) {
          const data = line.slice(6);
          const response = JSON.parse(data);
          
          switch (response.type) {
            case 'text':
              // 处理文本块
              console.log('文本块:', response.content);
              break;
              
            case 'thinking':
              // 处理思考信息
              console.log('思考:', response.content);
              break;
              
            case 'done':
              // 处理完成信息
              console.log('完成:', response.content);
              break;
              
            case 'error':
              // 处理错误
              console.error('错误:', response.content);
              break;
          }
        }
      }
    }
  } finally {
    reader.releaseLock();
  }
}

// 使用示例
streamChat([
  { role: 'user', content: '你好,能介绍一下 LangChain 吗?' }
]);
*/

错误处理和监控

实现完善的错误处理和监控机制:

typescript
// langchain.interceptor.ts
import {
  Injectable,
  NestInterceptor,
  ExecutionContext,
  CallHandler,
  Logger,
} from '@nestjs/common';
import { Observable } from 'rxjs';
import { tap, catchError } from 'rxjs/operators';

@Injectable()
export class LangchainInterceptor implements NestInterceptor {
  private readonly logger = new Logger(LangchainInterceptor.name);
  
  intercept(context: ExecutionContext, next: CallHandler): Observable<any> {
    const startTime = Date.now();
    const request = context.switchToHttp().getRequest();
    const response = context.switchToHttp().getResponse();
    
    this.logger.log(`开始处理请求: ${request.method} ${request.url}`);
    
    return next.handle().pipe(
      tap(() => {
        const duration = Date.now() - startTime;
        this.logger.log(`请求完成: ${request.method} ${request.url} (${duration}ms)`);
      }),
      catchError((error) => {
        const duration = Date.now() - startTime;
        this.logger.error(
          `请求失败: ${request.method} ${request.url} (${duration}ms)`,
          error.stack
        );
        throw error;
      })
    );
  }
}

// 在控制器中使用拦截器
// langchain.controller.ts
import { UseInterceptors } from '@nestjs/common';
import { LangchainInterceptor } from './langchain.interceptor';

@Controller('langchain')
@UseInterceptors(LangchainInterceptor)
export class LangchainController {
  // 控制器实现...
}

总结

通过在 NestJS 中使用 @Sse() 装饰器,我们可以轻松地暴露 LangChain 的流式 API:

  1. 基础集成 - 简单的文本生成和流式处理
  2. 高级功能 - 支持 RAG、多种数据类型和复杂交互
  3. 错误处理 - 完善的错误处理和恢复机制
  4. 监控支持 - 集成日志记录和性能监控
  5. 灵活配置 - 支持多种流式处理模式和参数配置

这种集成方式使得 LangChain 能够很好地融入现代 Web 应用架构,为前端提供实时、响应式的 AI 交互体验。

在下一章中,我们将探讨前端消费流式响应:Response.body.pipeTo(ReadableStreamDefaultReader),了解如何在前端处理和展示流式数据。