qinfengge

qinfengge

醉后不知天在水,满船清梦压星河
github
email
telegram

スプリングAI(四)連続対話

以前の記事では、私たちが実装したのは単純な呼び出しで、1 回の対話しか実現できませんでした。これは現実に合わず、優雅でもありません。誰が対話を 1 文だけで終えるのでしょうか。下の例外を除いて

image

もし AI も 1 回しか対話できないのなら、私たちも AI 大先生、あなたは一旦お引き取りください と言うしかありません。では、モデルに連続して対話させるにはどうすればよいのでしょうか。重要なのは 記憶 です。ユーザーの質問を覚え、同時にモデル自身の出力も記憶することです。そうすれば、モデルは 前文 に基づいて合理的な回答を導き出すことができます。

プロンプト#

第一章で言ったことを覚えていますか?

実際、プロンプトの種類は多く、遊び方も多様で、単なるヒントだけでなく、複数回の対話の鍵でもあります。

プロンプトの作成では、2 種類のパラメータを受け取ることができます。一つは単一のメッセージ、もう一つはメッセージの集合です。

image
メッセージ内の MessageType には、以下の 4 つのタイプがあります。

image

これが一致しましたね。

    USER("user"),  // ユーザーの入力

	ASSISTANT("assistant"), // モデルの出力

	SYSTEM("system"), // モデルのキャラクター

	FUNCTION("function"); // 関数

想像してみてください。現実で誰かと対話しているとき、あなたが一言言えば、相手も一言返します。そして、彼の言葉は必ず 前文 と関連性がなければなりません。そうでなければ、まるで馬の口と驢馬の口が合わないようなものです。したがって、モデルで連続して対話するための鍵も同じです。すなわち、毎回の対話で前文をモデルに渡し、対応する文脈関係を理解させることです。これが メッセージ集合 の役割です。

実装#

簡単な原理については話しましたので、次は直接実装に入ります。しかし、前の数記事での努力を忘れてはいけません。すべてを統合する必要があります。したがって、次は ストリーミング出力関数呼び出し連続対話 を含む完全な機能を実装します。

まず、クライアントの初期化

    private static final String BASEURL = "https://xxx";

    private static final String TOKEN = "sk-xxxx";

    /**
     * OpenAiChatClientを作成
     * @return OpenAiChatClient
     */
    private static OpenAiChatClient getClient(){
        OpenAiApi openAiApi = new OpenAiApi(BASEURL, TOKEN);
        return new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder()
                .withModel("gpt-3.5-turbo-1106")
                .withTemperature(0.8F)
                .build());
    }

この作成ステップでは、OpenAI のいくつかの古いモデルは関数呼び出しやストリーミング出力をサポートしていないことに注意が必要です。一部の古いモデルの最大トークン数は 4K しかありません。作成エラーは 400 BAD REQUEST を引き起こす可能性があります。

次に、履歴情報の保存
まず、Map を作成します。

private static Map<String, List<Message>> chatMessage = new ConcurrentHashMap<>();

ここで Map のキーはセッション ID に対応し、値は履歴メッセージです。セッションには必ず対応するユニーク ID が必要です。そうでなければ、混乱が生じます。
次に、毎回の会話でセッション ID とユーザーの入力を渡し、対応する入力をメッセージ集合に追加します。

/**
     * プロンプトを返す
     * @param message ユーザーが入力したメッセージ
     * @return プロンプト
     */
    private List<Message> getMessages(String id, String message) {
        String systemPrompt = "{prompt}";
        SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);

        Message userMessage = new UserMessage(message);

        Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "あなたは役に立つAIアシスタントです"));

        List<Message> messages = chatMessage.get(id);

        // メッセージが取得できなかった場合、新しいメッセージを作成し、システムプロンプトとユーザー入力メッセージをメッセージリストに追加します
        if (messages == null){
            messages = new ArrayList<>();
            messages.add(systemMessage);
            messages.add(userMessage);
        } else {
            messages.add(userMessage);
        }

        return messages;
    }

