以前の記事では、私たちが実装したのは単純な呼び出しで、1 回の対話しか実現できませんでした。これは現実に合わず、優雅でもありません。誰が対話を 1 文だけで終えるのでしょうか。下の例外を除いて
もし AI も 1 回しか対話できないのなら、私たちも AI 大先生、あなたは一旦お引き取りください と言うしかありません。では、モデルに連続して対話させるにはどうすればよいのでしょうか。重要なのは 記憶 です。ユーザーの質問を覚え、同時にモデル自身の出力も記憶することです。そうすれば、モデルは 前文 に基づいて合理的な回答を導き出すことができます。
プロンプト#
第一章で言ったことを覚えていますか?
実際、プロンプトの種類は多く、遊び方も多様で、単なるヒントだけでなく、複数回の対話の鍵でもあります。
プロンプトの作成では、2 種類のパラメータを受け取ることができます。一つは単一のメッセージ、もう一つはメッセージの集合です。
メッセージ内の MessageType には、以下の 4 つのタイプがあります。
これが一致しましたね。
USER("user"), // ユーザーの入力
ASSISTANT("assistant"), // モデルの出力
SYSTEM("system"), // モデルのキャラクター
FUNCTION("function"); // 関数
想像してみてください。現実で誰かと対話しているとき、あなたが一言言えば、相手も一言返します。そして、彼の言葉は必ず 前文 と関連性がなければなりません。そうでなければ、まるで馬の口と驢馬の口が合わないようなものです。したがって、モデルで連続して対話するための鍵も同じです。すなわち、毎回の対話で前文をモデルに渡し、対応する文脈関係を理解させることです。これが メッセージ集合 の役割です。
実装#
簡単な原理については話しましたので、次は直接実装に入ります。しかし、前の数記事での努力を忘れてはいけません。すべてを統合する必要があります。したがって、次は ストリーミング出力、関数呼び出し、連続対話 を含む完全な機能を実装します。
まず、クライアントの初期化。
private static final String BASEURL = "https://xxx";
private static final String TOKEN = "sk-xxxx";
/**
* OpenAiChatClientを作成
* @return OpenAiChatClient
*/
private static OpenAiChatClient getClient(){
OpenAiApi openAiApi = new OpenAiApi(BASEURL, TOKEN);
return new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder()
.withModel("gpt-3.5-turbo-1106")
.withTemperature(0.8F)
.build());
}
この作成ステップでは、OpenAI のいくつかの古いモデルは関数呼び出しやストリーミング出力をサポートしていないことに注意が必要です。一部の古いモデルの最大トークン数は 4K しかありません。作成エラーは 400 BAD REQUEST
を引き起こす可能性があります。
次に、履歴情報の保存。
まず、Map を作成します。
private static Map<String, List<Message>> chatMessage = new ConcurrentHashMap<>();
ここで Map のキーはセッション ID に対応し、値は履歴メッセージです。セッションには必ず対応するユニーク ID が必要です。そうでなければ、混乱が生じます。
次に、毎回の会話でセッション ID とユーザーの入力を渡し、対応する入力をメッセージ集合に追加します。
/**
* プロンプトを返す
* @param message ユーザーが入力したメッセージ
* @return プロンプト
*/
private List<Message> getMessages(String id, String message) {
String systemPrompt = "{prompt}";
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);
Message userMessage = new UserMessage(message);
Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "あなたは役に立つAIアシスタントです"));
List<Message> messages = chatMessage.get(id);
// メッセージが取得できなかった場合、新しいメッセージを作成し、システムプロンプトとユーザー入力メッセージをメッセージリストに追加します
if (messages == null){
messages = new ArrayList<>();
messages.add(systemMessage);
messages.add(userMessage);
} else {
messages.add(userMessage);
}
return messages;
}
ここで、もし最初の対話でメッセージリストが空であれば、systemMessage
も追加され、モデルにキャラクターを初期化することになります。
次に、関数の作成。
/**
* 関数呼び出しを初期化
* @return ChatOptions
*/
private ChatOptions initFunc(){
return OpenAiChatOptions.builder().withFunctionCallbacks(List.of(
FunctionCallbackWrapper.builder(new MockWeatherService()).withName("weather").withDescription("場所の天気を取得").build(),
FunctionCallbackWrapper.builder(new WbHotService()).withName("wbHot").withDescription("Weiboのホットリストを取得").build(),
FunctionCallbackWrapper.builder(new TodayNews()).withName("todayNews").withDescription("60秒で世界のニュースを視聴").build(),
FunctionCallbackWrapper.builder(new DailyEnglishFunc()).withName("dailyEnglish").withDescription("英語の毎日のインスピレーションの文").build())).build();
}
関数に関する詳細は第三章を参照してください。
最後に、出力です。
ここで言及しておくと、最終的な実装効果はウェブページであるため、サーバーが能動的にプッシュする機能を使用します。ここでは SSE を使用しており、SSE の紹介については以前のブログで書いたメッセージプッシュを参照してください。
要するに、以下は SSE のユーティリティクラスです。
@Component
@Slf4j
public class SseEmitterUtils {
/**
* 現在の接続数
*/
private static AtomicInteger count = new AtomicInteger(0);
/**
* SseEmitter情報を保存
*/
private static Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();
/**
* ユーザー接続を作成し、SseEmitterを返す
* @param key userId
* @return SseEmitter
*/
public static SseEmitter connect(String key) {
if (sseEmitterMap.containsKey(key)) {
return sseEmitterMap.get(key);
}
try {
// タイムアウト時間を設定、0は期限なし。デフォルトは30秒
SseEmitter sseEmitter = new SseEmitter(0L);
// コールバックを登録
sseEmitter.onCompletion(completionCallBack(key));
sseEmitter.onError(errorCallBack(key));
sseEmitter.onTimeout(timeoutCallBack(key));
sseEmitterMap.put(key, sseEmitter);
// 数量+1
count.getAndIncrement();
return sseEmitter;
} catch (Exception e) {
log.info("新しいSSE接続の作成中に例外が発生しました。現在の接続キーは:{}", key);
}
return null;
}
/**
* 指定されたユーザーにメッセージを送信
* @param key userId
* @param message メッセージ内容
*/
public static void sendMessage(String key, String message) {
if (sseEmitterMap.containsKey(key)) {
try {
sseEmitterMap.get(key).send(message);
} catch (IOException e) {
log.error("ユーザー[{}]のプッシュ中に例外が発生しました:{}", key, e.getMessage());
remove(key);
}
}
}
/**
* 同じグループの人にメッセージを発信。条件:key + groupId
* @param groupId グループID
* @param message メッセージ内容
*/
public static void groupSendMessage(String groupId, String message) {
if (!CollectionUtils.isEmpty(sseEmitterMap)) {
sseEmitterMap.forEach((k, v) -> {
try {
if (k.startsWith(groupId)) {
v.send(message, MediaType.APPLICATION_JSON);
}
} catch (IOException e) {
log.error("ユーザー[{}]のプッシュ中に例外が発生しました:{}", k, e.getMessage());
remove(k);
}
});
}
}
/**
* ブロードキャストメッセージを発信
* @param message メッセージ内容
*/
public static void batchSendMessage(String message) {
sseEmitterMap.forEach((k, v) -> {
try {
v.send(message, MediaType.APPLICATION_JSON);
} catch (IOException e) {
log.error("ユーザー[{}]のプッシュ中に例外が発生しました:{}", k, e.getMessage());
remove(k);
}
});
}
/**
* メッセージをグループに送信
* @param message メッセージ内容
* @param ids ユーザーID集合
*/
public static void batchSendMessage(String message, Set<String> ids) {
ids.forEach(userId -> sendMessage(userId, message));
}
/**
* 接続を削除
* @param key userId
*/
public static void remove(String key) {
sseEmitterMap.remove(key);
// 数量-1
count.getAndDecrement();
log.info("接続を削除しました:{}", key);
}
/**
* 現在の接続情報を取得
* @return Map
*/
public static List<String> getIds() {
return new ArrayList<>(sseEmitterMap.keySet());
}
/**
* 現在の接続数を取得
* @return int
*/
public static int getCount() {
return count.intValue();
}
private static Runnable completionCallBack(String key) {
return () -> {
log.info("接続終了:{}", key);
remove(key);
};
}
private static Runnable timeoutCallBack(String key) {
return () -> {
log.info("接続タイムアウト:{}", key);
remove(key);
};
}
private static Consumer<Throwable> errorCallBack(String key) {
return throwable -> {
log.info("接続異常:{}", key);
remove(key);
};
}
}
次に対話ができるようになるのでしょうか?いいえ、ウェブページの場合、具体的な実装を考える必要があります。主流の AI モデルのウェブ端は主に以下の 2 つの側面があります。
- 迅速な質問、ユーザーはホームページで直接質問できます。
- 対話情報の保存、各対話はユニークで、ユーザーはいつでも特定の対話に戻ることができます。
具体的な実装は次の通りです。
- インターフェースを作成し、ユーザーがホームページにアクセスしたときにこのインターフェースを呼び出してセッション ID を返します。
- 次のユーザー入力は、最初のステップで返されたセッション ID にバインドされます。ブラウザをリフレッシュするか、新しいセッションを作成しない限り。
最初のステップは次の通りです。
/**
* 接続を作成
*/
@SneakyThrows
@GetMapping("/init/{message}")
public String init() {
return String.valueOf(UUID.randomUUID());
}
UUID をフロントエンドに返します。
最後にセッション ID があれば、セッションにバインドして出力できます。
@GetMapping("chat/{id}/{message}")
public SseEmitter chat(@PathVariable String id, @PathVariable String message, HttpServletResponse response) {
response.setHeader("Content-type", "text/html;charset=UTF-8");
response.setCharacterEncoding("UTF-8");
OpenAiChatClient client = getClient();
SseEmitter emitter = SseEmitterUtils.connect(id);
List<Message> messages = getMessages(id, message);
System.err.println("chatMessageのサイズ: " + messages.size());
System.err.println("chatMessage: " + chatMessage);
if (messages.size() > MAX_MESSAGE){
SseEmitterUtils.sendMessage(id, "対話回数が多すぎます。後で再試行してください🤔");
}else {
// モデルの出力ストリームを取得
Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));
// ストリーム内のメッセージをSSEで送信
Mono<String> result = stream
.flatMap(it -> {
StringBuilder sb = new StringBuilder();
String content = it.getResult().getOutput().getContent();
Optional.ofNullable(content).ifPresent(r -> {
SseEmitterUtils.sendMessage(id, content);
sb.append(content);
});
return Mono.just(sb.toString());
})
// メッセージを文字列に結合
.reduce((a, b) -> a + b)
.defaultIfEmpty("");
// メッセージをchatMessageのAssistantMessageに保存
result.subscribe(finalContent -> messages.add(new AssistantMessage(finalContent)));
// メッセージをchatMessageに保存
chatMessage.put(id, messages);
}
return emitter;
}
まず、response
を使用して返すエンコーディングを UTF-8
に設定し、文字化けを防ぎます。
次に、SseEmitterUtils
を使用して対応するセッションに接続します。
次に、getMessages
を使用して対応するセッションの履歴メッセージを取得します。
次に、MAX_MESSAGE
を使用して対話の回数を判断します。これがここでの値を超えた場合、モデルの出力を呼び出さないようにします。これは主にコストを削減するためです。
private static final Integer MAX_MESSAGE = 10;
ここでは 10 回の対話と書かれていますが、実際には 5 回の対話です。なぜなら、履歴メッセージの size
で判断しており、履歴メッセージにはユーザーの入力とモデルの出力が含まれているため、2 で割る必要があります。
chatMessage: {e2578f9e-8d71-4531-a6af-400a80fb6569=[SystemMessage{content='あなたは役に立つAIアシスタントです', properties={}, messageType=SYSTEM}, UserMessage{content='こんにちは', properties={}, messageType=USER}, AssistantMessage{content='こんにちは!私に何かお手伝いできることはありますか?', properties={}, messageType=ASSISTANT}, UserMessage{content='あなたは誰ですか', properties={}, messageType=USER}]}
最後に、モデルの出力です。
Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));
stream
ストリームを使用し、プロンプトに履歴メッセージと関数を渡します。
出力ストリームを取得したら、SseEmitterUtils.sendMessage(id, content);
を使用してストリーム内の内容を対応するセッションに送信します。
最後のステップとして、モデルの出力も履歴メッセージに追加する必要があります。モデルが以前に回答したことを知り、再度回答しないようにするためです。
もし追加しなければ、モデルはユーザーの前のすべての入力に対してすべて回答しようとします。
例えば、最初のラウンドで「杭州について紹介してください」と尋ねると、AI の回答は正常です。
次のラウンドで「杭州にはどんな有名な観光地がありますか」と尋ねると、AI は前のラウンドで回答したかどうかを知らないため、これら 2 つの質問を同時に回答しようとします。つまり、「杭州について紹介してください、杭州にはどんな有名な観光地がありますか」となります。
3 ラウンド目、4 ラウンド目も同様に同時に回答します。
では、ストリームから完全な出力を取得するにはどうすればよいのでしょうか?
最初は stream.subscribe
と StringBuilder
を使用してストリーム内の内容を sb
に追加しましたが、sb
は常に null でした。Claude に尋ねたところ、Flux は非同期であることがわかりました。最終的に Mono を使用して処理しました。
このコードでは、flatMap のコールバック関数内で新しい StringBuilder インスタンス
sb
を作成しています。そして、各応答の内容をsb
に追加し、Mono<String>
を返してsb.toString()
の結果を発射します。
次に、reduce オペレーターを使用してすべてのMono<String>
を 1 つのMono<String>
に統合します。reduce のパラメータはマージ関数で、前の値と現在の値を新しい値にマージします。ここでは、文字列結合操作を使用してすべての応答内容を結合します。
最後に、このMono<String>
を購読し、そのコールバック関数内で最終的な内容をメッセージに追加します。応答がない場合は、defaultIfEmpty("")
を使用して空の文字列を発射し、null ではなくします。
この方法で、ストリーミング応答のすべての内容を正しく取得し、メッセージに追加することができます。
最後に、すべてが完了しました😎
ああ、忘れていました。フロントエンドがまだ残っていますが、私はフロントエンドにあまり詳しくありません。
そこで、openuiという AI ツールを使用してページとスタイルを作成し、Claude にインターフェースロジックを作成してもらいました。結果として、次のようなものが得られました。
<!doctype html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://cdn.tailwindcss.com"></script>
</head>
<body class="bg-zinc-100 dark:bg-zinc-800 min-h-screen p-4">
<div class="flex flex-col h-full">
<div id="messages" class="flex-1 overflow-y-auto p-4 space-y-4">
<div class="flex items-end">
<img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
<div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs">こんにちは~(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄</div>
</div>
</div>
<div class="p-2">
<input type="text" id="messageInput" placeholder="メッセージを入力してください..."
class="w-full p-2 rounded-lg border-2 border-zinc-300 dark:border-zinc-600 focus:outline-none focus:border-blue-500 dark:focus:border-blue-400">
<button onclick="sendMessage()"
class="mt-2 w-full bg-blue-500 hover:bg-blue-600 dark:bg-blue-600 dark:hover:bg-blue-700 text-white p-2 rounded-lg">送信</button>
</div>
</div>
<script>
let sessionId; // セッションIDを保存するための変数
// HTTPリクエストを送信し、応答を処理
function sendHTTPRequest(url, method = 'GET', body = null) {
return new Promise((resolve, reject) => {
const xhr = new XMLHttpRequest();
xhr.open(method, url, true);
xhr.onload = () => {
if (xhr.status >= 200 && xhr.status < 300) {
resolve(xhr.response);
} else {
reject(xhr.statusText);
}
};
xhr.onerror = () => reject(xhr.statusText);
if (body) {
xhr.setRequestHeader('Content-Type', 'application/json');
xhr.send(JSON.stringify(body));
} else {
xhr.send();
}
});
}
// サーバーから返されたSSEストリームを処理
function handleSSEStream(stream) {
console.log('ストリームが開始されました');
console.log(stream);
const messagesContainer = document.getElementById('messages');
const responseDiv = document.createElement('div');
responseDiv.className = 'flex items-end';
responseDiv.innerHTML = `
<img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
<div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs"></div>
`;
messagesContainer.appendChild(responseDiv);
const messageContentDiv = responseDiv.querySelector('div');
// 'message'イベントをリッスンし、バックエンドが新しいデータを送信したときにトリガー
stream.onmessage = function (event) {
const data = event.data;
console.log('受信したデータ:', data);
messageContentDiv.textContent += data;
messagesContainer.scrollTop = messagesContainer.scrollHeight;
};
}
// メッセージを送信
function sendMessage() {
const input = document.getElementById('messageInput');
const message = input.value.trim();
if (message) {
const messagesContainer = document.getElementById('messages');
const newMessageDiv = document.createElement('div');
newMessageDiv.className = 'flex items-end justify-end';
newMessageDiv.innerHTML = `
<div class="mr-2 p-2 bg-green-200 dark:bg-green-700 rounded-lg max-w-xs">
${message}
</div>
<img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
`;
messagesContainer.appendChild(newMessageDiv);
input.value = '';
messagesContainer.scrollTop = messagesContainer.scrollHeight;
// 最初のメッセージを送信する際、initリクエストを送信してセッションIDを取得
if (!this.sessionId) {
console.log('init');
sendHTTPRequest(`http://127.0.0.1:8868/pro/init/${message}`, 'GET')
.then(response => {
this.sessionId = response; // セッションIDを保存
return handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
});
} else {
// 以降のリクエストは直接chatインターフェースに送信
handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
}
}
}
</script>
</body>
</html>
最終的な効果#
PS: 実際にはフロントエンドをさらに最適化できます。例えば、履歴セッションを表示したり、出力を Markdown でレンダリングしたりなど。興味がある方は AI ツールを使って修正してみてください。