qinfengge

qinfengge

醉后不知天在水,满船清梦压星河
github
email
telegram

spring AI (四) 連續對話

在之前的文章中,我們實現的都是簡單的調用,只能實現 1 次對話。這既不符合現實也不優雅,哪有人對話只對一句的啊。除了下面那個

image

要是 AI 也只能對話一次,那我們也只能說 AI 大爺您先一邊涼快去吧
那怎麼讓模型連續對話呢,重點是 記憶 把用戶的提問記住,同時也把模型自己的輸出也記住,這樣,模型才能根據 上文 得出合理的回答。

Prompt#

還記得第一章說的嗎

事實上,Prompt 的種類很多,玩法也很多樣,不僅僅是提示詞,同樣也是多輪對話的關鍵。

在創建 Prompt 中可以看到其可以接收 2 種參數,一種是單條 message,還有就是 message 集合。

image
而 message 中的 MessageType 又有下面 4 中類型

image

這不就對上了

    USER("user"),  // 用戶的輸入

	ASSISTANT("assistant"), // 模型的輸出

	SYSTEM("system"), // 模型的人設

	FUNCTION("function"); //函數

想像一下,在現實中你和某個人對話,你說一句他說一句,而且他話一定要跟 前文 有對應關係,否則的話就是驢唇不對馬嘴了。那么在模型上連續對話的關鍵和這個是一樣的,即在每次對話時將前文傳遞給模型,使其理解對應的上下文關係。這就是message 集合的作用。

實現#

簡單的原理以及說完了,接下來就直接開搞。不過我們不能忘記前面幾篇文章所作出的努力,要融會貫通。所以接下來就實現一個完整的功能,包含流式輸出函數調用連續對話

首先,初始化客戶端

    private static final String BASEURL = "https://xxx";

    private static final String TOKEN = "sk-xxxx";

    /**
     * 創建OpenAiChatClient
     * @return OpenAiChatClient
     */
    private static OpenAiChatClient getClient(){
        OpenAiApi openAiApi = new OpenAiApi(BASEURL, TOKEN);
        return new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder()
                .withModel("gpt-3.5-turbo-1106")
                .withTemperature(0.8F)
                .build());
    }

在創建這一步需要注意的是,OpenAI 的一些老模型是不支持函數調用和流式輸出的,一些老模型的最大 token 也只有 4K。創建錯誤可能會導致 400 BAD REQUEST

接下來,保存歷史信息
先創建一個 Map

private static Map<String, List<Message>> chatMessage = new ConcurrentHashMap<>();

這裡 Map 的 key 對應的是會話 ID,value 就是歷史消息了。注意會話一定要有對應的唯一 ID,不然的話就會串台了。
然後在每次會話時傳入會話 ID 和用戶的輸入,將對應的輸入放到消息集合裡面

/**
     * 返回提示詞
     * @param message 用戶輸入的消息
     * @return Prompt
     */
    private List<Message> getMessages(String id, String message) {
        String systemPrompt = "{prompt}";
        SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);

        Message userMessage = new UserMessage(message);

        Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "you are a helpful AI assistant"));

        List<Message> messages = chatMessage.get(id);


        // 如果未獲取到消息,則創建新的消息並將系統提示和用戶輸入的消息添加到消息列表中
        if (messages == null){
            messages = new ArrayList<>();
            messages.add(systemMessage);
            messages.add(userMessage);
        } else {
            messages.add(userMessage);
        }

        return messages;
    }

這裡如果是第一輪對話 message 列表是空的話,還會把 systemMessage 也放進去,相當於給模型初始化一個人設。

然後,創建函數

/**
     * 初始化函數調用
     * @return ChatOptions
     */
    private ChatOptions initFunc(){
        return OpenAiChatOptions.builder().withFunctionCallbacks(List.of(
                FunctionCallbackWrapper.builder(new MockWeatherService()).withName("weather").withDescription("Get the weather in location").build(),
                FunctionCallbackWrapper.builder(new WbHotService()).withName("wbHot").withDescription("Get the hot list of Weibo").build(),
                FunctionCallbackWrapper.builder(new TodayNews()).withName("todayNews").withDescription("60s watch world news").build(),
                FunctionCallbackWrapper.builder(new DailyEnglishFunc()).withName("dailyEnglish").withDescription("A daily inspirational sentence in English").build())).build();
    }

關於函數的相關信息,請查看第三章

