qinfengge

qinfengge

醉后不知天在水,满船清梦压星河
github
email
telegram

spring AI (4) Continuous Conversation

In previous articles, we implemented simple calls that could only achieve a single conversation. This is neither realistic nor elegant; who has a conversation that consists of just one sentence? Except for the one below

image

If AI can only have one conversation, then we can only say Dear AI, you can go cool off for a while.
So how can we make the model have continuous conversations? The key is memory—to remember the user's questions while also remembering the model's own outputs. This way, the model can provide reasonable answers based on the previous context.

Prompt#

Do you remember what was said in the first chapter?

In fact, there are many types of Prompts, and they can be played in various ways. They are not just keywords; they are also key to multi-turn conversations.

In creating a Prompt, you can see that it can accept two types of parameters: one is a single message, and the other is a collection of messages.

image
The MessageType in the message has the following four types:

image

Doesn't this match up?

    USER("user"),  // User's input

	ASSISTANT("assistant"), // Model's output

	SYSTEM("system"), // Model's persona

	FUNCTION("function"); // Function

Imagine, in reality, when you have a conversation with someone, you say one sentence, and they respond with another. Their response must correspond to the previous context; otherwise, it would be nonsensical. Therefore, the key to continuous conversation in the model is the same: to pass the previous context to the model during each conversation, allowing it to understand the corresponding contextual relationship. This is the role of the message collection.

Implementation#

The simple principle has been explained, so let's get started. However, we cannot forget the efforts made in the previous articles; we need to integrate them. So next, we will implement a complete function that includes streaming output, function calls, and continuous conversation.

First, initialize the client:

    private static final String BASEURL = "https://xxx";

    private static final String TOKEN = "sk-xxxx";

    /**
     * Create OpenAiChatClient
     * @return OpenAiChatClient
     */
    private static OpenAiChatClient getClient(){
        OpenAiApi openAiApi = new OpenAiApi(BASEURL, TOKEN);
        return new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder()
                .withModel("gpt-3.5-turbo-1106")
                .withTemperature(0.8F)
                .build());
    }

It is important to note that some older models from OpenAI do not support function calls and streaming output, and some older models have a maximum token limit of only 4K. An error in creation may lead to a 400 BAD REQUEST.

Next, save historical information. First, create a Map:

private static Map<String, List<Message>> chatMessage = new ConcurrentHashMap<>();

Here, the key of the Map corresponds to the session ID, and the value is the historical messages. Note that the session must have a corresponding unique ID; otherwise, it will get mixed up.
Then, during each conversation, pass in the session ID and the user's input, placing the corresponding input into the message collection:

/**
     * Return prompt
     * @param message User's input message
     * @return Prompt
     */
    private List<Message> getMessages(String id, String message) {
        String systemPrompt = "{prompt}";
        SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);

        Message userMessage = new UserMessage(message);

        Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "you are a helpful AI assistant"));

        List<Message> messages = chatMessage.get(id);

        // If no messages are retrieved, create new messages and add the system prompt and user input to the message list
        if (messages == null){
            messages = new ArrayList<>();
            messages.add(systemMessage);
            messages.add(userMessage);
        } else {
            messages.add(userMessage);
        }

        return messages;
    }

If it is the first round of conversation and the message list is empty, it will also add the systemMessage, effectively initializing a persona for the model.

Then, create functions:

/**
     * Initialize function calls
     * @return ChatOptions
     */
    private ChatOptions initFunc(){
        return OpenAiChatOptions.builder().withFunctionCallbacks(List.of(
                FunctionCallbackWrapper.builder(new MockWeatherService()).withName("weather").withDescription("Get the weather in location").build(),
                FunctionCallbackWrapper.builder(new WbHotService()).withName("wbHot").withDescription("Get the hot list of Weibo").build(),
                FunctionCallbackWrapper.builder(new TodayNews()).withName("todayNews").withDescription("60s watch world news").build(),
                FunctionCallbackWrapper.builder(new DailyEnglishFunc()).withName("dailyEnglish").withDescription("A daily inspirational sentence in English").build())).build();
    }

For information related to functions, please refer to Chapter 3.

