| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- """Simplify transcript and call LLM for summary with token-aware truncation."""
- import json
- import boto3
- import tiktoken
- from openai import OpenAI
- from shared import S3_BUCKET, get_ssm, update_job
- s3 = boto3.client('s3')
- # 模型上下文窗口大小 (input tokens)
- # 保守估计,预留 output tokens 和 system prompt
- MODEL_CONTEXT = {
- 'default': 128000, # 输入 128k,输出 16k,分开计算
- }
- SYSTEM_PROMPT = '你是一个专业的会议纪要助手。'
- USER_PROMPT_TEMPLATE = """请根据以下会议转录内容,使用简体中文普通话,生成一份结构化的会议纪要,包含:
- 1. 会议主题
- 2. 参与人员(用 Speaker 标识)
- 3. 关键讨论点
- 4. 决议事项和 Action Items
- 5. 时间安排
- 转录内容:
- {transcript}"""
- # 初始化 tokenizer (cl100k_base 适用于 GPT-4/Claude/大多数中文模型)
- try:
- _enc = tiktoken.get_encoding('cl100k_base')
- except Exception:
- _enc = None
- def count_tokens(text: str) -> int:
- """计算 token 数"""
- if _enc:
- return len(_enc.encode(text))
- # fallback: 中文约 1.5 token/字, 英文约 1.3 token/word
- return int(len(text) * 0.7)
- def truncate_to_tokens(text: str, max_tokens: int) -> tuple[str, int]:
- """截断文本到指定 token 数,返回 (截断后文本, 实际 token 数)"""
- if _enc:
- tokens = _enc.encode(text)
- if len(tokens) <= max_tokens:
- return text, len(tokens)
- truncated = _enc.decode(tokens[:max_tokens])
- return truncated, max_tokens
- # fallback: 按字符估算
- ratio = max_tokens / max(count_tokens(text), 1)
- if ratio >= 1:
- return text, count_tokens(text)
- cut = int(len(text) * ratio)
- return text[:cut], max_tokens
- def handler(event, context):
- job_id = event['job_id']
- update_job(job_id, status='SUMMARIZING')
- # Load transcribe output
- obj = s3.get_object(Bucket=S3_BUCKET, Key=f"jobs/{job_id}/transcribe-output.json")
- data = json.loads(obj['Body'].read())
- # Simplify: merge items by speaker
- simplified = simplify_transcript(data)
- # Save simplified
- s3.put_object(
- Bucket=S3_BUCKET,
- Key=f"jobs/{job_id}/simplified.txt",
- Body=simplified.encode('utf-8'),
- )
- # LLM config
- llm_url = get_ssm('llm/api_url')
- llm_key = get_ssm('llm/api_key')
- llm_model = get_ssm('llm/model')
- client = OpenAI(api_key=llm_key, base_url=llm_url.rstrip('/'), timeout=600)
- # 计算可用 token 预算
- context_limit = MODEL_CONTEXT.get(llm_model, MODEL_CONTEXT['default'])
- system_tokens = count_tokens(SYSTEM_PROMPT) + 20 # 消息格式开销
- template_tokens = count_tokens(USER_PROMPT_TEMPLATE.replace('{transcript}', ''))
- available = context_limit - system_tokens - template_tokens
- # 截断转录到可用 token 数
- transcript_text, used_tokens = truncate_to_tokens(simplified, available)
- total_tokens = count_tokens(simplified)
- print(f"[summarize] 转录总 token: {total_tokens}, 可用: {available}, 实际输入: {used_tokens}")
- if used_tokens < total_tokens:
- print(f"[summarize] 截断: {total_tokens} -> {used_tokens} tokens ({used_tokens/total_tokens*100:.1f}%)")
- prompt = USER_PROMPT_TEMPLATE.format(transcript=transcript_text)
- response = client.responses.create(
- model=llm_model,
- instructions=SYSTEM_PROMPT,
- input=[{'role': 'user', 'content': prompt}],
- max_output_tokens=128000,
- )
- summary = response.output_text
- # 记录 usage
- usage = response.usage
- if usage:
- print(f"[summarize] LLM usage: input={usage.input_tokens}, output={usage.output_tokens}, total={usage.total_tokens}")
- # Save summary
- s3.put_object(
- Bucket=S3_BUCKET,
- Key=f"jobs/{job_id}/summary.md",
- Body=summary.encode('utf-8'),
- )
- update_job(job_id, status='SUMMARIZED')
- return {**event, 'summary': summary}
- def simplify_transcript(data):
- items = data.get('results', {}).get('items', [])
- paragraphs = []
- current_speaker = None
- current_text = []
- current_start = None
- for item in items:
- speaker = item.get('speaker_label', '')
- content = item['alternatives'][0]['content']
- if item['type'] == 'punctuation':
- if current_text:
- current_text.append(content)
- continue
- if speaker != current_speaker and current_text:
- paragraphs.append((current_speaker, current_start, ''.join(current_text).strip()))
- current_text = []
- if speaker != current_speaker:
- current_speaker = speaker
- current_start = item.get('start_time', '')
- if current_text:
- prev = current_text[-1]
- if prev and prev[-1].isascii() and content[0].isascii():
- current_text.append(' ')
- current_text.append(content)
- if current_text:
- paragraphs.append((current_speaker, current_start, ''.join(current_text).strip()))
- lines = []
- for spk, start, text in paragraphs:
- ts = ''
- if start:
- secs = float(start)
- ts = f"[{int(secs//60):02d}:{int(secs%60):02d}]"
- lines.append(f"**{spk}** {ts}")
- lines.append(text)
- lines.append('')
- return '\n'.join(lines)
|