鸽了一段時間,spring ai 已經出到 1.0 了,和 0.8.1 相比差別還是挺大的。正好最近時間很是寬裕可以折騰一下。
項目配置#
初始化項目的配置有 2 種方式,一種是在創建時直接選擇對應依賴
另一種就是手動配置了
在 maven
中加入
<repositories>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>spring-snapshots</id>
<name>Spring Snapshots</name>
<url>https://repo.spring.io/snapshot</url>
<releases>
<enabled>false</enabled>
</releases>
</repository>
</repositories>
接著添加 Dependency Management
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>1.0.0-SNAPSHOT</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
最後加入對應大語言模型的依賴
<!-- OpenAI依賴 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<!-- Ollama依賴 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
接著,編寫配置文件
spring:
ai:
ollama:
base-url: http://127.0.0.1:11434/
chat:
model: qwen2:7b
openai:
base-url: https://xxx
api-key: sk-xxx
chat:
options:
model: gpt-3.5-turbo
server:
port: 8868
我在配置文件裡面配置了 2 個模型,一個是 ollama
的 一個是 openai
的,其他的模型可以自己看文檔配置。
調用#
在 1.0 版本中調用方式有所改變,主要是實例化的對象有變化。
在最新版本中新增了一個 Chat Client API
, 當然上個版本中的 Chat Model API
也還在。
他們的區別如下
api | 範圍 | 作用 |
---|---|---|
Chat Client API | 適合單一模型,全局唯一。多模型配置會產生衝突 | 最頂層抽象,此 API 可調用所有模型,方便快速切換 |
Chat Model API | 單例模式,每個模型唯一 | 每個模型有其具體的實現 |
Chat Client#
因為 Chat Client 默認是全局唯一的,所以在配置文件中只可以配置單個模型,否則會在初始化創建 bean 時產生衝突
以下是官方的示例代碼
@RestController
class MyController {
private final ChatClient chatClient;
public MyController(ChatClient.Builder chatClientBuilder) {
this.chatClient = chatClientBuilder.build();
}
@GetMapping("/ai")
String generation(String userInput) {
return this.chatClient.prompt()
.user(userInput)
.call()
.content();
}
}
同時在創建時還可以指定一些模型的默認參數
創建一個配置類
@Configuration
class Config {
@Bean
ChatClient chatClient(ChatClient.Builder builder) {
return builder.defaultSystem("You are a friendly chat bot that answers question in the voice of a Pirate")
.build();
}
}
在使用時使用 @Autowired
注入
要使用多模型的配置,需要關閉 ChatClient.Builder 的自動配置
spring:
ai:
chat:
client:
enabled: false
接著創建對應的配置文件,以 openai 為例
/**
* @author LiZhiAo
* @date 2024/6/19 20:47
*/
@Component
@RequiredArgsConstructor
public class OpenAiConfig {
private final OpenAiChatModel openAiChatModel;
public ChatClient openAiChatClient() {
ChatClient.Builder builder = ChatClient.builder(openAiChatModel);
builder.defaultSystem("你是一个友善的人工智能,会根据用户的提问进行回答");
return ChatClient.create(openAiChatModel);
}
}
然後就可以指定調用的模型了
// 注入
private final OpenAiConfig openAiConfig;
// 調用
Flux<ChatResponse> stream = openAiConfig.openAiChatClient().prompt(new Prompt(messages)).stream().chatResponse();
Chat Model#
每個模型擁有其對應的 Chat Model,同樣根據配置文件自動裝配
以 OpenAiChatModel
為例,通過源碼可以看到裝配過程。
所以調用也很簡單
// 注入
private final OpenAiChatModel openAiChatModel;
// 調用
Flux<ChatResponse> stream = openAiChatModel.stream(new Prompt(messages));
qwen2#
在之前一段時間,我嘗試使用 LM Studio
安裝 llama3
並開啟 Local Inference Server
進行調試。
遺憾的是,簡單調用確實能夠成功,但是在流式輸出方面卻總是出錯
沒辦法,最後還是使用 ollama
+ Open WebUI
的方式開啟的本地模型 API。
安裝步驟#
安裝環境以 Windows 電腦為例,且具備 NVIDIA 顯卡,其餘方式請查看 Open WebUI 的安裝方法
- 安裝 ollama (可選)
- 安裝 Docker Desktop
- 運行鏡像
如果已進行第 1 步,在電腦上安裝 ollama如果你略過的第 1 步,則可以選擇下面自帶 ollama 的鏡像docker run -d -p 3000:8080 --gpus all --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:cuda
docker run -d -p 3000:8080 --gpus=all -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
- 下載模型
容器運行後進入 web 管理頁面進行模型的下載
以qwen2
為例,在模型拉取時輸入qwen2:7b
下載 qwen2 的 7B 版本
在第 2,3 步運行時可能產生 CUDA 的問題
Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 500: named symbol not found
搜索發現 N 卡驅動版本在 555.85 可能發生
解決方法也很簡單,把 Docker Desktop 更新到最新版即可。
實測下來 qwen2:7b
的中文回覆要比 llama3:8b
的好非常多,剩下的缺點就是不支持多模態,不過好像開發團隊已經在做了🎉
總結#
後端代碼#
完整的 Controller 如下
@RestController
@RequestMapping("/llama3")
@CrossOrigin
@RequiredArgsConstructor
public class llama3Controller {
private final OpenAiConfig openAiConfig;
private final OllamaConfig ollamaConfig;
private static final Integer MAX_MESSAGE = 10;
private static Map<String, List<Message>> chatMessage = new ConcurrentHashMap<>();
/**
* 返回提示詞
* @param message 用戶輸入的消息
* @return Prompt
*/
private List<Message> getMessages(String id, String message) {
String systemPrompt = "{prompt}";
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);
Message userMessage = new UserMessage(message);
Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "you are a helpful AI assistant"));
List<Message> messages = chatMessage.get(id);
// 如果未獲取到消息,則創建新的消息並將系統提示和用戶輸入的消息添加到消息列表中
if (messages == null){
messages = new ArrayList<>();
messages.add(systemMessage);
messages.add(userMessage);
} else {
messages.add(userMessage);
}
return messages;
}
/**
* 創建連接
*/
@SneakyThrows
@GetMapping("/init/{message}")
public String init() {
return String.valueOf(UUID.randomUUID());
}
@GetMapping("chat/{id}/{message}")
public SseEmitter chat(@PathVariable String id, @PathVariable String message, HttpServletResponse response) {
response.setHeader("Content-type", "text/html;charset=UTF-8");
response.setCharacterEncoding("UTF-8");
SseEmitter emitter = SseEmitterUtils.connect(id);
List<Message> messages = getMessages(id, message);
System.err.println("chatMessage大小: " + messages.size());
System.err.println("chatMessage: " + chatMessage);
if (messages.size() > MAX_MESSAGE){
SseEmitterUtils.sendMessage(id, "對話次數過多,請稍後重試🤔");
}else {
// 獲取模型的輸出流
Flux<ChatResponse> stream = ollamaConfig.ollamaChatClient().prompt(new Prompt(messages)).stream().chatResponse();
// 把流裡面的消息使用SSE發送
Mono<String> result = stream
.flatMap(it -> {
StringBuilder sb = new StringBuilder();
Optional.ofNullable(it.getResult().getOutput().getContent()).ifPresent(content -> {
SseEmitterUtils.sendMessage(id, content);
sb.append(content);
});
return Mono.just(sb.toString());
})
// 將消息拼接成字符串
.reduce((a, b) -> a + b)
.defaultIfEmpty("");
// 將消息存儲到chatMessage中的AssistantMessage
result.subscribe(finalContent -> messages.add(new AssistantMessage(finalContent)));
// 將消息存儲到chatMessage中
chatMessage.put(id, messages);
}
return emitter;
}
}
前端代碼#
讓 gpt 改了下前端頁面,使其支持 MD 的渲染和代碼高亮
<!doctype html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://cdn.tailwindcss.com"></script>
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.3.1/styles/default.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.3.1/highlight.min.js"></script>
</head>
<body class="bg-zinc-100 dark:bg-zinc-800 min-h-screen p-4">
<div class="flex flex-col h-full">
<div id="messages" class="flex-1 overflow-y-auto p-4 space-y-4">
<div class="flex items-end">
<img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
<div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg w-auto max-w-full">嗨~(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄</div>
</div>
</div>
<div class="p-2">
<input type="text" id="messageInput" placeholder="請輸入消息..."
class="w-full p-2 rounded-lg border-2 border-zinc-300 dark:border-zinc-600 focus:outline-none focus:border-blue-500 dark:focus:border-blue-400">
<button onclick="sendMessage()"
class="mt-2 w-full bg-blue-500 hover:bg-blue-600 dark:bg-blue-600 dark:hover:bg-blue-700 text-white p-2 rounded-lg">發送</button>
</div>
</div>
<script>
let sessionId; // 用於存儲會話 ID
let markdownBuffer = ''; // 緩衝區
// 初始化 marked 和 highlight.js
marked.setOptions({
highlight: function (code, lang) {
const language = hljs.getLanguage(lang) ? lang : 'plaintext';
return hljs.highlight(code, { language }).value;
}
});
// 發送 HTTP 請求並處理響應
function sendHTTPRequest(url, method = 'GET', body = null) {
return new Promise((resolve, reject) => {
const xhr = new XMLHttpRequest();
xhr.open(method, url, true);
xhr.onload = () => {
if (xhr.status >= 200 && xhr.status < 300) {
resolve(xhr.response);
} else {
reject(xhr.statusText);
}
};
xhr.onerror = () => reject(xhr.statusText);
if (body) {
xhr.setRequestHeader('Content-Type', 'application/json');
xhr.send(JSON.stringify(body));
} else {
xhr.send();
}
});
}
// 處理服務器返回的 SSE 流
function handleSSEStream(stream) {
console.log('Stream started');
const messagesContainer = document.getElementById('messages');
const responseDiv = document.createElement('div');
responseDiv.className = 'flex items-end';
responseDiv.innerHTML = `
<img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
<div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg w-auto max-w-full"></div>
`;
messagesContainer.appendChild(responseDiv);
const messageContentDiv = responseDiv.querySelector('div');
// 監聽 'message' 事件,當後端發送新的數據時觸發
stream.onmessage = function (event) {
const data = event.data;
console.log('Received data:', data);
// 將接收到的數據追加到緩衝區
markdownBuffer += data;
// 嘗試將緩衝區解析為 Markdown 並顯示
messageContentDiv.innerHTML = marked.parse(markdownBuffer);
// 使用 highlight.js 進行代碼高亮
document.querySelectorAll('pre code').forEach((block) => {
hljs.highlightElement(block);
});
// 保持滾動條在底部
messagesContainer.scrollTop = messagesContainer.scrollHeight;
};
}
// 發送消息
function sendMessage() {
const input = document.getElementById('messageInput');
const message = input.value.trim();
if (message) {
const messagesContainer = document.getElementById('messages');
const newMessageDiv = document.createElement('div');
newMessageDiv.className = 'flex items-end justify-end';
newMessageDiv.innerHTML = `
<div class="mr-2 p-2 bg-green-200 dark:bg-green-700 rounded-lg max-w-xs">
${message}
</div>
<img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
`;
messagesContainer.appendChild(newMessageDiv);
input.value = '';
messagesContainer.scrollTop = messagesContainer.scrollHeight;
// 第一次發送消息時,發送 init 請求獲取會話 ID
if (!this.sessionId) {
console.log('init');
sendHTTPRequest(`http://127.0.0.1:8868/llama3/init/${message}`, 'GET')
.then(response => {
this.sessionId = response; // 存儲會話 ID
return handleSSEStream(new EventSource(`http://127.0.0.1:8868/llama3/chat/${this.sessionId}/${message}`))
});
} else {
// 之後的請求直接發送到 chat 接口
handleSSEStream(new EventSource(`http://127.0.0.1:8868/llama3/chat/${this.sessionId}/${message}`))
}
}
}
</script>
</body>
</html>
Spring AI
Open WebUI
2024 最新 Spring AI 零基礎入門到精通教程(一套輕鬆搞定 AI 大模型應用開發)
上面視頻對應文檔(密碼:wrp6)