Finally, it's time for the output. Here, it should be mentioned that since the final implementation effect is a webpage, we used the server's active push feature, specifically SSE. For an introduction to SSE, you can refer to the previous blog post on message pushing.
In short, here is an SSE utility class:

@Component
@Slf4j
public class SseEmitterUtils {
    /**
     * Current connection count
     */
    private static AtomicInteger count = new AtomicInteger(0);

    /**
     * Store SseEmitter information
     */
    private static Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();

    /**
     * Create user connection and return SseEmitter
     * @param key userId
     * @return SseEmitter
     */
    public static SseEmitter connect(String key) {
        if (sseEmitterMap.containsKey(key)) {
            return sseEmitterMap.get(key);
        }

        try {
            // Set timeout, 0 means no expiration. Default is 30 seconds
            SseEmitter sseEmitter = new SseEmitter(0L);
            // Register callbacks
            sseEmitter.onCompletion(completionCallBack(key));
            sseEmitter.onError(errorCallBack(key));
            sseEmitter.onTimeout(timeoutCallBack(key));
            sseEmitterMap.put(key, sseEmitter);
            // Increment count
            count.getAndIncrement();
            return sseEmitter;
        } catch (Exception e) {
            log.info("Error creating new SSE connection, current connection key: {}", key);
        }
        return null;
    }

    /**
     * Send message to specified user
     * @param key userId
     * @param message Message content
     */
    public static void sendMessage(String key, String message) {
        if (sseEmitterMap.containsKey(key)) {
            try {
                sseEmitterMap.get(key).send(message);
            } catch (IOException e) {
                log.error("User[{}] push error:{}", key, e.getMessage());
                remove(key);
            }
        }
    }

    /**
     * Publish message to a group of people, requires: key + groupId
     * @param groupId Group ID
     * @param message Message content
     */
    public static void groupSendMessage(String groupId, String message) {
        if (!CollectionUtils.isEmpty(sseEmitterMap)) {
            sseEmitterMap.forEach((k, v) -> {
                try {
                    if (k.startsWith(groupId)) {
                        v.send(message, MediaType.APPLICATION_JSON);
                    }
                } catch (IOException e) {
                    log.error("User[{}] push error:{}", k, e.getMessage());
                    remove(k);
                }
            });
        }
    }

    /**
     * Broadcast message
     * @param message Message content
     */
    public static void batchSendMessage(String message) {
        sseEmitterMap.forEach((k, v) -> {
            try {
                v.send(message, MediaType.APPLICATION_JSON);
            } catch (IOException e) {
                log.error("User[{}] push error:{}", k, e.getMessage());
                remove(k);
            }
        });
    }

    /**
     * Group send message
     * @param message Message content
     * @param ids User ID collection
     */
    public static void batchSendMessage(String message, Set<String> ids) {
        ids.forEach(userId -> sendMessage(userId, message));
    }

    /**
     * Remove connection
     * @param key userId
     */
    public static void remove(String key) {
        sseEmitterMap.remove(key);
        // Decrement count
        count.getAndDecrement();
        log.info("Removed connection: {}", key);
    }

    /**
     * Get current connection information
     * @return Map
     */
    public static List<String> getIds() {
        return new ArrayList<>(sseEmitterMap.keySet());
    }

    /**
     * Get current connection count
     * @return int
     */
    public static int getCount() {
        return count.intValue();
    }

    private static Runnable completionCallBack(String key) {
        return () -> {
            log.info("Connection ended: {}", key);
            remove(key);
        };
    }

    private static Runnable timeoutCallBack(String key) {
        return () -> {
            log.info("Connection timed out: {}", key);
            remove(key);
        };
    }

    private static Consumer<Throwable> errorCallBack(String key) {
        return throwable -> {
            log.info("Connection error: {}", key);
            remove(key);
        };
    }
}

So can we have a conversation now? No, if it’s a webpage, we still need to think about the specific implementation. The main aspects of the web version of mainstream AI models are:

  1. Quick questioning, allowing users to ask questions directly on the homepage.
  2. Saving conversation information, ensuring each round of conversation is unique, and users can return to any specific conversation at any time.

The specific implementation is:

  1. Create an interface that returns a session ID when the user accesses the homepage.
  2. Subsequent user inputs are bound to the session ID returned in the first step, unless the browser is refreshed or a new session is created.