ここで、もし最初の対話でメッセージリストが空であれば、systemMessage も追加され、モデルにキャラクターを初期化することになります。

次に、関数の作成

/**
     * 関数呼び出しを初期化
     * @return ChatOptions
     */
    private ChatOptions initFunc(){
        return OpenAiChatOptions.builder().withFunctionCallbacks(List.of(
                FunctionCallbackWrapper.builder(new MockWeatherService()).withName("weather").withDescription("場所の天気を取得").build(),
                FunctionCallbackWrapper.builder(new WbHotService()).withName("wbHot").withDescription("Weiboのホットリストを取得").build(),
                FunctionCallbackWrapper.builder(new TodayNews()).withName("todayNews").withDescription("60秒で世界のニュースを視聴").build(),
                FunctionCallbackWrapper.builder(new DailyEnglishFunc()).withName("dailyEnglish").withDescription("英語の毎日のインスピレーションの文").build())).build();
    }

関数に関する詳細は第三章を参照してください。

最後に、出力です。
ここで言及しておくと、最終的な実装効果はウェブページであるため、サーバーが能動的にプッシュする機能を使用します。ここでは SSE を使用しており、SSE の紹介については以前のブログで書いたメッセージプッシュを参照してください。
要するに、以下は SSE のユーティリティクラスです。

@Component
@Slf4j
public class SseEmitterUtils {
    /**
     * 現在の接続数
     */
    private static AtomicInteger count = new AtomicInteger(0);

    /**
     * SseEmitter情報を保存
     */
    private static Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();

    /**
     * ユーザー接続を作成し、SseEmitterを返す
     * @param key userId
     * @return SseEmitter
     */
    public static SseEmitter connect(String key) {
        if (sseEmitterMap.containsKey(key)) {
            return sseEmitterMap.get(key);
        }

        try {
            // タイムアウト時間を設定、0は期限なし。デフォルトは30秒
            SseEmitter sseEmitter = new SseEmitter(0L);
            // コールバックを登録
            sseEmitter.onCompletion(completionCallBack(key));
            sseEmitter.onError(errorCallBack(key));
            sseEmitter.onTimeout(timeoutCallBack(key));
            sseEmitterMap.put(key, sseEmitter);
            // 数量+1
            count.getAndIncrement();
            return sseEmitter;
        } catch (Exception e) {
            log.info("新しいSSE接続の作成中に例外が発生しました。現在の接続キーは:{}", key);
        }
        return null;
    }

    /**
     * 指定されたユーザーにメッセージを送信
     * @param key userId
     * @param message メッセージ内容
     */
    public static void sendMessage(String key, String message) {
        if (sseEmitterMap.containsKey(key)) {
            try {
                sseEmitterMap.get(key).send(message);
            } catch (IOException e) {
                log.error("ユーザー[{}]のプッシュ中に例外が発生しました:{}", key, e.getMessage());
                remove(key);
            }
        }
    }

    /**
     * 同じグループの人にメッセージを発信。条件:key + groupId
     * @param groupId グループID
     * @param message メッセージ内容
     */
    public static void groupSendMessage(String groupId, String message) {
        if (!CollectionUtils.isEmpty(sseEmitterMap)) {
            sseEmitterMap.forEach((k, v) -> {
                try {
                    if (k.startsWith(groupId)) {
                        v.send(message, MediaType.APPLICATION_JSON);
                    }
                } catch (IOException e) {
                    log.error("ユーザー[{}]のプッシュ中に例外が発生しました:{}", k, e.getMessage());
                    remove(k);
                }
            });
        }
    }

