atrade/src/main/kotlin/network/RagService.kt

81 lines
2.9 KiB
Kotlin
Raw Normal View History

2026-01-21 18:30:03 +09:00
// src/main/kotlin/network/RagService.kt
2026-01-21 18:59:55 +09:00
import VectorStoreTable.metadata
2026-01-21 18:30:03 +09:00
import dev.langchain4j.data.segment.TextSegment
import dev.langchain4j.model.openai.OpenAiChatModel
import dev.langchain4j.model.openai.OpenAiEmbeddingModel
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.SqlExpressionBuilder.plus
import org.jetbrains.exposed.sql.transactions.transaction
import java.time.Duration
object RagService {
// 임베딩 모델 (8081) 및 채팅 모델 (8080) 설정
private val embeddingModel = OpenAiEmbeddingModel.builder()
.baseUrl("http://127.0.0.1:8081/v1")
.apiKey("unused")
.build()
private val chatModel = OpenAiChatModel.builder()
.baseUrl("http://127.0.0.1:8080/v1")
.apiKey("unused")
.timeout(Duration.ofSeconds(60))
.build()
/**
* 텍스트를 임베딩하여 H2 DB에 저장합니다.
*/
fun ingest(text: String, meta: String = "") {
2026-01-21 18:59:55 +09:00
val embeddingVector: DoubleArray = embeddingModel.embed(text).content().vector().map { it.toDouble() }.toDoubleArray()
2026-01-21 18:30:03 +09:00
transaction {
VectorStoreTable.insert {
it[content] = text
it[metadata] = meta
2026-01-21 18:59:55 +09:00
// [수정] 문자열 변환 없이 객체 그대로 전달
it[embedding] = embeddingVector
2026-01-21 18:30:03 +09:00
}
}
println("💾 H2 벡터 저장 완료: ${text.take(15)}...")
}
/**
* 질문과 가장 유사한 정보를 H2에서 검색하여 AI 답변을 생성합니다.
*/
2026-01-21 18:59:55 +09:00
fun askWithContext(question: String): String {
2026-01-21 18:30:03 +09:00
val queryVector = embeddingModel.embed(question).content().vector()
2026-01-21 18:59:55 +09:00
// H2 ARRAY 포맷에 맞춰 (v1, v2, ...) 형태로 변환
val vectorStr = queryVector.joinToString(",", "(", ")")
2026-01-21 18:30:03 +09:00
val context = transaction {
2026-01-21 18:59:55 +09:00
// 코사인 유사도 기준 상위 5개 뉴스 추출
val query = """
SELECT CONTENT FROM VECTOR_STORE
ORDER BY VECTOR_COSINE_SIMILARITY(EMBEDDING, CAST('$vectorStr' AS FLOAT8 ARRAY)) DESC
LIMIT 5
""".trimIndent()
2026-01-21 18:30:03 +09:00
val results = mutableListOf<String>()
exec(query) { rs ->
while (rs.next()) {
2026-01-21 18:59:55 +09:00
results.add(rs.getString("CONTENT"))
2026-01-21 18:30:03 +09:00
}
}
results.joinToString("\n\n")
}
2026-01-21 18:59:55 +09:00
val finalPrompt = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
당신은 실시간 뉴스 분석에 능통한 20 경력의 주식 전문가입니다.
제공된 [참고 자료] 바탕으로 사용자의 질문에 전문적이고 단호하게 답하세요.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
[참고 자료]
$context
[질문]
$question
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
""".trimIndent()
2026-01-21 18:30:03 +09:00
2026-01-21 18:59:55 +09:00
return chatModel.generate(finalPrompt)
2026-01-21 18:30:03 +09:00
}
}