So the first step is:

 /**
     * Create connection
     */
    @SneakyThrows
    @GetMapping("/init/{message}")
    public String init() {
        return String.valueOf(UUID.randomUUID());
    }

Directly return a UUID to the front end.

Finally, with the session ID, we can bind it to the conversation and output:

@GetMapping("chat/{id}/{message}")
    public SseEmitter chat(@PathVariable String id, @PathVariable String message, HttpServletResponse response) {

        response.setHeader("Content-type", "text/html;charset=UTF-8");
        response.setCharacterEncoding("UTF-8");

        OpenAiChatClient client = getClient();
        SseEmitter emitter = SseEmitterUtils.connect(id);
        List<Message> messages = getMessages(id, message);
        System.err.println("chatMessage size: " + messages.size());
        System.err.println("chatMessage: " + chatMessage);

        if (messages.size() > MAX_MESSAGE){
            SseEmitterUtils.sendMessage(id, "Too many conversation rounds, please try again later🤔");
        } else {
            // Get the model's output stream
            Flux<ChatResponse> stream = client.stream(new Prompt(messages, initFunc()));

            // Send messages from the stream using SSE
            Mono<String> result = stream
                    .flatMap(it -> {
                        StringBuilder sb = new StringBuilder();
                        String content = it.getResult().getOutput().getContent();
                        Optional.ofNullable(content).ifPresent(r -> {
                            SseEmitterUtils.sendMessage(id, content);
                            sb.append(content);
                        });
                        return Mono.just(sb.toString());
                    })
                    // Concatenate messages into a string
                    .reduce((a, b) -> a + b)
                    .defaultIfEmpty("");

            // Store the messages in chatMessage as AssistantMessage
            result.subscribe(finalContent -> messages.add(new AssistantMessage(finalContent)));

            // Store the messages in chatMessage
            chatMessage.put(id, messages);
        }
        return emitter;
    }

First, use response to set the return encoding to UTF-8 to prevent garbled text.
Then use SseEmitterUtils to connect to the corresponding session.
Next, use getMessages to return the historical messages for the corresponding session.
Then, using MAX_MESSAGE, we check the number of conversation rounds; if it exceeds this value, we will not call the model output again, mainly to reduce costs.

private static final Integer MAX_MESSAGE = 10;

Here it is set to 10 rounds, but it is actually 5 rounds of conversation because the size of historical messages includes both user inputs and model outputs, so it should be divided by 2.

chatMessage: {e2578f9e-8d71-4531-a6af-400a80fb6569=[SystemMessage{content='you are a helpful AI assistant', properties={}, messageType=SYSTEM}, UserMessage{content='Hello', properties={}, messageType=USER}, AssistantMessage{content='Hello! How can I assist you?', properties={}, messageType=ASSISTANT}, UserMessage{content='Who are you?', properties={}, messageType=USER}]}

Finally, it’s time for the model's output:

Flux<ChatResponse> stream = client.stream(new Prompt(messages, initFunc()));

Using the stream, the Prompt is passed in with historical messages and functions.
After obtaining the output stream, we use SseEmitterUtils.sendMessage(id, content); to send the content from the stream to the corresponding session.

And the last step, we need to put the model's output into the historical messages so that the model knows what it has already answered and does not need to answer again.
If we do not include it, the model will try to answer all previous user inputs.
For example, in the first round, if you ask "Introduce Hangzhou," the AI's response is normal.
In the second round, if you ask "What are the famous attractions in Hangzhou?", the AI will not know if it answered the previous question, so it tends to answer both questions at once, saying "Introduce Hangzhou, what are the famous attractions in Hangzhou?".
The same goes for the third and fourth rounds.

So how do we get the complete output from the stream?
Initially, I used stream.subscribe and StringBuilder to append the content from the stream to sb, but sb was always null. After consulting Claude, I learned that Flux is asynchronous; thus, I used Mono for processing in the end.

In this code, we create a new StringBuilder instance sb in the flatMap callback function. Then, we append the content of each response to sb and return a Mono that emits the result of sb.toString().
Next, we use the reduce operator to merge all Mono into a single Mono. The parameter of reduce is a merging function that combines the previous value and the current value into a new value. Here, we use string concatenation to combine all response contents.
Finally, we subscribe to this Mono and in its callback function, we add the final content to messages. If there are no responses, we use defaultIfEmpty("") to ensure that an empty string is emitted instead of null.
Through this method, we can correctly obtain all the content of the streaming response and add it to messages.