    /**
     * ブロードキャストメッセージを発信
     * @param message メッセージ内容
     */
    public static void batchSendMessage(String message) {
        sseEmitterMap.forEach((k, v) -> {
            try {
                v.send(message, MediaType.APPLICATION_JSON);
            } catch (IOException e) {
                log.error("ユーザー[{}]のプッシュ中に例外が発生しました:{}", k, e.getMessage());
                remove(k);
            }
        });
    }

    /**
     * メッセージをグループに送信
     * @param message メッセージ内容
     * @param ids ユーザーID集合
     */
    public static void batchSendMessage(String message, Set<String> ids) {
        ids.forEach(userId -> sendMessage(userId, message));
    }

    /**
     * 接続を削除
     * @param key userId
     */
    public static void remove(String key) {
        sseEmitterMap.remove(key);
        // 数量-1
        count.getAndDecrement();
        log.info("接続を削除しました:{}", key);
    }

    /**
     * 現在の接続情報を取得
     * @return Map
     */
    public static List<String> getIds() {
        return new ArrayList<>(sseEmitterMap.keySet());
    }

    /**
     * 現在の接続数を取得
     * @return int
     */
    public static int getCount() {
        return count.intValue();
    }

    private static Runnable completionCallBack(String key) {
        return () -> {
            log.info("接続終了:{}", key);
            remove(key);
        };
    }

    private static Runnable timeoutCallBack(String key) {
        return () -> {
            log.info("接続タイムアウト:{}", key);
            remove(key);
        };
    }

    private static Consumer<Throwable> errorCallBack(String key) {
        return throwable -> {
            log.info("接続異常:{}", key);
            remove(key);
        };
    }
}

次に対話ができるようになるのでしょうか?いいえ、ウェブページの場合、具体的な実装を考える必要があります。主流の AI モデルのウェブ端は主に以下の 2 つの側面があります。

  1. 迅速な質問、ユーザーはホームページで直接質問できます。
  2. 対話情報の保存、各対話はユニークで、ユーザーはいつでも特定の対話に戻ることができます。

具体的な実装は次の通りです。

  1. インターフェースを作成し、ユーザーがホームページにアクセスしたときにこのインターフェースを呼び出してセッション ID を返します。
  2. 次のユーザー入力は、最初のステップで返されたセッション ID にバインドされます。ブラウザをリフレッシュするか、新しいセッションを作成しない限り。

最初のステップは次の通りです。

 /**
     * 接続を作成
     */
    @SneakyThrows
    @GetMapping("/init/{message}")
    public String init() {
        return String.valueOf(UUID.randomUUID());
    }

UUID をフロントエンドに返します。

最後にセッション ID があれば、セッションにバインドして出力できます。

@GetMapping("chat/{id}/{message}")
    public SseEmitter chat(@PathVariable String id, @PathVariable String message, HttpServletResponse response) {

        response.setHeader("Content-type", "text/html;charset=UTF-8");
        response.setCharacterEncoding("UTF-8");

        OpenAiChatClient client = getClient();
        SseEmitter emitter = SseEmitterUtils.connect(id);
        List<Message> messages = getMessages(id, message);
        System.err.println("chatMessageのサイズ: " + messages.size());
        System.err.println("chatMessage: " + chatMessage);

        if (messages.size() > MAX_MESSAGE){
            SseEmitterUtils.sendMessage(id, "対話回数が多すぎます。後で再試行してください🤔");
        }else {
            // モデルの出力ストリームを取得
            Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));

            // ストリーム内のメッセージをSSEで送信
            Mono<String> result = stream
                    .flatMap(it -> {
                        StringBuilder sb = new StringBuilder();
                        String content = it.getResult().getOutput().getContent();
                        Optional.ofNullable(content).ifPresent(r -> {
                            SseEmitterUtils.sendMessage(id, content);
                            sb.append(content);
                        });
                        return Mono.just(sb.toString());
                    })
                    // メッセージを文字列に結合
                    .reduce((a, b) -> a + b)
                    .defaultIfEmpty("");

            // メッセージをchatMessageのAssistantMessageに保存
            result.subscribe(finalContent -> messages.add(new AssistantMessage(finalContent)));

            // メッセージをchatMessageに保存
            chatMessage.put(id, messages);

        }
        return emitter;

    }