最後,就是輸出
這裡說一下,因為最終的實現效果是網頁,所以使用了服務端主動推送的功能,這裡使用的是 SSE,關於 SSE 的介紹可以看之前博客裡面寫的消息推送
總之,下面是一個 SSE 的工具類

@Component
@Slf4j
public class SseEmitterUtils {
    /**
     * 當前連接數
     */
    private static AtomicInteger count = new AtomicInteger(0);

    /**
     * 存儲 SseEmitter 信息
     */
    private static Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();

    /**
     * 創建用戶連接並返回 SseEmitter
     * @param key userId
     * @return SseEmitter
     */
    public static SseEmitter connect(String key) {
        if (sseEmitterMap.containsKey(key)) {
            return sseEmitterMap.get(key);
        }

        try {
            // 設置超時時間,0表示不過期。默認30秒
            SseEmitter sseEmitter = new SseEmitter(0L);
            // 註冊回調
            sseEmitter.onCompletion(completionCallBack(key));
            sseEmitter.onError(errorCallBack(key));
            sseEmitter.onTimeout(timeoutCallBack(key));
            sseEmitterMap.put(key, sseEmitter);
            // 數量+1
            count.getAndIncrement();
            return sseEmitter;
        } catch (Exception e) {
            log.info("創建新的SSE連接異常,當前連接Key為:{}", key);
        }
        return null;
    }

    /**
     * 給指定用戶發送消息
     * @param key userId
     * @param message 消息內容
     */
    public static void sendMessage(String key, String message) {
        if (sseEmitterMap.containsKey(key)) {
            try {
                sseEmitterMap.get(key).send(message);
            } catch (IOException e) {
                log.error("用戶[{}]推送異常:{}", key, e.getMessage());
                remove(key);
            }
        }
    }

    /**
     * 向同組人發布消息,要求:key + groupId
     * @param groupId 群組id
     * @param message 消息內容
     */
    public static void groupSendMessage(String groupId, String message) {
        if (!CollectionUtils.isEmpty(sseEmitterMap)) {
            sseEmitterMap.forEach((k, v) -> {
                try {
                    if (k.startsWith(groupId)) {
                        v.send(message, MediaType.APPLICATION_JSON);
                    }
                } catch (IOException e) {
                    log.error("用戶[{}]推送異常:{}", k, e.getMessage());
                    remove(k);
                }
            });
        }
    }

    /**
     * 廣播群發消息
     * @param message 消息內容
     */
    public static void batchSendMessage(String message) {
        sseEmitterMap.forEach((k, v) -> {
            try {
                v.send(message, MediaType.APPLICATION_JSON);
            } catch (IOException e) {
                log.error("用戶[{}]推送異常:{}", k, e.getMessage());
                remove(k);
            }
        });
    }

    /**
     * 群發消息
     * @param message 消息內容
     * @param ids 用戶id集合
     */
    public static void batchSendMessage(String message, Set<String> ids) {
        ids.forEach(userId -> sendMessage(userId, message));
    }

    /**
     * 移除連接
     * @param key userId
     */
    public static void remove(String key) {
        sseEmitterMap.remove(key);
        // 數量-1
        count.getAndDecrement();
        log.info("移除連接:{}", key);
    }

    /**
     * 獲取當前連接信息
     * @return Map
     */
    public static List<String> getIds() {
        return new ArrayList<>(sseEmitterMap.keySet());
    }

    /**
     * 獲取當前連接數量
     * @return int
     */
    public static int getCount() {
        return count.intValue();
    }

    private static Runnable completionCallBack(String key) {
        return () -> {
            log.info("結束連接:{}", key);
            remove(key);
        };
    }

    private static Runnable timeoutCallBack(String key) {
        return () -> {
            log.info("連接超時:{}", key);
            remove(key);
        };
    }

    private static Consumer<Throwable> errorCallBack(String key) {
        return throwable -> {
            log.info("連接異常:{}", key);
            remove(key);
        };
    }
}

接下來就可以對話了嗎,不,如果是網頁,我們還要思考具體的實現。主流的 AI 模型的網頁端主要有這 2 個方面

  1. 快速提問,用戶可以在首頁直接提問
  2. 保存對話信息,每輪對話都是唯一的,用戶可隨時返回某一個對話

具體實現就是

  1. 創建一個接口,用戶訪問首頁時調用此接口並返回會話 ID
  2. 接下來的用戶輸入都綁定在第一步返回的會話 ID 上,除非刷新瀏覽器或創建新的會話