Finally, we are done! 😎

Oh, I almost forgot, there's still the front end, but since I am not very familiar with front-end development, I chose to use openui to help me write the page and styles, and then used Claude to help me write the interface logic. Thus, I finally got this:

<!doctype html>
<html>

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <script src="https://cdn.tailwindcss.com"></script>
</head>

<body class="bg-zinc-100 dark:bg-zinc-800 min-h-screen p-4">
    <div class="flex flex-col h-full">
        <div id="messages" class="flex-1 overflow-y-auto p-4 space-y-4">
            <div class="flex items-end">
                <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
                <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs">Hi~(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄</div>
            </div>
        </div>
        <div class="p-2">
            <input type="text" id="messageInput" placeholder="Please enter a message..."
                class="w-full p-2 rounded-lg border-2 border-zinc-300 dark:border-zinc-600 focus:outline-none focus:border-blue-500 dark:focus:border-blue-400">
            <button onclick="sendMessage()"
                class="mt-2 w-full bg-blue-500 hover:bg-blue-600 dark:bg-blue-600 dark:hover:bg-blue-700 text-white p-2 rounded-lg">Send</button>
        </div>
    </div>
    <script>
        let sessionId; // Used to store session ID

        // Send HTTP request and handle response
        function sendHTTPRequest(url, method = 'GET', body = null) {
            return new Promise((resolve, reject) => {
                const xhr = new XMLHttpRequest();
                xhr.open(method, url, true);
                xhr.onload = () => {
                    if (xhr.status >= 200 && xhr.status < 300) {
                        resolve(xhr.response);
                    } else {
                        reject(xhr.statusText);
                    }
                };
                xhr.onerror = () => reject(xhr.statusText);
                if (body) {
                    xhr.setRequestHeader('Content-Type', 'application/json');
                    xhr.send(JSON.stringify(body));
                } else {
                    xhr.send();
                }
            });
        }

        // Handle SSE stream returned by the server
        function handleSSEStream(stream) {
            console.log('Stream started');
            console.log(stream);
            const messagesContainer = document.getElementById('messages');
            const responseDiv = document.createElement('div');
            responseDiv.className = 'flex items-end';
            responseDiv.innerHTML = `
    <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
    <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs"></div>
  `;
            messagesContainer.appendChild(responseDiv);

            const messageContentDiv = responseDiv.querySelector('div');

            // Listen for 'message' events, triggered when the backend sends new data
            stream.onmessage = function (event) {
                const data = event.data;
                console.log('Received data:', data);
                messageContentDiv.textContent += data;
                messagesContainer.scrollTop = messagesContainer.scrollHeight;
            };
        }

        // Send message
        function sendMessage() {
            const input = document.getElementById('messageInput');
            const message = input.value.trim();
            if (message) {
                const messagesContainer = document.getElementById('messages');
                const newMessageDiv = document.createElement('div');
                newMessageDiv.className = 'flex items-end justify-end';
                newMessageDiv.innerHTML = `
          <div class="mr-2 p-2 bg-green-200 dark:bg-green-700 rounded-lg max-w-xs">
            ${message}
          </div>
          <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
        `;
                messagesContainer.appendChild(newMessageDiv);
                input.value = '';
                messagesContainer.scrollTop = messagesContainer.scrollHeight;

                // When sending the message for the first time, send an init request to get the session ID
                if (!this.sessionId) {
                    console.log('init');
                    sendHTTPRequest(`http://127.0.0.1:8868/pro/init/${message}`, 'GET')
                        .then(response => {
                            this.sessionId = response; // Store session ID
                            return handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                        });

                } else {
                    // Subsequent requests are sent directly to the chat interface
                    handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                }
            }
        }
    </script>
</body>

</html>

Final Effect#

image

PS: The front end can actually be optimized further, such as displaying historical conversations and using markdown to render outputs. If you're interested, you can use AI tools to modify it.

Spring AI Continuous Conversation

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.