まず、response を使用して返すエンコーディングを UTF-8 に設定し、文字化けを防ぎます。
次に、SseEmitterUtils を使用して対応するセッションに接続します。
次に、getMessages を使用して対応するセッションの履歴メッセージを取得します。
次に、MAX_MESSAGE を使用して対話の回数を判断します。これがここでの値を超えた場合、モデルの出力を呼び出さないようにします。これは主にコストを削減するためです。

private static final Integer MAX_MESSAGE = 10;

ここでは 10 回の対話と書かれていますが、実際には 5 回の対話です。なぜなら、履歴メッセージの size で判断しており、履歴メッセージにはユーザーの入力とモデルの出力が含まれているため、2 で割る必要があります。

chatMessage: {e2578f9e-8d71-4531-a6af-400a80fb6569=[SystemMessage{content='あなたは役に立つAIアシスタントです', properties={}, messageType=SYSTEM}, UserMessage{content='こんにちは', properties={}, messageType=USER}, AssistantMessage{content='こんにちは!私に何かお手伝いできることはありますか?', properties={}, messageType=ASSISTANT}, UserMessage{content='あなたは誰ですか', properties={}, messageType=USER}]}

最後に、モデルの出力です。

Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));

stream ストリームを使用し、プロンプトに履歴メッセージと関数を渡します。
出力ストリームを取得したら、SseEmitterUtils.sendMessage(id, content); を使用してストリーム内の内容を対応するセッションに送信します。

最後のステップとして、モデルの出力も履歴メッセージに追加する必要があります。モデルが以前に回答したことを知り、再度回答しないようにするためです。
もし追加しなければ、モデルはユーザーの前のすべての入力に対してすべて回答しようとします。
例えば、最初のラウンドで「杭州について紹介してください」と尋ねると、AI の回答は正常です。
次のラウンドで「杭州にはどんな有名な観光地がありますか」と尋ねると、AI は前のラウンドで回答したかどうかを知らないため、これら 2 つの質問を同時に回答しようとします。つまり、「杭州について紹介してください、杭州にはどんな有名な観光地がありますか」となります。
3 ラウンド目、4 ラウンド目も同様に同時に回答します。

では、ストリームから完全な出力を取得するにはどうすればよいのでしょうか?
最初は stream.subscribeStringBuilder を使用してストリーム内の内容を sb に追加しましたが、sb は常に null でした。Claude に尋ねたところ、Flux は非同期であることがわかりました。最終的に Mono を使用して処理しました。

このコードでは、flatMap のコールバック関数内で新しい StringBuilder インスタンスsbを作成しています。そして、各応答の内容をsbに追加し、Mono<String>を返してsb.toString()の結果を発射します。
次に、reduce オペレーターを使用してすべてのMono<String>を 1 つのMono<String>に統合します。reduce のパラメータはマージ関数で、前の値と現在の値を新しい値にマージします。ここでは、文字列結合操作を使用してすべての応答内容を結合します。
最後に、このMono<String>を購読し、そのコールバック関数内で最終的な内容をメッセージに追加します。応答がない場合は、defaultIfEmpty("")を使用して空の文字列を発射し、null ではなくします。
この方法で、ストリーミング応答のすべての内容を正しく取得し、メッセージに追加することができます。

最後に、すべてが完了しました😎

ああ、忘れていました。フロントエンドがまだ残っていますが、私はフロントエンドにあまり詳しくありません。
そこで、openuiという AI ツールを使用してページとスタイルを作成し、Claude にインターフェースロジックを作成してもらいました。結果として、次のようなものが得られました。

<!doctype html>
<html>

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <script src="https://cdn.tailwindcss.com"></script>
</head>

