from uuid import UUID
from typing import Any, Dict, List
from langchain.schema.output import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_community.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from typing import Dict
import json
from langchain.docstore.document import Document
from tenacity import RetryCallState
example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""
docs = [
Document(
page_content=example_doc_1,
)
]
query = """How long was Elizabeth hospitalized?
"""
prompt_template = """Use the following pieces of context to answer the question at the end.
{context}
Question: {question}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
class MyCustomHandler(BaseCallbackHandler):
def __init__(self) -> None:
super().__init__()
def on_llm_new_token(self, token: str, **kwargs) -> None:
print("***********************")
def on_text(self, text: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
print("***********on text")
return super().on_text(text, run_id=run_id, parent_run_id=parent_run_id, **kwargs)
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> Any:
print("************** on llm end ***********************")
return super().on_llm_end(
response, run_id=run_id, parent_run_id=parent_run_id, **kwargs
)
def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
print("***************ON ERROR ********************", error)
return super().on_llm_error(error, run_id=run_id, parent_run_id=parent_run_id, **kwargs)
def on_retry(self, retry_state: RetryCallState, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
print("***************************")
return super().on_retry(retry_state, run_id=run_id, parent_run_id=parent_run_id, **kwargs)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: List[str] | None = None,
metadata: Dict[str, Any] | None = None,
**kwargs: Any,
) -> Any:
print("************on llm start ******************")
return super().on_llm_start(
serialized,
prompts,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
**kwargs,
)
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
{"inputs": prompt, **model_kwargs, "stream": True})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]['generated_text']
content_handler = ContentHandler()
callback_f = MyCustomHandler()
llm = SagemakerEndpoint(
streaming=True,
endpoint_name="endpoint",
region_name="region",
model_kwargs={"temperature": 0.7,
"max_new_tokens": 400, },
content_handler=content_handler,
callbacks=[callback_f],
verbose=True
)
chain = load_qa_chain(
llm=llm,
prompt=PROMPT,
)
chain.invoke({"question": query, "input_documents": docs})
`Traceback (most recent call last):
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain_community\llms\sagemaker_endpoint.py", line 345, in call
resp = json.loads(line)
^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\json_init.py", line 346, in loads
return _default_decoder.decode(s)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\json\decoder.py", line 337, in decode
obj, end = self.raw_decode(s, idx=_w(s, 0).end())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\json\decoder.py", line 357, in raw_decode
raise JSONDecodeError("Expecting value", s, err.value) from None
json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "F:\poc\aws_sage_maker\sagemaker_langchain.py", line 127, in
out = chain.invoke({"question": query, "input_documents": docs})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 162, in invoke
raise e
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 156, in invoke
self._call(inputs, run_manager=run_manager)
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\combine_documents\base.py", line 136, in _call
output, extra_return_dict = self.combine_docs(
^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\combine_documents\stuff.py", line 244, in combine_docs
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\llm.py", line 293, in predict
return self(kwargs, callbacks=callbacks)[self.output_key]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core_api\deprecation.py", line 145, in warning_emitting_wrapper
return wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 363, in call
return self.invoke(
^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 162, in invoke
raise e
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 156, in invoke
self._call(inputs, run_manager=run_manager)
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\llm.py", line 103, in _call
response = self.generate([inputs], run_manager=run_manager)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\llm.py", line 115, in generate
return self.llm.generate_prompt(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 568, in generate_prompt
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 741, in generate
output = self._generate_helper(
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 605, in _generate_helper
raise e
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 592, in _generate_helper
self._generate(
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 1177, in _generate
self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain_community\llms\sagemaker_endpoint.py", line 355, in _call
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
ValueError: Error raised by streaming inference endpoint: Expecting value: line 1 column 1 (char 0)`
I am trying to stream the response of Sagemaker, But the code throws errors as above, seems like JSON.loads(line) from sagemaker_endpoint.py is throwing error