summarize.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """Simplify transcript and call LLM for summary with token-aware truncation."""
  2. import json
  3. import boto3
  4. import tiktoken
  5. from openai import OpenAI
  6. from shared import S3_BUCKET, get_ssm, update_job
  7. s3 = boto3.client('s3')
  8. # 模型上下文窗口大小 (input tokens)
  9. # 保守估计,预留 output tokens 和 system prompt
  10. MODEL_CONTEXT = {
  11. 'default': 128000, # 输入 128k,输出 16k,分开计算
  12. }
  13. SYSTEM_PROMPT = '你是一个专业的会议纪要助手。'
  14. USER_PROMPT_TEMPLATE = """请根据以下会议转录内容,使用简体中文普通话,生成一份结构化的会议纪要,包含:
  15. 1. 会议主题
  16. 2. 参与人员(用 Speaker 标识)
  17. 3. 关键讨论点
  18. 4. 决议事项和 Action Items
  19. 5. 时间安排
  20. 转录内容:
  21. {transcript}"""
  22. # 初始化 tokenizer (cl100k_base 适用于 GPT-4/Claude/大多数中文模型)
  23. try:
  24. _enc = tiktoken.get_encoding('cl100k_base')
  25. except Exception:
  26. _enc = None
  27. def count_tokens(text: str) -> int:
  28. """计算 token 数"""
  29. if _enc:
  30. return len(_enc.encode(text))
  31. # fallback: 中文约 1.5 token/字, 英文约 1.3 token/word
  32. return int(len(text) * 0.7)
  33. def truncate_to_tokens(text: str, max_tokens: int) -> tuple[str, int]:
  34. """截断文本到指定 token 数,返回 (截断后文本, 实际 token 数)"""
  35. if _enc:
  36. tokens = _enc.encode(text)
  37. if len(tokens) <= max_tokens:
  38. return text, len(tokens)
  39. truncated = _enc.decode(tokens[:max_tokens])
  40. return truncated, max_tokens
  41. # fallback: 按字符估算
  42. ratio = max_tokens / max(count_tokens(text), 1)
  43. if ratio >= 1:
  44. return text, count_tokens(text)
  45. cut = int(len(text) * ratio)
  46. return text[:cut], max_tokens
  47. def handler(event, context):
  48. job_id = event['job_id']
  49. update_job(job_id, status='SUMMARIZING')
  50. # Load transcribe output
  51. obj = s3.get_object(Bucket=S3_BUCKET, Key=f"jobs/{job_id}/transcribe-output.json")
  52. data = json.loads(obj['Body'].read())
  53. # Simplify: merge items by speaker
  54. simplified = simplify_transcript(data)
  55. # Save simplified
  56. s3.put_object(
  57. Bucket=S3_BUCKET,
  58. Key=f"jobs/{job_id}/simplified.txt",
  59. Body=simplified.encode('utf-8'),
  60. )
  61. # LLM config
  62. llm_url = get_ssm('llm/api_url')
  63. llm_key = get_ssm('llm/api_key')
  64. llm_model = get_ssm('llm/model')
  65. client = OpenAI(api_key=llm_key, base_url=llm_url.rstrip('/'), timeout=600)
  66. # 计算可用 token 预算
  67. context_limit = MODEL_CONTEXT.get(llm_model, MODEL_CONTEXT['default'])
  68. system_tokens = count_tokens(SYSTEM_PROMPT) + 20 # 消息格式开销
  69. template_tokens = count_tokens(USER_PROMPT_TEMPLATE.replace('{transcript}', ''))
  70. available = context_limit - system_tokens - template_tokens
  71. # 截断转录到可用 token 数
  72. transcript_text, used_tokens = truncate_to_tokens(simplified, available)
  73. total_tokens = count_tokens(simplified)
  74. print(f"[summarize] 转录总 token: {total_tokens}, 可用: {available}, 实际输入: {used_tokens}")
  75. if used_tokens < total_tokens:
  76. print(f"[summarize] 截断: {total_tokens} -> {used_tokens} tokens ({used_tokens/total_tokens*100:.1f}%)")
  77. prompt = USER_PROMPT_TEMPLATE.format(transcript=transcript_text)
  78. response = client.responses.create(
  79. model=llm_model,
  80. instructions=SYSTEM_PROMPT,
  81. input=[{'role': 'user', 'content': prompt}],
  82. max_output_tokens=128000,
  83. )
  84. summary = response.output_text
  85. # 记录 usage
  86. usage = response.usage
  87. if usage:
  88. print(f"[summarize] LLM usage: input={usage.input_tokens}, output={usage.output_tokens}, total={usage.total_tokens}")
  89. # Save summary
  90. s3.put_object(
  91. Bucket=S3_BUCKET,
  92. Key=f"jobs/{job_id}/summary.md",
  93. Body=summary.encode('utf-8'),
  94. )
  95. update_job(job_id, status='SUMMARIZED')
  96. return {**event, 'summary': summary}
  97. def simplify_transcript(data):
  98. items = data.get('results', {}).get('items', [])
  99. paragraphs = []
  100. current_speaker = None
  101. current_text = []
  102. current_start = None
  103. for item in items:
  104. speaker = item.get('speaker_label', '')
  105. content = item['alternatives'][0]['content']
  106. if item['type'] == 'punctuation':
  107. if current_text:
  108. current_text.append(content)
  109. continue
  110. if speaker != current_speaker and current_text:
  111. paragraphs.append((current_speaker, current_start, ''.join(current_text).strip()))
  112. current_text = []
  113. if speaker != current_speaker:
  114. current_speaker = speaker
  115. current_start = item.get('start_time', '')
  116. if current_text:
  117. prev = current_text[-1]
  118. if prev and prev[-1].isascii() and content[0].isascii():
  119. current_text.append(' ')
  120. current_text.append(content)
  121. if current_text:
  122. paragraphs.append((current_speaker, current_start, ''.join(current_text).strip()))
  123. lines = []
  124. for spk, start, text in paragraphs:
  125. ts = ''
  126. if start:
  127. secs = float(start)
  128. ts = f"[{int(secs//60):02d}:{int(secs%60):02d}]"
  129. lines.append(f"**{spk}** {ts}")
  130. lines.append(text)
  131. lines.append('')
  132. return '\n'.join(lines)