<body class="bg-zinc-100 dark:bg-zinc-800 min-h-screen p-4">
    <div class="flex flex-col h-full">
        <div id="messages" class="flex-1 overflow-y-auto p-4 space-y-4">
            <div class="flex items-end">
                <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
                <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs">こんにちは~(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄</div>
            </div>
        </div>
        <div class="p-2">
            <input type="text" id="messageInput" placeholder="メッセージを入力してください..."
                class="w-full p-2 rounded-lg border-2 border-zinc-300 dark:border-zinc-600 focus:outline-none focus:border-blue-500 dark:focus:border-blue-400">
            <button onclick="sendMessage()"
                class="mt-2 w-full bg-blue-500 hover:bg-blue-600 dark:bg-blue-600 dark:hover:bg-blue-700 text-white p-2 rounded-lg">送信</button>
        </div>
    </div>
    <script>
        let sessionId; // セッションIDを保存するための変数

        // HTTPリクエストを送信し、応答を処理
        function sendHTTPRequest(url, method = 'GET', body = null) {
            return new Promise((resolve, reject) => {
                const xhr = new XMLHttpRequest();
                xhr.open(method, url, true);
                xhr.onload = () => {
                    if (xhr.status >= 200 && xhr.status < 300) {
                        resolve(xhr.response);
                    } else {
                        reject(xhr.statusText);
                    }
                };
                xhr.onerror = () => reject(xhr.statusText);
                if (body) {
                    xhr.setRequestHeader('Content-Type', 'application/json');
                    xhr.send(JSON.stringify(body));
                } else {
                    xhr.send();
                }
            });
        }

        // サーバーから返されたSSEストリームを処理
        function handleSSEStream(stream) {
            console.log('ストリームが開始されました');
            console.log(stream);
            const messagesContainer = document.getElementById('messages');
            const responseDiv = document.createElement('div');
            responseDiv.className = 'flex items-end';
            responseDiv.innerHTML = `
    <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
    <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs"></div>
  `;
            messagesContainer.appendChild(responseDiv);

            const messageContentDiv = responseDiv.querySelector('div');

            // 'message'イベントをリッスンし、バックエンドが新しいデータを送信したときにトリガー
            stream.onmessage = function (event) {
                const data = event.data;
                console.log('受信したデータ:', data);
                messageContentDiv.textContent += data;
                messagesContainer.scrollTop = messagesContainer.scrollHeight;
            };
        }

        // メッセージを送信
        function sendMessage() {
            const input = document.getElementById('messageInput');
            const message = input.value.trim();
            if (message) {
                const messagesContainer = document.getElementById('messages');
                const newMessageDiv = document.createElement('div');
                newMessageDiv.className = 'flex items-end justify-end';
                newMessageDiv.innerHTML = `
          <div class="mr-2 p-2 bg-green-200 dark:bg-green-700 rounded-lg max-w-xs">
            ${message}
          </div>
          <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
        `;
                messagesContainer.appendChild(newMessageDiv);
                input.value = '';
                messagesContainer.scrollTop = messagesContainer.scrollHeight;

                // 最初のメッセージを送信する際、initリクエストを送信してセッションIDを取得
                if (!this.sessionId) {
                    console.log('init');
                    sendHTTPRequest(`http://127.0.0.1:8868/pro/init/${message}`, 'GET')
                        .then(response => {
                            this.sessionId = response; // セッションIDを保存
                            return handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                        });

                } else {
                    // 以降のリクエストは直接chatインターフェースに送信
                    handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                }
            }
        }
    </script>
</body>

</html>

最終的な効果#

image

PS: 実際にはフロントエンドをさらに最適化できます。例えば、履歴セッションを表示したり、出力を Markdown でレンダリングしたりなど。興味がある方は AI ツールを使って修正してみてください。

Spring AI 連続対話

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。