Text-to-Image Retrieval

This section describes how to use RAG SDK to retrieve images based on text.

Prerequisites

You have completed operations in Installing RAG SDK.

Sample Process

Procedure

  1. Create a retrieve_img_demo.py in any directory. The content is as follows:
    import argparse
    
    from mx_rag.document import LoaderMng
    from mx_rag.document.loader import ImageLoader
    
    from mx_rag.embedding.local import ImageEmbedding
    from mx_rag.knowledge import KnowledgeDB, upload_files
    from mx_rag.knowledge.knowledge import KnowledgeStore
    from mx_rag.retrievers import Retriever
    from mx_rag.storage.document_store import SQLiteDocstore
    from mx_rag.storage.vectorstore import MindFAISS
    
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument ('--query', type=str, help="Query the text content of the image")
        parser.add_argument ('--image-path', type=str, action='append', help="Path of the image to be imported into the database")
    
        args = parser.parse_args().__dict__
        images: list[str] = args.pop("image_path")
        query = args.pop("query")
        loader_mng = LoaderMng()
        loader_mng.register_loader(ImageLoader, [".jpg"])
    
        dev = 0
        img_emb = ImageEmbedding("ViT-B-16", model_path="path to clip model", dev_id=dev)
    
        img_vector_store = MindFAISS(x_dim=512, devs=[dev],
                                     load_local_index="./image_faiss.index",
                                     auto_save=True)
        chunk_store = SQLiteDocstore(db_path="./sql.db")
    
       # Initialize the relational database for knowledge management.
        knowledge_store = KnowledgeStore(db_path="./sql.db")
    
        user_id = "fc557af8-5973-4893-9624-4a510c3e18fb"
        knowledge_store.add_knowledge("test", user_id=user_id)
    
        knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, chunk_store=chunk_store, vector_store=img_vector_store,
                                   knowledge_name="test", white_paths=["/home"], user_id=user_id)
    
        upload_files(knowledge_db, images, loader_mng=loader_mng,
                     embed_func=img_emb.embed_images, force=True)
    
        img_retriever = Retriever(vector_store=img_vector_store, document_store=chunk_store,
                                  embed_func=img_emb.embed_documents, k=1, score_threshold=0.4)
        res = img_retriever.invoke(query)
        # res contains the path of the retrieved image.
        print(res)
    
  2. Run the following command. Set other parameters as required. For details, see ClientParam.
    python3 retrieve_img_demo.py --image-path ./car1.jpg --image-path ./car2.jpg --query "Car"