那麼第一步就是

 /**
     * 創建連接
     */
    @SneakyThrows
    @GetMapping("/init/{message}")
    public String init() {
        return String.valueOf(UUID.randomUUID());
    }

直接返回 UUID 給前端

最後有了會話 ID 就可以綁定到會話並輸出了。

@GetMapping("chat/{id}/{message}")
    public SseEmitter chat(@PathVariable String id, @PathVariable String message, HttpServletResponse response) {

        response.setHeader("Content-type", "text/html;charset=UTF-8");
        response.setCharacterEncoding("UTF-8");

        OpenAiChatClient client = getClient();
        SseEmitter emitter = SseEmitterUtils.connect(id);
        List<Message> messages = getMessages(id, message);
        System.err.println("chatMessage大小: " + messages.size());
        System.err.println("chatMessage: " + chatMessage);

        if (messages.size() > MAX_MESSAGE){
            SseEmitterUtils.sendMessage(id, "對話次數過多,請稍後重試🤔");
        }else {
            // 獲取模型的輸出流
            Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));

            // 把流裡面的消息使用SSE發送
            Mono<String> result = stream
                    .flatMap(it -> {
                        StringBuilder sb = new StringBuilder();
                        String content = it.getResult().getOutput().getContent();
                        Optional.ofNullable(content).ifPresent(r -> {
                            SseEmitterUtils.sendMessage(id, content);
                            sb.append(content);
                        });
                        return Mono.just(sb.toString());
                    })
                    // 將消息拼接成字符串
                    .reduce((a, b) -> a + b)
                    .defaultIfEmpty("");

            // 將消息存儲到chatMessage中的AssistantMessage
            result.subscribe(finalContent -> messages.add(new AssistantMessage(finalContent)));

            // 將消息存儲到chatMessage中
            chatMessage.put(id, messages);

        }
        return emitter;

    }

首先使用 response 將返回編碼設置為 UTF-8 防止亂碼
然後使用 SseEmitterUtils連接到對應的會話
接著使用 getMessages 返回獲取對應會話的歷史消息
然後使用 MAX_MESSAGE 對會話輪數進行了判斷,如果大於這裡的值則不再調用模型輸出,這裡主要是降低成本

private static final Integer MAX_MESSAGE = 10;

這裡寫的是 10 輪,其實是 5 輪對話,因為是用歷史消息的 size 判斷的,而歷史消息裡面是包含用戶的輸入和模型的輸出的,所以要除以 2.

chatMessage: {e2578f9e-8d71-4531-a6af-400a80fb6569=[SystemMessage{content='you are a helpful AI assistant', properties={}, messageType=SYSTEM}, UserMessage{content='你好呀', properties={}, messageType=USER}, AssistantMessage{content='你好!需要我的幫助嗎?', properties={}, messageType=ASSISTANT}, UserMessage{content='你是誰啊', properties={}, messageType=USER}]}

最後就是模型的輸出

Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));

使用了 stream 流,然後 Prompt 傳入歷史消息和函數
獲取到輸出流後使用 SseEmitterUtils.sendMessage(id, content); 把流裡面的內容發送到對應的會話啦。

還有最後一步,我們要把模型的輸出也放到歷史消息裡面,要讓模型知道以及回答過的不用再次回答。
如果不放進去,那麼模型會把用戶前面的所有輸入全部回答。
例如,第一輪問 “介紹下杭州”,此時 AI 的回答是正常的
第二輪接著問 “杭州有哪些著名的景點”,此時 AI 不知道上一輪是否回答過了,所以它傾向於同時回答這 2 個問題,即 “介紹下杭州,杭州有哪些著名的景點”。
第三輪、第四輪同樣會同樣回答。

那麼怎麼從流裡面獲取完整的輸出呢?
最開始我使用 stream.subscribeStringBuilder 將流裡面的內容追加到 sb 裡面,但是 sb 總是為 null。問了 claude 才知道 Flux 是異步的,最後使用了 Mono 進行處理。

在這段代碼中,我們在 flatMap 的回調函數中創建了一個新的 StringBuilder 實例 sb。然後,我們將每個響應的內容追加到 sb 中,並返回一個 Mono 發射 sb.toString () 的結果。
接下來,我們使用 reduce 操作符將所有的 Mono 合併成一個 Mono。reduce 的參數是一個合併函數,它將前一個值和當前值合併成一個新值。在這裡,我們使用字符串連接操作將所有響應內容拼接在一起。
最後,我們訂閱這個 Mono, 並在其回調函數中將最終內容添加到 messages 中。如果沒有任何響應,我們使用 defaultIfEmpty ("") 確保發射一個空字符串,而不是 null。
通過這種方式,我們可以正確地獲取到流式響應的全部內容,並將其添加到 messages 中。

