Unable to configure stream on Sagemaker Langchain

31 Views Asked by At
from uuid import UUID
from typing import Any, Dict, List
from langchain.schema.output import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_community.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from typing import Dict
import json
from langchain.docstore.document import Document
from tenacity import RetryCallState
example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""

docs = [
    Document(
        page_content=example_doc_1,
    )
]
query = """How long was Elizabeth hospitalized?
"""

prompt_template = """Use the following pieces of context to answer the question at the end.

{context}

Question: {question}
Answer:"""
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)


class MyCustomHandler(BaseCallbackHandler):
    def __init__(self) -> None:
        super().__init__()

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        print("***********************")

    def on_text(self, text: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
        print("***********on text")
        return super().on_text(text, run_id=run_id, parent_run_id=parent_run_id, **kwargs)

    def on_llm_end(
        self,
        response: LLMResult,
        *,
        run_id: UUID,
        parent_run_id: UUID | None = None,
        **kwargs: Any,
    ) -> Any:
        print("************** on llm end ***********************")
        return super().on_llm_end(
            response, run_id=run_id, parent_run_id=parent_run_id, **kwargs
        )

    def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
        print("***************ON ERROR ********************", error)
        return super().on_llm_error(error, run_id=run_id, parent_run_id=parent_run_id, **kwargs)

    def on_retry(self, retry_state: RetryCallState, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
        print("***************************")
        return super().on_retry(retry_state, run_id=run_id, parent_run_id=parent_run_id, **kwargs)

    def on_llm_start(
        self,
        serialized: Dict[str, Any],
        prompts: List[str],
        *,
        run_id: UUID,
        parent_run_id: UUID | None = None,
        tags: List[str] | None = None,
        metadata: Dict[str, Any] | None = None,
        **kwargs: Any,
    ) -> Any:
        print("************on llm start ******************")
        return super().on_llm_start(
            serialized,
            prompts,
            run_id=run_id,
            parent_run_id=parent_run_id,
            tags=tags,
            metadata=metadata,
            **kwargs,
        )


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps(
            {"inputs": prompt, **model_kwargs, "stream": True})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]['generated_text']


content_handler = ContentHandler()
callback_f = MyCustomHandler()

llm = SagemakerEndpoint(
    streaming=True,
    endpoint_name="endpoint",
    region_name="region",
    model_kwargs={"temperature": 0.7,
                  "max_new_tokens": 400, },
    content_handler=content_handler,
    callbacks=[callback_f],
    verbose=True
)
chain = load_qa_chain(
    llm=llm,
    prompt=PROMPT,
)

chain.invoke({"question": query, "input_documents": docs})
`Traceback (most recent call last):
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain_community\llms\sagemaker_endpoint.py", line 345, in call
resp = json.loads(line)
^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\json_init.py", line 346, in loads
return _default_decoder.decode(s)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\json\decoder.py", line 337, in decode
obj, end = self.raw_decode(s, idx=_w(s, 0).end())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\json\decoder.py", line 357, in raw_decode
raise JSONDecodeError("Expecting value", s, err.value) from None
json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "F:\poc\aws_sage_maker\sagemaker_langchain.py", line 127, in
out = chain.invoke({"question": query, "input_documents": docs})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 162, in invoke
raise e
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 156, in invoke
self._call(inputs, run_manager=run_manager)
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\combine_documents\base.py", line 136, in _call
output, extra_return_dict = self.combine_docs(
^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\combine_documents\stuff.py", line 244, in combine_docs
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\llm.py", line 293, in predict
return self(kwargs, callbacks=callbacks)[self.output_key]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core_api\deprecation.py", line 145, in warning_emitting_wrapper
return wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 363, in call
return self.invoke(
^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 162, in invoke
raise e
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\base.py", line 156, in invoke
self._call(inputs, run_manager=run_manager)
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\llm.py", line 103, in _call
response = self.generate([inputs], run_manager=run_manager)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain\chains\llm.py", line 115, in generate
return self.llm.generate_prompt(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 568, in generate_prompt
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 741, in generate
output = self._generate_helper(
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 605, in _generate_helper
raise e
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 592, in _generate_helper
self._generate(
File "C:\Users\NYL\anaconda3\envs\sagemaker\Lib\site-packages\langchain_core\language_models\llms.py", line 1177, in _generate
self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
File "C:\Users\NYL\AppData\Roaming\Python\Python311\site-packages\langchain_community\llms\sagemaker_endpoint.py", line 355, in _call
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
ValueError: Error raised by streaming inference endpoint: Expecting value: line 1 column 1 (char 0)`

I am trying to stream the response of Sagemaker, But the code throws errors as above, seems like JSON.loads(line) from sagemaker_endpoint.py is throwing error

0

There are 0 best solutions below