鸽了一段时间,spring ai 已经出到 1.0 了,和 0.8.1 相比差别还是挺大的。正好最近时间很是宽裕可以折腾一下。
项目配置#
初始化项目的配置有 2 种方式,一种是在创建时直接选择对应依赖
另一种就是手动配置了
在 maven
中加入
<repositories>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>spring-snapshots</id>
<name>Spring Snapshots</name>
<url>https://repo.spring.io/snapshot</url>
<releases>
<enabled>false</enabled>
</releases>
</repository>
</repositories>
接着添加 Dependency Management
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>1.0.0-SNAPSHOT</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
最后加入对应大语言模型的依赖
<!-- OpenAI依赖 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<!-- Ollama依赖 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
接着,编写配置文件
spring:
ai:
ollama:
base-url: http://127.0.0.1:11434/
chat:
model: qwen2:7b
openai:
base-url: https://xxx
api-key: sk-xxx
chat:
options:
model: gpt-3.5-turbo
server:
port: 8868
我在配置文件里面配置了 2 个模型,一个是 ollama
的 一个是 openai
的,其他的模型可以自己看文档配置。
调用#
在 1.0 版本中调用方式有所改变,主要是实例化的对象有变化。
在最新版本中新增了一个 Chat Client API
, 当然上一个版本中的 Chat Model API
也还在。
他们的区别如下
api | 范围 | 作用 |
---|---|---|
Chat Client API | 适合单一模型,全局唯一。多模型配置会产生冲突 | 最顶层抽象,此 API 可调用所有模型,方便快速切换 |
Chat Model API | 单例模式,每个模型唯一 | 每个模型有其具体的实现 |
Chat Client#
因为 Chat Client 默认是全局唯一的,所以在配置文件中只可以配置单个模型,否则会在初始化创建 bean 时产生冲突
以下是官方的示例代码
@RestController
class MyController {
private final ChatClient chatClient;
public MyController(ChatClient.Builder chatClientBuilder) {
this.chatClient = chatClientBuilder.build();
}
@GetMapping("/ai")
String generation(String userInput) {
return this.chatClient.prompt()
.user(userInput)
.call()
.content();
}
}
同时在创建时还可以指定一些模型的默认参数
创建一个配置类
@Configuration
class Config {
@Bean
ChatClient chatClient(ChatClient.Builder builder) {
return builder.defaultSystem("You are a friendly chat bot that answers question in the voice of a Pirate")
.build();
}
}
在使用时使用 @Autowired
注入
要使用多模型的配置,需要关闭 ChatClient.Builder 的自动配置
spring:
ai:
chat:
client:
enabled: false
接着创建对应的配置文件,以 openai 为例
/**
* @author LiZhiAo
* @date 2024/6/19 20:47
*/
@Component
@RequiredArgsConstructor
public class OpenAiConfig {
private final OpenAiChatModel openAiChatModel;
public ChatClient openAiChatClient() {
ChatClient.Builder builder = ChatClient.builder(openAiChatModel);
builder.defaultSystem("你是一个友善的人工智能,会根据用户的提问进行回答");
return ChatClient.create(openAiChatModel);
}
}
然后就可以指定调用的模型了
// 注入
private final OpenAiConfig openAiConfig;
// 调用
Flux<ChatResponse> stream = openAiConfig.openAiChatClient().prompt(new Prompt(messages)).stream().chatResponse();
Chat Model#
每个模型拥有其对应的 Chat Model,同样根据配置文件自动装配
以 OpenAiChatModel
为例,通过源码可以看到装配过程。
所以调用也很简单
// 注入
private final OpenAiChatModel openAiChatModel;
// 调用
Flux<ChatResponse> stream = openAiChatModel.stream(new Prompt(messages));
qwen2#
在之前一段时间,我尝试使用 LM Studio
安装 llama3
并开启 Local Inference Server
进行调试。
遗憾的是,简单调用确实能够成功,但是在流式输出方面却总是出错
没办法,最后还是使用 ollama
+ Open WebUI
的方式开启的本地模型 API。
安装步骤#
安装环境以 Windows 电脑为例,且具备 NVIDIA 显卡,其余方式请查看 Open WebUI 的安装方法
- 安装 ollama (可选)
- 安装 Docker Desktop
- 运行镜像
如果已进行第 1 步,在电脑上安装 ollama如果你略过的第 1 步,则可以选择下面自带 ollama 的镜像docker run -d -p 3000:8080 --gpus all --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:cuda
docker run -d -p 3000:8080 --gpus=all -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
- 下载模型
容器运行后进入 web 管理页面进行模型的下载
以qwen2
为例,在模型拉取时输入qwen2:7b
下载 qwen2 的 7B 版本
在第 2,3 步运行时可能产生 CUDA 的问题
Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 500: named symbol not found
搜索发现 N 卡驱动版本在 555.85 可能发生
解决方法也很简单,把 Docker Desktop 更新到最新版即可。
实测下来 qwen2:7b
的中文回复要比 llama3:8b
的好非常多,剩下的缺点就是不支持多模态,不过好像开发团队已经在做了🎉
总结#
后端代码#
完整的 Controller 如下
@RestController
@RequestMapping("/llama3")
@CrossOrigin
@RequiredArgsConstructor
public class llama3Controller {
private final OpenAiConfig openAiConfig;
private final OllamaConfig ollamaConfig;
private static final Integer MAX_MESSAGE = 10;
private static Map<String, List<Message>> chatMessage = new ConcurrentHashMap<>();
/**
* 返回提示词
* @param message 用户输入的消息
* @return Prompt
*/
private List<Message> getMessages(String id, String message) {
String systemPrompt = "{prompt}";
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);
Message userMessage = new UserMessage(message);
Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "you are a helpful AI assistant"));
List<Message> messages = chatMessage.get(id);
// 如果未获取到消息,则创建新的消息并将系统提示和用户输入的消息添加到消息列表中
if (messages == null){
messages = new ArrayList<>();
messages.add(systemMessage);
messages.add(userMessage);
} else {
messages.add(userMessage);
}
return messages;
}
/**
* 创建连接
*/
@SneakyThrows
@GetMapping("/init/{message}")
public String init() {
return String.valueOf(UUID.randomUUID());
}
@GetMapping("chat/{id}/{message}")
public SseEmitter chat(@PathVariable String id, @PathVariable String message, HttpServletResponse response) {
response.setHeader("Content-type", "text/html;charset=UTF-8");
response.setCharacterEncoding("UTF-8");
SseEmitter emitter = SseEmitterUtils.connect(id);
List<Message> messages = getMessages(id, message);
System.err.println("chatMessage大小: " + messages.size());
System.err.println("chatMessage: " + chatMessage);
if (messages.size() > MAX_MESSAGE){
SseEmitterUtils.sendMessage(id, "对话次数过多,请稍后重试🤔");
}else {
// 获取模型的输出流
Flux<ChatResponse> stream = ollamaConfig.ollamaChatClient().prompt(new Prompt(messages)).stream().chatResponse();
// 把流里面的消息使用SSE发送
Mono<String> result = stream
.flatMap(it -> {
StringBuilder sb = new StringBuilder();
Optional.ofNullable(it.getResult().getOutput().getContent()).ifPresent(content -> {
SseEmitterUtils.sendMessage(id, content);
sb.append(content);
});
return Mono.just(sb.toString());
})
// 将消息拼接成字符串
.reduce((a, b) -> a + b)
.defaultIfEmpty("");
// 将消息存储到chatMessage中的AssistantMessage
result.subscribe(finalContent -> messages.add(new AssistantMessage(finalContent)));
// 将消息存储到chatMessage中
chatMessage.put(id, messages);
}
return emitter;
}
}
前端代码#
让 gpt 改了下前端页面,使其支持 MD 的渲染和代码高亮
<!doctype html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://cdn.tailwindcss.com"></script>
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.3.1/styles/default.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.3.1/highlight.min.js"></script>
</head>
<body class="bg-zinc-100 dark:bg-zinc-800 min-h-screen p-4">
<div class="flex flex-col h-full">
<div id="messages" class="flex-1 overflow-y-auto p-4 space-y-4">
<div class="flex items-end">
<img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
<div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg w-auto max-w-full">嗨~(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄</div>
</div>
</div>
<div class="p-2">
<input type="text" id="messageInput" placeholder="请输入消息..."
class="w-full p-2 rounded-lg border-2 border-zinc-300 dark:border-zinc-600 focus:outline-none focus:border-blue-500 dark:focus:border-blue-400">
<button onclick="sendMessage()"
class="mt-2 w-full bg-blue-500 hover:bg-blue-600 dark:bg-blue-600 dark:hover:bg-blue-700 text-white p-2 rounded-lg">发送</button>
</div>
</div>
<script>
let sessionId; // 用于存储会话 ID
let markdownBuffer = ''; // 缓冲区
// 初始化 marked 和 highlight.js
marked.setOptions({
highlight: function (code, lang) {
const language = hljs.getLanguage(lang) ? lang : 'plaintext';
return hljs.highlight(code, { language }).value;
}
});
// 发送 HTTP 请求并处理响应
function sendHTTPRequest(url, method = 'GET', body = null) {
return new Promise((resolve, reject) => {
const xhr = new XMLHttpRequest();
xhr.open(method, url, true);
xhr.onload = () => {
if (xhr.status >= 200 && xhr.status < 300) {
resolve(xhr.response);
} else {
reject(xhr.statusText);
}
};
xhr.onerror = () => reject(xhr.statusText);
if (body) {
xhr.setRequestHeader('Content-Type', 'application/json');
xhr.send(JSON.stringify(body));
} else {
xhr.send();
}
});
}
// 处理服务器返回的 SSE 流
function handleSSEStream(stream) {
console.log('Stream started');
const messagesContainer = document.getElementById('messages');
const responseDiv = document.createElement('div');
responseDiv.className = 'flex items-end';
responseDiv.innerHTML = `
<img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
<div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg w-auto max-w-full"></div>
`;
messagesContainer.appendChild(responseDiv);
const messageContentDiv = responseDiv.querySelector('div');
// 监听 'message' 事件,当后端发送新的数据时触发
stream.onmessage = function (event) {
const data = event.data;
console.log('Received data:', data);
// 将接收到的数据追加到缓冲区
markdownBuffer += data;
// 尝试将缓冲区解析为 Markdown 并显示
messageContentDiv.innerHTML = marked.parse(markdownBuffer);
// 使用 highlight.js 进行代码高亮
document.querySelectorAll('pre code').forEach((block) => {
hljs.highlightElement(block);
});
// 保持滚动条在底部
messagesContainer.scrollTop = messagesContainer.scrollHeight;
};
}
// 发送消息
function sendMessage() {
const input = document.getElementById('messageInput');
const message = input.value.trim();
if (message) {
const messagesContainer = document.getElementById('messages');
const newMessageDiv = document.createElement('div');
newMessageDiv.className = 'flex items-end justify-end';
newMessageDiv.innerHTML = `
<div class="mr-2 p-2 bg-green-200 dark:bg-green-700 rounded-lg max-w-xs">
${message}
</div>
<img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
`;
messagesContainer.appendChild(newMessageDiv);
input.value = '';
messagesContainer.scrollTop = messagesContainer.scrollHeight;
// 第一次发送消息时,发送 init 请求获取会话 ID
if (!this.sessionId) {
console.log('init');
sendHTTPRequest(`http://127.0.0.1:8868/llama3/init/${message}`, 'GET')
.then(response => {
this.sessionId = response; // 存储会话 ID
return handleSSEStream(new EventSource(`http://127.0.0.1:8868/llama3/chat/${this.sessionId}/${message}`))
});
} else {
// 之后的请求直接发送到 chat 接口
handleSSEStream(new EventSource(`http://127.0.0.1:8868/llama3/chat/${this.sessionId}/${message}`))
}
}
}
</script>
</body>
</html>
Spring AI
Open WebUI
2024 最新 Spring AI 零基础入门到精通教程(一套轻松搞定 AI 大模型应用开发)
上面视频对应文档(密码:wrp6)