最後就大功告成了😎

哦,忘了,還差一個前端,但是由於我對前端不甚了解。
所以我選擇使用openui這個 AI 工具來幫我寫頁面和樣式,然後用 Claude 幫我寫接口邏輯。於是,我最終得到了這個

<!doctype html>
<html>

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <script src="https://cdn.tailwindcss.com"></script>
</head>

<body class="bg-zinc-100 dark:bg-zinc-800 min-h-screen p-4">
    <div class="flex flex-col h-full">
        <div id="messages" class="flex-1 overflow-y-auto p-4 space-y-4">
            <div class="flex items-end">
                <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
                <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs">嗨~(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄</div>
            </div>
        </div>
        <div class="p-2">
            <input type="text" id="messageInput" placeholder="請輸入消息..."
                class="w-full p-2 rounded-lg border-2 border-zinc-300 dark:border-zinc-600 focus:outline-none focus:border-blue-500 dark:focus:border-blue-400">
            <button onclick="sendMessage()"
                class="mt-2 w-full bg-blue-500 hover:bg-blue-600 dark:bg-blue-600 dark:hover:bg-blue-700 text-white p-2 rounded-lg">發送</button>
        </div>
    </div>
    <script>
        let sessionId; // 用於存儲會話 ID

        // 發送 HTTP 請求並處理響應
        function sendHTTPRequest(url, method = 'GET', body = null) {
            return new Promise((resolve, reject) => {
                const xhr = new XMLHttpRequest();
                xhr.open(method, url, true);
                xhr.onload = () => {
                    if (xhr.status >= 200 && xhr.status < 300) {
                        resolve(xhr.response);
                    } else {
                        reject(xhr.statusText);
                    }
                };
                xhr.onerror = () => reject(xhr.statusText);
                if (body) {
                    xhr.setRequestHeader('Content-Type', 'application/json');
                    xhr.send(JSON.stringify(body));
                } else {
                    xhr.send();
                }
            });
        }

        // 處理服務器返回的 SSE 流
        function handleSSEStream(stream) {
            console.log('Stream started');
            console.log(stream);
            const messagesContainer = document.getElementById('messages');
            const responseDiv = document.createElement('div');
            responseDiv.className = 'flex items-end';
            responseDiv.innerHTML = `
    <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
    <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs"></div>
  `;
            messagesContainer.appendChild(responseDiv);

            const messageContentDiv = responseDiv.querySelector('div');

            // 監聽 'message' 事件,當後端發送新的數據時觸發
            stream.onmessage = function (event) {
                const data = event.data;
                console.log('Received data:', data);
                messageContentDiv.textContent += data;
                messagesContainer.scrollTop = messagesContainer.scrollHeight;
            };
        }

        // 發送消息
        function sendMessage() {
            const input = document.getElementById('messageInput');
            const message = input.value.trim();
            if (message) {
                const messagesContainer = document.getElementById('messages');
                const newMessageDiv = document.createElement('div');
                newMessageDiv.className = 'flex items-end justify-end';
                newMessageDiv.innerHTML = `
          <div class="mr-2 p-2 bg-green-200 dark:bg-green-700 rounded-lg max-w-xs">
            ${message}
          </div>
          <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
        `;
                messagesContainer.appendChild(newMessageDiv);
                input.value = '';
                messagesContainer.scrollTop = messagesContainer.scrollHeight;

                // 第一次發送消息時,發送 init 請求獲取會話 ID
                if (!this.sessionId) {
                    console.log('init');
                    sendHTTPRequest(`http://127.0.0.1:8868/pro/init/${message}`, 'GET')
                        .then(response => {
                            this.sessionId = response; // 存儲會話 ID
                            return handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                        });

                } else {
                    // 之後的請求直接發送到 chat 接口
                    handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                }
            }
        }
    </script>
</body>

</html>

最終效果#

image

PS: 其實前端可以再優化一下,比如顯示歷史會話,使用 markdown 渲染輸出等。有興趣的可以使用 AI 工具修改下。

Spring AI 連續對話

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。