query

Function

Searches for related images based on the given text and sends the images and prompts to an LLM to generate images.

Prototype

def query(text, llm_config, *args, **kwargs)

Parameters

Parameter

Data Type

Required/Optional

Description

text

String

Required

Description of the image to be searched for. The value range is (0, 1 × 1000 × 1000].

llm_config

LLMParameterConfig

Optional

Inherits its parent class, which is not used here.

args

List

Optional

Not used

kwargs["prompt"]

String

Required

Image generation prompt, which is passed by kwargs. The length range is (0, 1 × 1024 × 1024].

kwargs["size"]

String

Optional

Image size, in the format of "height × width", which is specified by kwargs. The supported size depends on the corresponding LLM. The regular expression matching format is ^\d{1,5}\*\d{1,5}$. The default value is 512 × 512.

Return Value

Data Type

Description

Dict,

{"prompt": prompt, "result": data}

data is the Base64-encoded image data.

Example

# In this example, a new image is generated by the LLM based on the retrieved images and generation prompts.
from paddle.base import libpaddle
from mx_rag.chain import Img2ImgChain
from mx_rag.llm import Img2ImgMultiModel
from mx_rag.retrievers import Retriever
from mx_rag.storage.vectorstore import MindFAISS
from mx_rag.storage.document_store import SQLiteDocstore
from mx_rag.embedding.local import ImageEmbedding
from mx_rag.utils import ClientParam
dev = 0
img_emb = ImageEmbedding(model_name="ViT-B-16", model_path="/path/to/chinese-clip-vit-base-patch16", dev_id=dev)
img_vector_store = MindFAISS(x_dim=512,
                             devs=[dev],
                             load_local_index="/path/to/image_faiss.index",
                             auto_save=True)
chunk_store = SQLiteDocstore(db_path="/path/to/sql.db")
img_retriever = Retriever(vector_store=img_vector_store, document_store=chunk_store,
                          embed_func=img_emb.embed_documents, k=1, score_threshold=0.5)
multi_model = Img2ImgMultiModel(model_name="sd",
                                url="img to image url",
                                client_param=ClientParam(ca_file="/path/to/ca.crt"))
img2img_chain = Img2ImgChain(multi_model=multi_model, retriever=img_retriever)
llm_data = img2img_chain.query("Search for an image of the little boy",
                               prompt="he is a knight, wearing armor, big sword in right hand. Blur the background, focus on the knight")
print(llm_data)